WeNet训练初体验

  |   0 评论   |   0 浏览

背景

WeNet

WeNet是一个面向工业级产品的开源端到端语音识别解决方案,同时支持流式及非流式识别,并能高效运行于云端及嵌入式端。

初体验

在环境准备后,还有六个步骤。官方建议手动一步一步执行,然后查看结果,来熟悉整个过程。

环境准备

conda create -n wenet python=3.8
conda activate wenet
pip install -r requirements.txt
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge

步骤-1:下载数据

cd wenet/examples/aishell/s0
bash run.sh --stage -1 --stop_stage -1 --data data

数据会自动被下载到data目录中,结果为:

$ ls data/
data_aishell  data_aishell.tgz  resource_aishell  resource_aishell.tgz

步骤0:数据预处理

将数据整理为WeNet所需的格式。

bash run.sh --stage 0 --stop-stage 0 --data data
Preparing data/local/train transcriptions
Preparing data/local/dev transcriptions
Preparing data/local/test transcriptions
local/aishell_data_prep.sh: AISHELL data preparation succeeded

目标为 wav.scp 和 text 两个文件,格式分别为:

wav.scp

BAC009S0724W0121 data/data_aishell/wav/dev/S0724/BAC009S0724W0121.wav
BAC009S0724W0122 data/data_aishell/wav/dev/S0724/BAC009S0724W0122.wav
BAC009S0724W0123 data/data_aishell/wav/dev/S0724/BAC009S0724W0123.wav
BAC009S0724W0124 data/data_aishell/wav/dev/S0724/BAC009S0724W0124.wav
BAC009S0724W0125 data/data_aishell/wav/dev/S0724/BAC009S0724W0125.wav
BAC009S0724W0126 data/data_aishell/wav/dev/S0724/BAC009S0724W0126.wav
BAC009S0724W0127 data/data_aishell/wav/dev/S0724/BAC009S0724W0127.wav
BAC009S0724W0128 data/data_aishell/wav/dev/S0724/BAC009S0724W0128.wav
BAC009S0724W0129 data/data_aishell/wav/dev/S0724/BAC009S0724W0129.wav
BAC009S0724W0130 data/data_aishell/wav/dev/S0724/BAC009S0724W0130.wav
...

text

BAC009S0724W0121     广州  市  房地  产中  介  协会  分析
BAC009S0724W0122     广州  市  房地  产中  介  协会  还  表示
BAC009S0724W0123     相比  于  其他  一  线  城市
BAC009S0724W0124     广州  二手  住宅  市场  表现  一直  相对  稳健
BAC009S0724W0125     而  在  股市  大幅  震荡  的  环境  下
BAC009S0724W0126     预计  第  三  季度  将  陆续  有  部分  股市  资金  重归  楼市
BAC009S0724W0127     但  受  穗  六条  及  二  套房  首  付  七成  的  制约
BAC009S0724W0128     下半  年  楼市  能  是否  能  回到  快速  上升  通道  依然  存在  变数
BAC009S0724W0129     其中  越秀  区  涨幅  领先
BAC009S0724W0130     天河  区  的  签约  面积  在  豪宅  交  投  增多  的  带动  下  上升  较快

dev, test和train三个目录下的数据量分别为:

14326 dev/text
    7176 test/text
  120098 train/text

步骤1:(可选)计算cmvn特征

由于TorchAudio会在数据加载的时候,实时计算特征。所以这步就是文本的直接复制。

tools/compute_cmvn_stats.py is used to extract global cmvn(cepstral mean and variance normalization) statistics. These statistics will be used to normalize the acoustic features. Setting cmvn=false will skip this step.

UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
processed 1000 wavs, 452620 frames
processed 2000 wavs, 900146 frames
processed 3000 wavs, 1346665 frames
processed 4000 wavs, 1808066 frames

结果

tree raw_wav/
raw_wav/
├── dev
│   ├── text
│   └── wav.scp
├── test
│   ├── text
│   └── wav.scp
└── train
    ├── global_cmvn
    ├── text
    └── wav.scp

步骤2: 生成label-token词典

bash run.sh --stage 2 --stop-stage 2 --data data
Make a dictionary

结果

head data/dict/lang_char.txt
<blank> 0 // 对应于CTC中的空符号
<unk> 1 // 未知token,对应于OOV词汇
一 2
丁 3
七 4
万 5
丈 6
三 7
上 8
下 9
…………
龄 4228
龙 4229
龚 4230
龟 4231
<sos/eos> 4232 // 对应于语音的开始和结束,使用同样的id

步骤3:准备WeNet数据格式

将所有的输入输出信息,生成到一个WeNet格式的单一文件中。

bash run.sh --stage 3 --stop-stage 3 --data data
Prepare data, prepare requried format
tools/format_data.sh --nj 32 --feat-type wav --feat raw_wav/dev/wav.scp raw_wav/dev data/dict/lang_char.txt

结果

