Fairseq框架
Faieseq
Fairseq是一个用PyTorch编写的序列建模工具包,它允许研究人员和开发人员训练用于翻译、摘要、语言建模和其他文本生成任务的定制模型。本系列笔记主要以翻译官方文档为主,附带一些个人的学习记录。官方教程连接:https://fairseq.readthedocs.io/en/latest/getting_started.html。
环境配置
git clone https://github.com/pytorch/fairseq.git
cd fairseq
pip install --editable ./
# 安装失败的可能原因:版本问题
PyTorch version >= 1.10.0
Python version >= 3.8
#Linux 出现gcc版本问题
conda install https://anaconda.org/brown-data-science/gcc/5.4.0/download/linux-64/gcc-5.4.0-0.tar.bz2
tokenization and BPE
# 下载Moses, tokenization采用Moses
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
# 脚本地址
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
LC=$SCRIPTS/tokenizer/lowercase.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=10000
# 数据集下载路径
URL="http://dl.fbaipublicfiles.com/fairseq/data/iwslt14/de-en.tgz"
# 数据集名称
GZ=de-en.tgz
# 监测Moses scripts是否正确
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
# 源语言类型
src=de
# 目标语言类型
tgt=en
# 翻译类型
lang=de-en
prep=iwslt14.tokenized.de-en
tmp=$prep/tmp
orig=orig
mkdir -p $orig $tmp $prep
# 将数据集下载到orig路径下
echo "Downloading data from ${URL}..."
cd $orig
wget "$URL"
# 判断数据集是否成功下载
if [ -f $GZ ]; then
echo "Data successfully downloaded."
else
echo "Data not successfully downloaded."
exit
fi
# 解压数据集
tar zxvf $GZ
cd ..
# 训练集tokenization
echo "pre-processing train data..."
for l in $src $tgt; do
f=train.tags.$lang.$l
tok=train.tags.$lang.tok.$l
cat $orig/$lang/$f | \
grep -v '<url>' | \
grep -v '<talkid>' | \
grep -v '<keywords>' | \
sed -e 's/<title>//g' | \LC
sed -e 's/<\/title>//g' | \
sed -e 's/<description>//g' | \
sed -e 's/<\/description>//g' | \
perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
echo ""
done
# 训练集清洗
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
for l in $src $tgt; do
perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
done
# 验证集和测试集清洗
echo "pre-processing valid/test data..."
for l in $src $tgt; do
for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
fname=${o##*/}
f=$tmp/${fname%.*}
echo $o $f
grep '<seg id' $o | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -l $l | \
perl $LC > $f
echo ""
done
done
echo "creating train, valid, test..."
# 分割训练集和验证集
for l in $src $tgt; do
awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l
awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l
# 合成测试集
cat $tmp/IWSLT14.TED.dev2010.de-en.$l \
$tmp/IWSLT14.TEDX.dev2012.de-en.$l \
$tmp/IWSLT14.TED.tst2010.de-en.$l \
$tmp/IWSLT14.TED.tst2011.de-en.$l \
$tmp/IWSLT14.TED.tst2012.de-en.$l \
> $tmp/test.$l
done
# 训练集目录
TRAIN=$tmp/train.en-de
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
cat $tmp/train.$l >> $TRAIN
done
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
# 三个数据集都进行BPE分词
for L in $src $tgt; do
for f in train.$L valid.$L test.$L; do
echo "apply_bpe.py to ${f}..."
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f
done
done
命令解析:
--task,意思是任务类型,默认是“translation”(翻译);包括translation_from_pretrained_bart等Fairseq自带的任务类型;
--arch,意思是model architecture,模型结构,可选项包括,transformer,lstm等;这里采用自带的结构fastcorrect
--lr 学习率,为初始学习率,后续可能被--lr-scheduler修改;--lr-scheduler 是lr更新计划,这里采用inverse_sqrt 方法;
--dropout 字面意思就是丢弃,这里指的是在训练模型时,丢弃一部分数据来防止过拟合;
--optimizer adam --adam-betas '(0.9, 0.98)' 参数优化策略;
--criterion fc_loss 训练准则;
--max-tokens 9000 一个batch最大的token数量;
--save-dir $SAVE_DIR 存储checkpoints的路径,checkpoint即模型;
--user-dir $EXP_HOME/FastCorrect 一个包含扩展的python模块,这里的扩展是指模型结构或者任务,和task是相对的,一般不适用官方规定的arch时需要手动设置这个路径;
--max-epoch 30 当达到这个30个epoch的时候,停止训练;
--update-freq 4 参数更新频率,每4个batch更新参数;
--fp16 使用FP16
--num-workers 8 8个子线程用于load数据;
举例:
PRETRAIN=mbart.cc25 # fix if you moved the downloaded checkpoint
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
fairseq-train path_2_data \
--encoder-normalize-before --decoder-normalize-before \
--arch mbart_large --layernorm-embedding \
--task translation_from_pretrained_bart \
--source-lang en_XX --target-lang ro_RO \
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler polynomial_decay --lr 3e-05 --warmup-updates 2500 --total-num-update 40000 \
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
--max-tokens 1024 --update-freq 2 \
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
--seed 222 --log-format simple --log-interval 2 \
--restore-file $PRETRAIN \
--reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
--langs $langs \
--ddp-backend legacy_ddp
预训练
fairseq-preprocess \
--source-lang "source" \
--target-lang "target" \
--trainpref "${TASK}/train.bpe" \
--validpref "${TASK}/val.bpe" \
--destdir "${TASK}-bin/" \
--workers 60 \
--srcdict dict.txt \
--tgtdict dict.txt;
最后更新: January 18, 2023