WeNet训练初体验
背景
WeNet
WeNet是一个面向工业级产品的开源端到端语音识别解决方案,同时支持流式及非流式识别,并能高效运行于云端及嵌入式端。
初体验
在环境准备后,还有六个步骤。官方建议手动一步一步执行,然后查看结果,来熟悉整个过程。
环境准备
conda create -n wenet python=3.8
conda activate wenet
pip install -r requirements.txt
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge
步骤-1:下载数据
cd wenet/examples/aishell/s0
bash run.sh --stage -1 --stop_stage -1 --data data
数据会自动被下载到data目录中,结果为:
$ ls data/
data_aishell data_aishell.tgz resource_aishell resource_aishell.tgz
步骤0:数据预处理
将数据整理为WeNet所需的格式。
bash run.sh --stage 0 --stop-stage 0 --data data
Preparing data/local/train transcriptions
Preparing data/local/dev transcriptions
Preparing data/local/test transcriptions
local/aishell_data_prep.sh: AISHELL data preparation succeeded
目标为 wav.scp 和 text 两个文件,格式分别为:
wav.scp
BAC009S0724W0121 data/data_aishell/wav/dev/S0724/BAC009S0724W0121.wav
BAC009S0724W0122 data/data_aishell/wav/dev/S0724/BAC009S0724W0122.wav
BAC009S0724W0123 data/data_aishell/wav/dev/S0724/BAC009S0724W0123.wav
BAC009S0724W0124 data/data_aishell/wav/dev/S0724/BAC009S0724W0124.wav
BAC009S0724W0125 data/data_aishell/wav/dev/S0724/BAC009S0724W0125.wav
BAC009S0724W0126 data/data_aishell/wav/dev/S0724/BAC009S0724W0126.wav
BAC009S0724W0127 data/data_aishell/wav/dev/S0724/BAC009S0724W0127.wav
BAC009S0724W0128 data/data_aishell/wav/dev/S0724/BAC009S0724W0128.wav
BAC009S0724W0129 data/data_aishell/wav/dev/S0724/BAC009S0724W0129.wav
BAC009S0724W0130 data/data_aishell/wav/dev/S0724/BAC009S0724W0130.wav
...
text
BAC009S0724W0121 广州 市 房地 产中 介 协会 分析
BAC009S0724W0122 广州 市 房地 产中 介 协会 还 表示
BAC009S0724W0123 相比 于 其他 一 线 城市
BAC009S0724W0124 广州 二手 住宅 市场 表现 一直 相对 稳健
BAC009S0724W0125 而 在 股市 大幅 震荡 的 环境 下
BAC009S0724W0126 预计 第 三 季度 将 陆续 有 部分 股市 资金 重归 楼市
BAC009S0724W0127 但 受 穗 六条 及 二 套房 首 付 七成 的 制约
BAC009S0724W0128 下半 年 楼市 能 是否 能 回到 快速 上升 通道 依然 存在 变数
BAC009S0724W0129 其中 越秀 区 涨幅 领先
BAC009S0724W0130 天河 区 的 签约 面积 在 豪宅 交 投 增多 的 带动 下 上升 较快
dev, test和train三个目录下的数据量分别为:
14326 dev/text
7176 test/text
120098 train/text
步骤1:(可选)计算cmvn特征
由于TorchAudio会在数据加载的时候,实时计算特征。所以这步就是文本的直接复制。
tools/compute_cmvn_stats.py
is used to extract global cmvn(cepstral mean and variance normalization) statistics. These statistics will be used to normalize the acoustic features. Setting cmvn=false
will skip this step.
UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
warnings.warn(_create_warning_msg(
processed 1000 wavs, 452620 frames
processed 2000 wavs, 900146 frames
processed 3000 wavs, 1346665 frames
processed 4000 wavs, 1808066 frames
结果
tree raw_wav/
raw_wav/
├── dev
│ ├── text
│ └── wav.scp
├── test
│ ├── text
│ └── wav.scp
└── train
├── global_cmvn
├── text
└── wav.scp
步骤2: 生成label-token词典
bash run.sh --stage 2 --stop-stage 2 --data data
Make a dictionary
结果
head data/dict/lang_char.txt
<blank> 0 // 对应于CTC中的空符号
<unk> 1 // 未知token,对应于OOV词汇
一 2
丁 3
七 4
万 5
丈 6
三 7
上 8
下 9
…………
龄 4228
龙 4229
龚 4230
龟 4231
<sos/eos> 4232 // 对应于语音的开始和结束,使用同样的id
步骤3:准备WeNet数据格式
将所有的输入输出信息,生成到一个WeNet格式的单一文件中。
bash run.sh --stage 3 --stop-stage 3 --data data
Prepare data, prepare requried format
tools/format_data.sh --nj 32 --feat-type wav --feat raw_wav/dev/wav.scp raw_wav/dev data/dict/lang_char.txt
结果
bash run.sh --stage 3 --stop-stage 3 --data data
Prepare data, prepare requried format
tools/format_data.sh --nj 32 --feat-type wav --feat raw_wav/dev/wav.scp raw_wav/dev data/dict/lang_char.txt
sym2int.pl: replacing 斡 with 1
sym2int.pl: replacing 沂 with 1
sym2int.pl: replacing 绢 with 1
sym2int.pl: replacing 髌 with 1
sym2int.pl: replacing 湄 with 1
sym2int.pl: replacing 圃 with 1
sym2int.pl: replacing 柚 with 1
sym2int.pl: replacing 柚 with 1
sym2int.pl: replacing 柚 with 1
sym2int.pl: replacing 圃 with 1
sym2int.pl: replacing 荇 with 1
sym2int.pl: replacing 芪 with 1
sym2int.pl: replacing 薙 with 1
sym2int.pl: replacing 潦 with 1
sym2int.pl: replacing 鲤 with 1
sym2int.pl: replacing 涟 with 1
sym2int.pl: replacing 筏 with 1
sym2int.pl: replacing 筏 with 1
sym2int.pl: replacing 筏 with 1
sym2int.pl: replacing 锚 with 1
sym2int.pl: not warning for OOVs any more times
** Replaced 110 instances of OOVs with 1
tools/format_data.sh --nj 32 --feat-type wav --feat raw_wav/test/wav.scp raw_wav/test data/dict/lang_char.txt
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 埕 with 1
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 藕 with 1
sym2int.pl: replacing 祎 with 1
sym2int.pl: replacing 祎 with 1
sym2int.pl: replacing 甯 with 1
sym2int.pl: replacing 垭 with 1
sym2int.pl: replacing 纾 with 1
sym2int.pl: replacing 疙 with 1
sym2int.pl: replacing b with 1
sym2int.pl: replacing 垡 with 1
sym2int.pl: replacing 疡 with 1
sym2int.pl: replacing 嗪 with 1
sym2int.pl: replacing 雹 with 1
sym2int.pl: replacing 淅 with 1
sym2int.pl: replacing 谶 with 1
sym2int.pl: replacing 诏 with 1
sym2int.pl: not warning for OOVs any more times
** Replaced 51 instances of OOVs with 1
tools/format_data.sh --nj 32 --feat-type wav --feat raw_wav/train/wav.scp raw_wav/train data/dict/lang_char.txt
步骤4: NN训练
bash run.sh --stage 4 --stop-stage 4 --data data
多CPU模式
使用Multi-GPU的DDP(DataDistributedParallel)模式。
注意事项:多GPU训练,推荐使用nccl,如果不行,则可用gloo或者torch==1.6.0。
恢复模式
设置
checkpoint=exp/your_exp/$n.pt
配置
参考 conf/train_conformer.yaml
文件。
使用Tensorboard
tensorboard --logdir tensorboard/$your_exp_name/ --port 12598 --bind_all
步骤5:wav识别
在步骤4中的checkpoint出现时,我们即可以做wav识别。
bash run.sh --stage 5 --stop-stage 5 --data data
这里用了一个54.pt
,结果为:
==> exp/conformer/test_attention_rescoring/wer <==
Overall -> 5.91 % N=104765 C=98689 S=5899 D=177 I=113
Mandarin -> 5.90 % N=104762 C=98689 S=5896 D=177 I=113
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
===========================================================================
==> exp/conformer/test_attention/wer <==
Overall -> 7.49 % N=104765 C=97034 S=5862 D=1869 I=120
Mandarin -> 7.49 % N=104762 C=97034 S=5860 D=1868 I=120
Other -> 100.00 % N=3 C=0 S=2 D=1 I=0
===========================================================================
==> exp/conformer/test_ctc_greedy_search/wer <==
Overall -> 6.74 % N=104765 C=97843 S=6748 D=174 I=134
Mandarin -> 6.73 % N=104762 C=97843 S=6745 D=174 I=134
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
===========================================================================
==> exp/conformer/test_ctc_prefix_beam_search/wer <==
Overall -> 6.73 % N=104765 C=97848 S=6748 D=169 I=137
Mandarin -> 6.73 % N=104762 C=97848 S=6745 D=169 I=137
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
步骤6:导出模型
wenet/bin/export_jit.py
will export the trained model using Libtorch. The exported model files can be easily used for inference in other programming languages such as C++.
python wenet/bin/export_jit.py \
--config exp/conformer/train.yaml \
--checkpoint exp/conformer/56.pt \
--output_file final.zip \
--output_quant_file final_quant.zip
结果
70M Sep 19 21:02 final_quant.zip
199M Sep 19 21:01 final.zip
LM训练
7.0 环境准备
srilm
srilm:在官网下载,编译。将srilm/lm/bin/i686-m64
下的工具复制到~/bin
下。
openfst
glog
sudo yum install glog-devel gflags-devel
wenet-kaldi
这个编译起来比较费劲,在CMakeList.txt中增加头文件搜索路径
set( CMAKE_CXX_FLAGS "-std=c++11 -O3 -I /home/fw/work/2021/06.wenet/wenet/runtime/core/kaldi" )
set(CXX_FLAGS -std=c++11)
最后弃坑,还是直接编译kaldi吧。
decoder_main
这个必须要自己编译的。进入目录runtime/server/x86
,查看 README_CN.md,或者直接使用镜像:
docker pull mobvoiwenet/wenet:v0.5.0
7.1 准备词典
dict=/home/fw/work/2021/06.wenet/wenet/examples/aishell/s0/data/dict/lang_char.txt
unit_file=$dict
mkdir -p data/local/dict
cp $unit_file data/local/dict/units.txt
tools/fst/prepare_dict.py $unit_file data/resource_aishell/lexicon.txt data/local/dict/lexicon.txt
结果
$ head data/local/dict/lexicon.txt
啊 啊
啊啊啊 啊 啊 啊
阿 阿
阿尔 阿 尔
阿根廷 阿 根 廷
阿九 阿 九
阿克 阿 克
阿拉伯数字 阿 拉 伯 数 字
阿拉法特 阿 拉 法 特
阿拉木图 阿 拉 木 图
7.2 训练LM
lm=data/local/lm
mkdir -p $lm
tools/filter_scp.pl data/train/text data/data_aishell/transcript/aishell_transcript_v0.8.txt > $lm/text
local/aishell_train_lms.sh
结果
file data/local/lm/heldout: 10000 sentences, 89496 words, 0 OOVs
0 zeroprobs, logprob= -272791.2 ppl= 551.7352 ppl1= 1117.077
结果
ls -lh data/local/lm/*
-rw-rw-r-- 1 fw fw 511K Sep 21 14:17 data/local/lm/heldout
-rw-rw-r-- 1 fw fw 16M Sep 21 14:17 data/local/lm/lm.arpa
-rw-rw-r-- 1 fw fw 8.0M Sep 21 13:15 data/local/lm/text
-rw-rw-r-- 1 fw fw 7.9M Sep 21 14:17 data/local/lm/text.no_oov
-rw-rw-r-- 1 fw fw 5.5M Sep 21 14:17 data/local/lm/train
-rw-rw-r-- 1 fw fw 2.1M Sep 21 14:17 data/local/lm/unigram.counts
-rw-rw-r-- 1 fw fw 607K Sep 21 14:17 data/local/lm/word.counts
-rw-rw-r-- 1 fw fw 1.1M Sep 21 14:17 data/local/lm/wordlist
7.3 准备TLG
tools/fst/compile_lexicon_token_fst.sh data/local/dict data/local/tmp data/local/lang
tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
结果
fstaddselfloops 'echo 4234 |' 'echo 123660 |'
Lexicon and token FSTs compiling succeeded
和
arpa2fst --read-symbol-table=data/lang_test/words.txt --keep-symbols=true -
LOG (arpa2fst[5.5.971~1-07043]:Read():arpa-file-parser.cc:94) Reading \data\ section.
LOG (arpa2fst[5.5.971~1-07043]:Read():arpa-file-parser.cc:149) Reading \1-grams: section.
LOG (arpa2fst[5.5.971~1-07043]:Read():arpa-file-parser.cc:149) Reading \2-grams: section.
LOG (arpa2fst[5.5.971~1-07043]:Read():arpa-file-parser.cc:149) Reading \3-grams: section.
Checking how stochastic G is (the first of these numbers should be small):
fstisstochastic data/lang_test/G.fst
0 -1.14386
fsttablecompose data/lang_test/L.fst data/lang_test/G.fst
fstminimizeencoded
fstdeterminizestar --use-log=true
fsttablecompose data/lang_test/T.fst data/lang_test/LG.fst
Composing decoding graph TLG.fst succeeded
7.4 解码
./tools/decode.sh --nj 16 \
--beam 15.0 --lattice_beam 7.5 --max_active 7000 \
--blank_skip_thresh 0.98 --ctc_weight 0.5 --rescoring_weight 1.0 \
--fst_path data/lang_test/TLG.fst \
data/test/wav.scp data/test/text $dir/final.zip \
data/lang_test/words.txt $dir/lm_with_runtime
sudo docker run --rm -it -v /home/fw/work/2021/06.wenet/:/home/mount docker.io/mobvoiwenet/wenet:v0.5.0 bash
其它
Amazon SageMaker
Amazon SageMaker是一项完全托管的机器学习服务,涵盖了数据标记、数据处理、模型训练、超参调优、模型部署及持续模型监控等基本流程;也提供自动打标签,自动机器学习,监控模型训练等高阶功能。
其通过全托管的机器学习基础设施和对主流框架的支持,可以降低客户机器学习的整体拥有成本。
模型在训练的过程中,需要用到大量的计算资源,我们可以借助Amazon SageMaker非常方便的启动包含多台完全托管的训练实例集群,加速训练过程。