当前位置: 首页 > 工具软件 > Fairseq > 使用案例 >

fairseq库学习笔记(一)入门(Getting Started)

井学
2023-12-01

fairseq库学习笔记(一)入门

前言

Fairseq是一个用PyTorch编写的序列建模工具包,它允许研究人员和开发人员训练用于翻译、摘要、语言建模和其他文本生成任务的定制模型。本系列笔记主要以翻译官方文档为主,附带一些个人的学习记录。官方教程连接:link

1 入门(Getting Started)

1.1 评估预训练模型(Evaluating Pre-trained Models)

首先,下载一个预训练模型及其词汇表。

curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -

这个模型使用字节对编码(BPE)词汇表,因此我们必须在翻译源文本之前对其应用编码。这可以通过使用wmt14.en-fr.fconv-cuda / bpecodesapply_bpel .py脚本完成。@@用作延续标记,原始文本可以通过sed s/@@ //g或将 --remove-bpe标记传递给fairseq-generate来轻松恢复。在BPE之前,输入文本需要使用 mosesdecoder中的tokenizer.perl来分词。

让我们使用fairseq-interactive交互式生成翻译。在这里,我们使用5的beam size并使用Moses分词器和给定的字节对编码词汇表对输入进行预处理。它将自动删除BPE延续标记并对输出进行detokenize。

> MODEL_DIR=wmt14.en-fr.fconv-py
> fairseq-interactive \
    --path $MODEL_DIR/model.pt $MODEL_DIR \
    --beam 5 --source-lang en --target-lang fr \
    --tokenizer moses \
    --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
Why is it rare to discover new marine mammal species?
S-0     Why is it rare to discover new marine mam@@ mal species ?
H-0     -0.0643349438905716     Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
P-0     -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015

这个生成脚本生成三种类型的输出:以S这一行是原始源句的副本;H是生成的翻译输出(称为Hypothesis),它紧跟在一个平均的log-likelihood之后;而P是每个token位置的positional score,包括文本中省略的句末token。

你可能会看到其他类型的输出行是:

  • D:解码后的输出
  • T:参考目标
  • A:对齐信息
  • E:生成步骤的历史

笔者自已实验输出如下:

(slurm) jxqi@main-3:~/Study/fairseq_learn$ fairseq-interactive     --path $MODEL_DIR/model.pt $MODEL_DIR     --beam 2 --source-lang en --target-lang fr     --tokenizer moses     --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
2021-10-06 10:37:15 | INFO | fairseq_cli.interactive | Namespace(all_gather_list_size=16384, batch_size=1, batch_size_valid=None, beam=2, bf16=False, bpe='subword_nmt', bpe_codes='wmt14.en-fr.fconv-py/bpecodes', bpe_separator='@@', broadcast_buffers=False, bucket_cap_mb=25, buffer_size=1, checkpoint_shard_count=1, checkpoint_suffix='', constraints=None, cpu=False, criterion='cross_entropy', curriculum=0, data='wmt14.en-fr.fconv-py', data_buffer_size=10, dataset_impl=None, ddp_backend='c10d', decoding_format=None, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method=None, distributed_no_spawn=False, distributed_num_procs=1, distributed_port=-1, distributed_rank=0, distributed_world_size=1, distributed_wrapper='DDP', diverse_beam_groups=-1, diverse_beam_strength=0.5, diversity_rate=-1.0, empty_cache_freq=0, eval_bleu=False, eval_bleu_args=None, eval_bleu_detok='space', eval_bleu_detok_args=None, eval_bleu_print_samples=False, eval_bleu_remove_bpe=None, eval_tokenized_bleu=False, fast_stat_sync=False, find_unused_parameters=False, fix_batches_to_gpus=False, fixed_validation_seed=None, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, gen_subset='test', input='-', iter_decode_eos_penalty=0.0, iter_decode_force_max_iter=False, iter_decode_max_iter=10, iter_decode_with_beam=1, iter_decode_with_external_reranker=False, left_pad_source='True', left_pad_target='False', lenpen=1, lm_path=None, lm_weight=0.0, load_alignments=False, localsgd_frequency=3, log_format=None, log_interval=100, lr_scheduler='fixed', lr_shrink=0.1, match_source_len=False, max_len_a=0, max_len_b=200, max_source_positions=1024, max_target_positions=1024, max_tokens=None, max_tokens_valid=None, memory_efficient_bf16=False, memory_efficient_fp16=False, min_len=1, min_loss_scale=0.0001, model_overrides='{}', model_parallel_size=1, moses_no_dash_splits=False, moses_no_escape=False, moses_source_lang=None, moses_target_lang=None, nbest=1, no_beamable_mm=False, no_early_stop=False, no_progress_bar=False, no_repeat_ngram_size=0, no_seed_provided=False, nprocs_per_node=4, num_batch_buckets=0, num_shards=1, num_workers=1, optimizer=None, path='wmt14.en-fr.fconv-py/model.pt', pipeline_balance=None, pipeline_checkpoint='never', pipeline_chunks=0, pipeline_decoder_balance=None, pipeline_decoder_devices=None, pipeline_devices=None, pipeline_encoder_balance=None, pipeline_encoder_devices=None, pipeline_model_parallel=False, prefix_size=0, print_alignment=False, print_step=False, profile=False, quantization_config_path=None, quiet=False, remove_bpe=None, replace_unk=None, required_batch_size_multiple=8, required_seq_len_multiple=1, results_path=None, retain_dropout=False, retain_dropout_modules=None, retain_iter_history=False, sacrebleu=False, sampling=False, sampling_topk=-1, sampling_topp=-1.0, score_reference=False, scoring='bleu', seed=1, shard_id=0, skip_invalid_size_inputs_valid_test=False, slowmo_algorithm='LocalSGD', slowmo_momentum=None, source_lang='en', target_lang='fr', task='translation', temperature=1.0, tensorboard_logdir=None, threshold_loss_scale=None, tokenizer='moses', tpu=False, train_subset='train', truncate_source=False, unkpen=0, unnormalized=False, upsample_primary=1, user_dir=None, valid_subset='valid', validate_after_updates=0, validate_interval=1, validate_interval_updates=0, warmup_updates=0, zero_sharding='none')
2021-10-06 10:37:16 | INFO | fairseq.tasks.translation | [en] dictionary: 43771 types
2021-10-06 10:37:16 | INFO | fairseq.tasks.translation | [fr] dictionary: 43807 types
2021-10-06 10:37:16 | INFO | fairseq_cli.interactive | loading model(s) from wmt14.en-fr.fconv-py/model.pt
2021-10-06 10:37:34 | INFO | fairseq_cli.interactive | NOTE: hypothesis and token scores are output in base 2
2021-10-06 10:37:34 | INFO | fairseq_cli.interactive | Type the input sentence and press return:
Why is it rare to discover new marine mammal species?
S-0     Why is it rare to discover new marine mam@@ mal species ?
W-0     5.055   seconds
H-0     -0.2200193852186203     Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
D-0     -0.2200193852186203     Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
P-0     -0.3204 -0.4505 -0.1860 -0.3856 -0.2469 -0.2784 -0.1588 -0.2395 -0.1447 -0.1068 -0.1588 -0.1175 -0.1786 -0.1421 -0.1858
Hello world.
S-1     H@@ ello world .
W-1     0.150   seconds
H-1     -0.36586591601371765    Bon@@ jour le monde .
D-1     -0.36586591601371765    Bonjour le monde.
P-1     -0.7724 -0.0206 -0.9498 -0.0860 -0.1848 -0.1815
What is your name?
S-2     What is your name ?
W-2     0.194   seconds
H-2     -0.26296576857566833    Quel est votre nom ?
D-2     -0.26296576857566833    Quel est votre nom ?
P-2     -0.7339 -0.1240 -0.3228 -0.0817 -0.1359 -0.1796
what's your name?
S-3     what 's your name ?
W-3     0.214   seconds
H-3     -0.5068069100379944     quel est votre nom ?
D-3     -0.5068069100379944     quel est votre nom ?
P-3     -2.0806 -0.1678 -0.3916 -0.0744 -0.1446 -0.1818

这里我又自定义输入了几句诸如"Hello world.","what’s your name?"之类的英语句子,模型给出了翻译花费的时间和翻译结果。

1.2 训练一个新的模型(Training a New Model)

1.2.1 数据预处理(Data Pre-processing)

Fairseq包含几个翻译数据集的示例预处理脚本:IWSLT 2014(德语翻译成英语),WMT 2014(英语翻译成法语)和WMT 2014(英语翻译成德语)。对IWSLT数据集进行预处理和二值化:

首先需要我们使用git将fairseq这个repo先clone下来。之后进到相应的文件夹中准备好实验数据。

> cd examples/translation/
> bash prepare-iwslt14.sh
> cd ../..
> TEXT=examples/translation/iwslt14.tokenized.de-en
> fairseq-preprocess --source-lang de --target-lang en \
    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
    --destdir data-bin/iwslt14.tokenized.de-en

1.2.2 训练(Training)

使用fairseq-train来训练一个新模型。下面是一些适用于IWSLT 2014数据集的示例设置:

> mkdir -p checkpoints/fconv
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
    --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
    --arch fconv_iwslt_de_en --save-dir checkpoints/fconv

1.2.3 生成(Generation)

一旦模型经过训练,就可以使用fairseq-generate(用于二进制数据)或fairseq-interactive(用于原始文本)生成翻译:

> fairseq-generate data-bin/iwslt14.tokenized.de-en \
    --path checkpoints/fconv/checkpoint_best.pt \
    --batch-size 128 --beam 5
| [de] dictionary: 35475 types
| [en] dictionary: 24739 types
| data-bin/iwslt14.tokenized.de-en test 6750 examples
| model fconv
| loaded checkpoint trainings/fconv/checkpoint_best.pt
S-721   danke .
T-721   thank you .
...

1.3 高级训练选项(Advanced Training Options)

1.3.1 使用延迟更新来获得更大的mini-batch size(Large mini-batch training with delayed updates)

–update-freq选项可以被用来从多个mini-batch中累积梯度并且延迟更新,以此实现一个更大的有效的bacth_size。延迟更新还可以通过减少gpu间的通信成本和节省gpu间工作负载差异造成的空闲时间来提高训练速度。

在单个GPU上进行训练,其有效批大小相当于在8个GPU上进行训练。

> CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)

1.3.2 使用半精度浮点数(Training with half precision floating point (FP16))

注意: FP16要求Volta GPU和CUDA 9.1或更高的版本。

最近的gpu支持高效的半精度浮点计算,例如,使用 Nvidia Tensor Cores。Fairseq支持FP16训练,这可以通过设置–fp16实现。

> fairseq-train --fp16 (...)

1.3.3 分布式训练(Distributed training)

fairseq中的分布式训练是在torch.distributed上实现的。启动任务最简单的方法是使用torch.distributed.launch工具。

例如,要在2个各有8个gpu(总共16个gpu)的节点上训练大型English-German Transformer模型,在每个节点上运行以下命令,在第二个节点上用node_rank=1替换node_rank=0,并确保将 --master_addr 更新为第一个节点的IP地址。

> python -m torch.distributed.launch --nproc_per_node=8 \
    --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
    --master_port=12345 \
    $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
    --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
    --lr 0.0005 \
    --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 3584 \
    --max-epoch 70 \
    --fp16

在SLURM集群中,fairseq会自动检测节点和gpu的数量,但需要提供端口号。

> salloc --gpus=16 --nodes 2 (...)
> srun fairseq-train --distributed-port 12345 (...).

1.3.4 共享非常大的数据集(Sharding very large datasets)

在非常大的数据集上进行训练可能是具有挑战性的,特别是在您的机器没有太多系统RAM的情况下。fairseq中的大多数任务都支持在分片数据集上进行训练,在分片数据集中,原始数据集被预处理成非重叠的块(或分片)。

例如,不必将所有数据预处理到单个data-bin目录中,而是可以分割数据并创建data-bin1、data-bin2等。然后你可以像这样调整你的训练指令:

> fairseq-train data-bin1:data-bin2:data-bin3 (...)

训练将在每个分片上进行迭代,一个接一个,每个分片对应一个epoch,从而减少系统内存占用。

1.4 命令行工具(Command-line Tools)

待补充

 类似资料: