当前位置: 首页 > 工具软件 > Fairseq > 使用案例 >

fairseq笔记

鞠晋
2023-12-01

训练新模型

以机器翻译为例子开始

数据预处理

Fairseq 包含多个翻译数据集的示例预处理脚本:IWSLT 2014(德语-英语)、WMT 2014(英语-法语)和 WMT 2014(英语-德语)。预处理和二值化 IWSLT 数据集:

> cd examples/translation/  #把当前路径切换到翻译示例下
> bash prepare-iwslt14.sh   #运行预处理脚本
> cd ../..                  #返回上上级目录。也就是退回到fairseq-master/
> TEXT=examples/translation/iwslt14.tokenized.de-en

#这个指令会调用 anaconda/scripts下的fairseq-preprocess.exe(如果是在windows下) ,
#这个实质上,根据F:\ANACONDA\Lib\site-packages\fairseq-0.10.0.dist-info的entry_points.txt
#fairseq-preprocess = fairseq_cli.preprocess:cli_main
#已经指明了这个exe实际上执行的是cli_main
> fairseq-preprocess --source-lang de --target-lang en \
    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
    --destdir data-bin/iwslt14.tokenized.de-en

预处理脚本

我在脚本中写了注释,方便对脚本语言不熟练的同学入门,如果熟练的同学可以回忆一下这个预处理做了什么,然后跳转到关于fairseq-preprocess的源码讲解部分

  • 下载数据集

  • 下载subword-nmt和moses

  • 清洗训练集

    • 去掉包含<url>,<talkid>,<keyword>的行,删除<title>,</title>,<description>,</description>标记
    • 用moses分词,把标点符号和英文单词分开
    • 保留1-175长度和源语言目标语言长度1.5比例内的句子
    • 全部单词变小写
  • 把训练数据集(清洗后的)按照22:1划分成训练集和验证集

  • 把原验证集测试集全划分成测试集

  • bpe算法

#!/usr/bin/env bash
#
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh

#echo是输出一些信息,git clone是下载github上的仓库
echo 'Cloning Moses github repository (for tokenization scripts)...' 
git clone https://github.com/moses-smt/mosesdecoder.git

echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git

#定义一些路径变量
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
LC=$SCRIPTS/tokenizer/lowercase.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=10000
# 指定去哪里下数据集
URL="http://dl.fbaipublicfiles.com/fairseq/data/iwslt14/de-en.tgz"
GZ=de-en.tgz
# -d 是linux脚本中用来判断是不是目录的
if [ ! -d "$SCRIPTS" ]; then
    echo "Please set SCRIPTS variable correctly to point to Moses scripts."
    exit
fi

src=de
tgt=en
lang=de-en
prep=iwslt14.tokenized.de-en
tmp=$prep/tmp
orig=orig
# 创建多层目录,主要是因为tmp和prep是多层目录
mkdir -p $orig $tmp $prep

echo "Downloading data from ${URL}..."
cd $orig
#从互联网上下载东西
wget "$URL"
#判断GZ是不是文件
if [ -f $GZ ]; then
    echo "Data successfully downloaded."
else
    echo "Data not successfully downloaded."
    exit
fi
# 解压GZ到当前文件夹
tar zxvf $GZ
#返回上一级
cd ..

echo "pre-processing train data..."
for l in $src $tgt; do
    f=train.tags.$lang.$l
    tok=train.tags.$lang.tok.$l
    # |是管道,就是这一条语句的输出作为下一条语句的输入,等于把grep这几行去掉
    cat $orig/$lang/$f | \
    grep -v '<url>' | \
    grep -v '<talkid>' | \
    grep -v '<keywords>' | \
    # s是替换,g是替换多次
    sed -e 's/<title>//g' | \
    sed -e 's/<\/title>//g' | \
    sed -e 's/<description>//g' | \
    sed -e 's/<\/description>//g' | \
    # 分词,就是把标点符号和单词分开
    perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
    #换行……?
    echo ""
done
# http://www.statmt.org/moses/?n=FactoredTraining.PrepareTraining
# 保留1-175长度和源语言目标语言长度1.5比例内的句子
#  clean-corpus-n.perl CORPUS L1 L2 OUT MIN MAX
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
#小写
for l in $src $tgt; do
    perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
done