bash run.sh --stage 3 --stop-stage 3 --data data
Prepare data, prepare requried format
tools/format_data.sh --nj 32 --feat-type wav --feat raw_wav/dev/wav.scp raw_wav/dev data/dict/lang_char.txt
sym2int.pl: replacing 斡 with 1
sym2int.pl: replacing 沂 with 1
sym2int.pl: replacing 绢 with 1
sym2int.pl: replacing 髌 with 1
sym2int.pl: replacing 湄 with 1
sym2int.pl: replacing 圃 with 1
sym2int.pl: replacing 柚 with 1
sym2int.pl: replacing 柚 with 1
sym2int.pl: replacing 柚 with 1
sym2int.pl: replacing 圃 with 1
sym2int.pl: replacing 荇 with 1
sym2int.pl: replacing 芪 with 1
sym2int.pl: replacing 薙 with 1
sym2int.pl: replacing 潦 with 1
sym2int.pl: replacing 鲤 with 1
sym2int.pl: replacing 涟 with 1
sym2int.pl: replacing 筏 with 1
sym2int.pl: replacing 筏 with 1
sym2int.pl: replacing 筏 with 1
sym2int.pl: replacing 锚 with 1
sym2int.pl: not warning for OOVs any more times
** Replaced 110 instances of OOVs with 1
tools/format_data.sh --nj 32 --feat-type wav --feat raw_wav/test/wav.scp raw_wav/test data/dict/lang_char.txt
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 埕 with 1
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 藕 with 1
sym2int.pl: replacing 祎 with 1
sym2int.pl: replacing 祎 with 1
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 垭 with 1
sym2int.pl: replacing 纾 with 1
sym2int.pl: replacing 疙 with 1
sym2int.pl: replacing b with 1
sym2int.pl: replacing 垡 with 1
sym2int.pl: replacing 疡 with 1
sym2int.pl: replacing 嗪 with 1
sym2int.pl: replacing 雹 with 1
sym2int.pl: replacing 淅 with 1
sym2int.pl: replacing 谶 with 1
sym2int.pl: replacing 诏 with 1
sym2int.pl: not warning for OOVs any more times
** Replaced 51 instances of OOVs with 1
tools/format_data.sh --nj 32 --feat-type wav --feat raw_wav/train/wav.scp raw_wav/train data/dict/lang_char.txt

步骤4: NN训练

bash run.sh --stage 4 --stop-stage 4 --data data

多CPU模式

使用Multi-GPU的DDP(DataDistributedParallel)模式。

注意事项:多GPU训练,推荐使用nccl,如果不行,则可用gloo或者torch==1.6.0。

恢复模式

设置

checkpoint=exp/your_exp/$n.pt

配置

参考 conf/train_conformer.yaml文件。

使用Tensorboard

tensorboard --logdir tensorboard/$your_exp_name/ --port 12598 --bind_all

步骤5:wav识别

在步骤4中的checkpoint出现时,我们即可以做wav识别。

bash run.sh --stage 5 --stop-stage 5 --data data

这里用了一个54.pt,结果为:

==> exp/conformer/test_attention_rescoring/wer <==
Overall -> 5.91 % N=104765 C=98689 S=5899 D=177 I=113
Mandarin -> 5.90 % N=104762 C=98689 S=5896 D=177 I=113
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0

===========================================================================

==> exp/conformer/test_attention/wer <==
Overall -> 7.49 % N=104765 C=97034 S=5862 D=1869 I=120
Mandarin -> 7.49 % N=104762 C=97034 S=5860 D=1868 I=120
Other -> 100.00 % N=3 C=0 S=2 D=1 I=0

===========================================================================

==> exp/conformer/test_ctc_greedy_search/wer <==
Overall -> 6.74 % N=104765 C=97843 S=6748 D=174 I=134
Mandarin -> 6.73 % N=104762 C=97843 S=6745 D=174 I=134
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0

===========================================================================

==> exp/conformer/test_ctc_prefix_beam_search/wer <==
Overall -> 6.73 % N=104765 C=97848 S=6748 D=169 I=137
Mandarin -> 6.73 % N=104762 C=97848 S=6745 D=169 I=137
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0

步骤6:导出模型

wenet/bin/export_jit.py
will export the trained model using Libtorch. The exported model files can be easily used for inference in other programming languages such as C++.

python wenet/bin/export_jit.py \
--config exp/conformer/train.yaml \
--checkpoint exp/conformer/56.pt \
--output_file final.zip \
--output_quant_file final_quant.zip

结果

70M Sep 19 21:02 final_quant.zip
199M Sep 19 21:01 final.zip

LM训练

7.0 环境准备

srilm

srilm:在官网下载,编译。将srilm/lm/bin/i686-m64下的工具复制到~/bin下。

openfst

下载 openfst-1.6.5

glog

sudo yum install glog-devel gflags-devel

wenet-kaldi

这个编译起来比较费劲,在CMakeList.txt中增加头文件搜索路径

