安装:
其中requirements.txt里安装的是很新的torchtext
git clone --branch 0.9.1 https://github.com/OpenNMT/OpenNMT-py.git
cd OpenNMT-py
pip install -r requirements.txt
cd ..
预处理:
其中src-train.txt
和tgt-train.txt
为原始英文数据,按行对齐
python ./OpenNMT-py/preprocess.py \
-train_src data/opennmt_format/src-train.txt \
-train_tgt data/opennmt_format/tgt-train.txt \
-valid_src data/opennmt_format/src-dev.txt \
-valid_tgt data/opennmt_format/tgt-dev.txt \
-save_data data/opennmt_format/preprocessed \
-src_seq_length 10000 \
-tgt_seq_length 10000 \
-src_seq_length_trunc 400 \
-tgt_seq_length_trunc 100 \
-dynamic_dict \
-share_vocab \
-src_vocab_size 32000 \
-tgt_vocab_size 32000 \
-shard_size 100000
然后训练:
python ./OpenNMT-py/train.py \
-data data/opennmt_format/preprocessed \
-save_model data/doc2query \
-layers 6 \