echo "pre-processing valid/test data..."
for l in $src $tgt; do
    # ls是个列表,*和正则不是一个东西,正则的是要求 前面有点什么,他的0闭包,但是文件扩展他本身就是任何字符(串)。
    for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
    #o是一个绝对路径,这个##*是删除除了最后一个路径(也就是文件名)的东西,等于保留IWSLT14.TED*.$l.xml
    # 两个重复的符号是最大匹配,一个是最小匹配,#是去掉左边,%是去掉右边(逆序匹配),后面的就是正则了
    fname=${o##*/}
    # f就是等于 temp/IWSLT14.TED*.$l
    f=$tmp/${fname%.*}
    echo $o $f
    # 找到有segid 的行,去掉seg id,去掉句子结尾的seg,替换中文引号为英文引号
    grep '<seg id' $o | \
        sed -e 's/<seg id="[0-9]*">\s*//g' | \
        sed -e 's/\s*<\/seg>\s*//g' | \
        sed -e "s/\’/\'/g" | \
    #分词,小写
    perl $TOKENIZER -threads 8 -l $l | \
    perl $LC > $f
    echo ""
    done
done


echo "creating train, valid, test..."
# 每23行输出一行给valid,其实这个句子应该这么理解
# [awk '{if (NR%23 == 0)  print $0; }' $tmp/train.tags.de-en.$l] > $tmp/valid.$l
for l in $src $tgt; do
    awk '{if (NR%23 == 0)  print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l
    awk '{if (NR%23 != 0)  print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l
    #把这几个文件合并成$tmp/test.$l
    cat $tmp/IWSLT14.TED.dev2010.de-en.$l \
        $tmp/IWSLT14.TEDX.dev2012.de-en.$l \
        $tmp/IWSLT14.TED.tst2010.de-en.$l \
        $tmp/IWSLT14.TED.tst2011.de-en.$l \
        $tmp/IWSLT14.TED.tst2012.de-en.$l \
        > $tmp/test.$l
done

TRAIN=$tmp/train.en-de
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
    cat $tmp/train.$l >> $TRAIN
done

#调用subword_nmt 学习code ,这些参数在我的另一个文章中有
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE

#使用code进行bpe生成
for L in $src $tgt; do
    for f in train.$L valid.$L test.$L; do
        echo "apply_bpe.py to ${f}..."
        python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f
    done
done

preprocess的源码分析

先讲怎么从命令行中截获参数,fairseq用的是python 标准库里的argparser,顾名思义,是参数解析器。

python的参数类型

在python中,有如下四种类型的参数,分别是位置参数、关键字参数、默认参数和可变参数。其中位置参数和关键字参数讲的是调用的方式,比如以下例子:

def print_hello(name, sex)

print_hello('小明','male') #位置参数

print_hello(sex='male',name='小明') #关键字参数

位置参数通过参数定义的位置来传递参数;关键字参数通过键值对的方式来传递参数,不需要考虑位置关系。当关键字参数和位置参数混用的时候,需要特别注意,位置参数必须在关键字参数之前,所以解析的方式就是逐个把位置参数送入形参,再把关键字和形参结合,一旦关键字参数和位置参数相同,比如print_hello(1, name='小明')就会因为name有两个实参报错。

默认参数就是在形参定义的时候,带上的默认值,比如

def print_hello(sex,name='male')
    # name是默认参数

这时候调用的时候就可以不传name的实参进去,当然需要注意的是,默认参数也必须在位置参数之后。

可变参数就是有时候我们不确定调用的时候会传递多少个参数,此时可以用packing包裹位置参数或者关键字参数。
比如包裹位置参数的例子

#定义
def func(*args):
    ....

# 调用
func()
func(a)
func(a, b, c)

所有传进去的参数都会被args收集,他是一个tuple类型的变量。

def func(**kargs):
    ....

func(a=1)
func(a=1, b=2, c=3)

kargs是一个dict类型的变量。
需要注意的是,args和kargs并不是必须的命名,只是一种习惯,区别是元组还是字典,靠的是**的数量。

我们什么时候需要argparse,比如我们写好了一个python脚本hello.py
你可以直接使用

python hello.py

来运行这个脚本,但有时候你并不满足只是运行,可能还需要从外界获取一点额外的信息,比如说,使用者的名字,运行的次数,等等。这时你就可以使用argparse,达到下面的效果:

python hello.py --name 小明 --time 3

这时小明和3就会被传入hello.py中,并且可以被获取,我们不会讲的特别深入,只保证你能明白fairseq用这个做什么,感兴趣可以自行从python的官方文档中阅读。

讲完argparse的用途,我们讲怎么做。
基本上是三部曲,第一是创建解析器,第二步是往解析器里添加需要解析的参数,第三步是开始解析参数。这就好像开辅导班,第一步是租店面,第二步是确定教什么科目,第三步是招收对应的老师为学生授课这样(奇怪的例子orz)

argparse中,有一个命令解析类,叫做ArgumentParser,他的构造函数中的所有参数都是关键字参数,也就是要用键值对的方式传进去,他有非常多的成员,我们只讲fairseq用的部分。
bool类型的add_help,当这个参数是true的时候,你可以通过-h或者–help读到这个hello.py所有参数的帮助(当然这得是你写了才有东西输出)。
bool类型的allow_abbrev,这个是允许使用缩写的意思,在python3.5以后默认开启,比如我们定义了–time 这个参数,当你实际使用时采用–ti 3,也可以被识别到time参数上,当然,一旦你输入的缩写是多个参数的共同前缀,产生了歧义时,这个选项就无法使用了。

在定义完一个解析器之后,我们需要为里面加上需要解析的内容,这个是通过解析器类的add_argument方法进行的,比如我们需要让这个参数解析器接受--time,就是通过这个方法加的。这个方法常用的参数如下:
default - 说白了如果定义了这个,就允许这个参数被当成默认参数处理,如果没写,或者把这一个参数定义成None,那这个参数就无法在命令行缺省该参数时使用。
name or flags - 这个参数就是用来写–time的,为这个解析器类加上需要解析的参数,需要注意,如果在同一个add_argument里面写多个flags('-f', '--foo'),也可以只写一个name比如(bar),需要注意,不加-的会被解析成位置参数,是不允许缺省的,一旦缺省了会报错,而加了-的会被认为是可选参数,同时如果flags不是一个字母,前面要加–,如果是一个字母,只用一个-。
dest - 这个说起来有点绕啊,其实他和name需要区分一下,他是parser创建完后,解析了参数之后(后面会说的parse_args()方法),你用什么变量名来获取刚刚的参数,比如说啊
parser.add_argument(’-f’, ‘–foo-bar’, ‘–foo’),这里面所有的参数都是flags对吧,那之后我们要调用这个参数,就是通过parse.foo_bar,因为如果有–的,会选择第一个–的names去掉杠杠,而且把里面的-变成_(这是因为变量名的规范要求),如果只有-的,就取第一个-的name作为内容。


讲完参数的类型后,我们开始看源码部分的argparse内容。fairseq-preprocess调用的是fairseq_cli/preprocess.pycli-main()

def cli_main():
    parser = options.get_preprocessing_parser()
    args = parser.parse_args()
    main(args)

其中options的路径是fairseq/options.py,我们看看预处理的解析器函数内容是什么。

def get_preprocessing_parser(default_task="translation"):
    parser = get_parser("Preprocessing", default_task)
    add_preprocess_args(parser)
    return parser

上面这三行里面,get_parser创建了一个parser,他的两个实参中,第一个字符串类的desc并没有用上,更像是一个和不同cli指令区分开的标记,只是为了增加可读性。这个函数的具体注释写在代码体里面了。

def get_parser(desc, default_task="translation"):
    """

    Args:
        desc: 没用上,这里像是一个信息标记
        default_task:  默认任务

    Returns:

    """
    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.

    #创建用户参数解析器
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) #不添加帮助,不允许运行时缩写指代参数
    usr_parser.add_argument("--user-dir", default=None)
    #这个和parse_args很像,区别在于当解析碰到额外指令时不会报错,会把额外的不存在的参数存下来
    #不过这里没打算接受,所以用了_,本质应该是为了加强鲁棒性。
    usr_args, _ = usr_parser.parse_known_args()
    #导入用户的自定义模块
    utils.import_user_module(usr_args)

    #然后通过数据类初始化这个parser
    parser = argparse.ArgumentParser(allow_abbrev=False)
    gen_parser_from_dataclass(parser, CommonConfig())

    from fairseq.registry import REGISTRIES

    for registry_name, REGISTRY in REGISTRIES.items():
        parser.add_argument(
            "--" + registry_name.replace("_", "-"),
            default=REGISTRY["default"],
            choices=REGISTRY["registry"].keys(),
        )

    # Task definitions can be found under fairseq/tasks/
    from fairseq.tasks import TASK_REGISTRY

    parser.add_argument(
        "--task",
        metavar="TASK",
        default=default_task,
        choices=TASK_REGISTRY.keys(),
        help="task",
    )
    # fmt: on
    return parser
 类似资料: