CTC学习笔记
背景
CTC,Connectionist Temporal Classification,用来解决输入序列和输出序列难以一一对应的问题。
举例来说,在语音识别中,我们希望音频中的音素和翻译后的字符可以一一对应,这是训练时一个很天然的想法。但是要对齐是一件很困难的事,有人说话块,有人说话慢,每个人说话快慢不同,不可能手动地对音素和字符对齐,这样太耗时。
再比如,在OCR中使用RNN时,RNN的每一个输出要对应到字符图像中的每一个位置,要手工做这样的标记工作量太大,而且图像中的字符数量不同,字体样式不同,大小不同,导致输出不一定能和每个字符一一对应。
问题描述
为了简化问题,我们以对26个英文字符进行识别为例。考虑到有些位置没有字符,定义一个-作为空白符加入到字符集合,这样总共有27个字符。
我们将识别过程分为三步,第一步是将音频转换特征序列,第二步为将特征序列转换为27个字符的序列,第二步是将27个字符的序列转换为对应的文本。
问题解决
提取特征
在音频上,每个窗长取一段音频,总共有t个时间步(即窗口)。每个窗口提取特征维度为m,即m个数字。
特征到字符序列
以上面的矩阵做为输入,输出为27个字符的序列。记每个窗口对应的字符为y,总共有t个字符。每个y字符是27个字符的概率为y1, y2, ……,yn,其和为1,其中n为字符的个数。
如果用空间来说的话,输出序列是一个t维空间,空间的每个轴有n个值。
CTC引入了blank(该帧没有预测值),每个预测的分类对应的一整段语音中的一个spike(尖峰),其他不是尖峰的位置认为是blank。对于一段语音,CTC最后的输出是spike(尖峰)的序列,并不关心每一个音素持续了多长时间。
字符序列到文本
拿到输出序列后,可以定义一个B函数,来将序列转换为文本。
比如我们定义一个B变换,对输出序列(比如下例中的4个π)进行变换,变换成真实输出(比如下例中的state),把连续的相同字符删减为1个并删去空白符。举例说明,当t=12时:
B(π_1)=B(−−stta−t−−−e)=state
B(π_2)=B(sst−aaa−tee−)=state
B(π_3)=B(−−sttaa−tee−)=state
B(π_4)=B(sst−aa−t−−−e)=state
优化神经网络
我们使用神经网络来训练前面的特征到字符序列。这里使用LSTM神经网络。我们的优化目标函数为,在特征给定的情况下,输出为真实的文本l的概率。
由于不能输入公式,这里的细节略,请查看参考一。
由于有t个位置,每个位置有n种选择,即总共有$n^t$种可能。如果逐条遍历来求得,时间复杂度是指数级的。因此CTC借用了HMM中的“前向-后向算法”(forward-backward algorithm)来计算,将时间复杂度变为$nT$。
CTC中的前向后向算法
由于真实输出ll是一个序列,序列可以通过一个路径图中的一条路径来表示,我们也称输出序列ll为路径ll。
以前面的π1,π2,π3,π4为例子,画出两条路径(还有两条没画出来),如下图所示
定义 forward 为:在t时刻时,前面1-t这些时刻的路径中的概率的和。
t = 1时,符号只能为空白符或者l_1(即s)。
观察上图((图源见参考资料[1])可以发现,如果t=6时字符是a,那么t=5时只能是字符a,t,空白符三选一,否则经过B变换后无法得到state。
最后:forward和backward可以用前面的dp递推式计算出来,时间复杂度是nT,相比于前面的指数复杂度n^T大大减小了计算量。
这样对LSTM的输出y求导之后,再根据y对LSTM里面的权重参数w进行链式求导,就可以使用梯度下降的方法来更新参数了。
本文这里只是简要介绍,详细可见参考一。
CTC的预测
一种方法是Best Path search。计算概率最大的一条输出序列(假设时间步独立,那么直接在每个时间步取概率最大的字符输出即可),但是这样没有考虑多个输出序列对应一个真实输出这件事,举个例子,[s,s,-]和[s,s,s]的概率比[s,t,a]低,但是它们的概率之和会高于[s,t,a]。
第二种方法是Beam Search。假设指定B=3,预测过程如下图所示(图源见参考资料[2])。在第一个时间步选取概率最大的三个字符,然后在第二个时间步也选取概率最大的三个字符,两两组合(概率相乘)可以组合成9个序列,这些序列在B转换之后会得到一些相同输出,把具有相同输出的序列进行合并,比如有3个序列都可以转换成a,把它们合并(概率加在一起),计算出概率最大的三个序列,然后继续和下一个时间步的字符进行同样的合并。
讨论
CTC的基本假设
第一个是条件独立性。CTC做了一个假设就是不同时间步的输出之间是独立的。这个假设对于很多序列问题来说并不成立,输出序列之间往往存在联系。
第二个是单调对齐。CTC只允许单调对齐,在语音识别中可能是有效的,但是在机器翻译中,比如目标语句中的一些比较后的词,可能与源语句中前面的一些词对应,这个CTC是没法做到的。
第三个是多对一映射。CTC的输入和输出是多对一的关系。这意味着输出长度不能超过输入长度,这在手写字体识别或者语音中不是什么问题,因为通常输入都会大于输出,但是对于输出长度大于输入长度的问题CTC就无法处理了。
对齐
CTC关注的是一个输入序列到一个输出序列的结果,预测两个序列是否接近,不会给出输出序列中每个结果在时间点上是否和输入的序列正好对齐。即无法给出输入序列的词的时间信息。