set( CMAKE_CXX_FLAGS "-std=c++11 -O3 -I /home/fw/work/2021/06.wenet/wenet/runtime/core/kaldi" )
set(CXX_FLAGS -std=c++11)

最后弃坑,还是直接编译kaldi吧。

decoder_main

这个必须要自己编译的。进入目录runtime/server/x86,查看 README_CN.md,或者直接使用镜像:

docker pull mobvoiwenet/wenet:v0.5.0

7.1 准备词典

dict=/home/fw/work/2021/06.wenet/wenet/examples/aishell/s0/data/dict/lang_char.txt
unit_file=$dict
mkdir -p data/local/dict
cp $unit_file data/local/dict/units.txt
tools/fst/prepare_dict.py $unit_file data/resource_aishell/lexicon.txt data/local/dict/lexicon.txt

结果

$ head data/local/dict/lexicon.txt
啊 啊
啊啊啊 啊 啊 啊
阿 阿
阿尔 阿 尔
阿根廷 阿 根 廷
阿九 阿 九
阿克 阿 克
阿拉伯数字 阿 拉 伯 数 字
阿拉法特 阿 拉 法 特
阿拉木图 阿 拉 木 图

7.2 训练LM

lm=data/local/lm
mkdir -p $lm
tools/filter_scp.pl data/train/text data/data_aishell/transcript/aishell_transcript_v0.8.txt > $lm/text
local/aishell_train_lms.sh

结果

file data/local/lm/heldout: 10000 sentences, 89496 words, 0 OOVs
0 zeroprobs, logprob= -272791.2 ppl= 551.7352 ppl1= 1117.077

结果

ls -lh data/local/lm/*
-rw-rw-r-- 1 fw fw 511K Sep 21 14:17 data/local/lm/heldout
-rw-rw-r-- 1 fw fw  16M Sep 21 14:17 data/local/lm/lm.arpa
-rw-rw-r-- 1 fw fw 8.0M Sep 21 13:15 data/local/lm/text
-rw-rw-r-- 1 fw fw 7.9M Sep 21 14:17 data/local/lm/text.no_oov
-rw-rw-r-- 1 fw fw 5.5M Sep 21 14:17 data/local/lm/train
-rw-rw-r-- 1 fw fw 2.1M Sep 21 14:17 data/local/lm/unigram.counts
-rw-rw-r-- 1 fw fw 607K Sep 21 14:17 data/local/lm/word.counts
-rw-rw-r-- 1 fw fw 1.1M Sep 21 14:17 data/local/lm/wordlist

7.3 准备TLG

tools/fst/compile_lexicon_token_fst.sh  data/local/dict data/local/tmp data/local/lang
tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;

结果

fstaddselfloops 'echo 4234 |' 'echo 123660 |'
Lexicon and token FSTs compiling succeeded

arpa2fst --read-symbol-table=data/lang_test/words.txt --keep-symbols=true -
LOG (arpa2fst[5.5.971~1-07043]:Read():arpa-file-parser.cc:94) Reading \data\ section.
LOG (arpa2fst[5.5.971~1-07043]:Read():arpa-file-parser.cc:149) Reading \1-grams: section.
LOG (arpa2fst[5.5.971~1-07043]:Read():arpa-file-parser.cc:149) Reading \2-grams: section.
LOG (arpa2fst[5.5.971~1-07043]:Read():arpa-file-parser.cc:149) Reading \3-grams: section.
Checking how stochastic G is (the first of these numbers should be small):
fstisstochastic data/lang_test/G.fst
0 -1.14386
fsttablecompose data/lang_test/L.fst data/lang_test/G.fst
fstminimizeencoded
fstdeterminizestar --use-log=true
fsttablecompose data/lang_test/T.fst data/lang_test/LG.fst
Composing decoding graph TLG.fst succeeded

7.4 解码

./tools/decode.sh --nj 16 \
    --beam 15.0 --lattice_beam 7.5 --max_active 7000 \
    --blank_skip_thresh 0.98 --ctc_weight 0.5   --rescoring_weight 1.0 \
    --fst_path data/lang_test/TLG.fst \
    data/test/wav.scp data/test/text $dir/final.zip \
    data/lang_test/words.txt $dir/lm_with_runtime
sudo docker run --rm -it -v /home/fw/work/2021/06.wenet/:/home/mount docker.io/mobvoiwenet/wenet:v0.5.0 bash

其它

Amazon SageMaker

Amazon SageMaker是一项完全托管的机器学习服务,涵盖了数据标记、数据处理、模型训练、超参调优、模型部署及持续模型监控等基本流程;也提供自动打标签,自动机器学习,监控模型训练等高阶功能。

其通过全托管的机器学习基础设施和对主流框架的支持,可以降低客户机器学习的整体拥有成本。

模型在训练的过程中,需要用到大量的计算资源,我们可以借助Amazon SageMaker非常方便的启动包含多台完全托管的训练实例集群,加速训练过程。

参考