CTC在语音识别上的应用,loss为nan的处理

ctc在pytorch1.2以上的版本中有集成好的是实现。

torch.nn.functional.ctc_loss

原理不再介绍,有很多开源的实现。主要说一下自己遇到的问题。
在语音上应用时,会遇到loss为nan的情况,如果代码在交叉熵损失或者其他损失的情况下可以正常跑,说明数据没问题。主要原因出在对齐上。在一个batch中一条发音可能比较短,对应的目标文本也比较短,ctc就无法对齐,就会出现loss为Inf,后面就体现为nan。
网上找了很多解决方法,都是针对数据的,但是其实

torch.nn.functional.ctc_loss

已经考虑到了这一点,在函数中有一个参数zero_infinity。
直接看pytorch的官方文档就可以看到。

torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)

这个参数就是对齐很短的时候,把出现的inf损失置为0。
这样可能会带来一定的精度的下降,实验还没做完,后续看看效果。
zero_infinity=True就行了,对结果没什么影响。


版权声明:本文为qq_37950002原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。