当前位置: 首页 > 知识库问答 >
问题:

理解Keras中语音识别的CTC损失

轩辕海
2023-03-14

我试图了解CTC损失是如何为语音识别工作的,以及它如何在Keras中实现。

  1. 我认为我理解的(如果我错了,请纠正我!)

大体上,CTC损耗被添加到经典网络之上,以便逐个元素(文本或语音的字母)解码顺序信息,而不是直接解码元素块(例如单词)。

假设我们将一些句子的语句作为MFCC输入。

使用CTC损失的目标是学习如何使每个字母在每个时间步与MFCC匹配。因此,Dense softmax输出层由与句子组成所需元素数量一样多的神经元组成:

  • 字母表(a,b,…,z)

然后,softmax层具有29个神经元(26个用于一些特殊字符的字母表)。

为了实现它,我发现我可以这样做:

# CTC implementation from Keras example found at https://github.com/keras- 
# team/keras/blob/master/examples/image_ocr.py

def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    # the 2 is critical here since the first couple outputs of the RNN
    # tend to be garbage:
    # print "y_pred_shape: ", y_pred.shape
    y_pred = y_pred[:, 2:, :]
    # print "y_pred_shape: ", y_pred.shape
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)



input_data = Input(shape=(1000, 20))
#let's say each MFCC is (1000 timestamps x 20 features)

x = Bidirectional(lstm(...,return_sequences=True))(input_data)

x = Bidirectional(lstm(...,return_sequences=True))(x)

y_pred = TimeDistributed(Dense(units=ALPHABET_LENGTH, activation='softmax'))(x)

loss_out = Lambda(function=ctc_lambda_func, name='ctc', output_shape=(1,))(
                  [y_pred, y_true, input_length, label_length])

model = Model(inputs=[input_data, y_true, input_length,label_length], 
                      outputs=loss_out)

ALPHABET_LENGTH=29(字母表长度特殊字符)

以及:

  • y_true:包含真值标签的张量(样本,最大字符串长度)

(来源)

现在,我面临一些问题:

  • 这是编码和使用CTC丢失的正确方法吗

共有1个答案

澹台新知
2023-03-14
  • y_true您的基本真相数据。您将要与培训中的模型输出进行比较的数据。(另一方面,y_pred是模型的计算输出)

这种损失似乎期望您的模型的输出(y_pred)有不同的长度,以及您的地面真相数据(y_true)。这可能是为了避免计算句子结束后垃圾字符的损失(因为您需要一个固定大小的张量来同时处理大量的句子)

因为函数的留档是要求形状(样本,长度),所以格式是...每个句子中每个char的char索引。

有一些可能性。

如果所有长度相同,您可以轻松地将其用作常规损耗:

def ctc_loss(y_true, y_pred):

    return K.ctc_batch_cost(y_true, y_pred, input_length, label_length)
    #where input_length and label_length are constants you created previously
    #the easiest way here is to have a fixed batch size in training 
    #the lengths should have the same batch size (see shapes in the link for ctc_cost)    

model.compile(loss=ctc_loss, ...)   

#here is how you pass the labels for training
model.fit(input_data_X_train, ground_truth_data_Y_train, ....)

这有点复杂,你需要你的模型以某种方式告诉你每个输出句子的长度。

  • 有一个end_of_sentence字符,并检测它在句子中的位置。
  • 让你的模型的一个分支来计算这个数字,并将其舍入为整数。
  • (Hardcore)如果你使用有状态的手动训练循环,获取你决定完成一个句子的迭代的索引

我喜欢第一个想法,并将在这里举例说明。

def ctc_find_eos(y_true, y_pred):

    #convert y_pred from one-hot to label indices
    y_pred_ind = K.argmax(y_pred, axis=-1)

    #to make sure y_pred has one end_of_sentence (to avoid errors)
    y_pred_end = K.concatenate([
                                  y_pred_ind[:,:-1], 
                                  eos_index * K.ones_like(y_pred_ind[:,-1:])
                               ], axis = 1)

    #to make sure the first occurrence of the char is more important than subsequent ones
    occurrence_weights = K.arange(start = max_length, stop=0, dtype=K.floatx())

    #is eos?
    is_eos_true = K.cast_to_floatx(K.equal(y_true, eos_index))
    is_eos_pred = K.cast_to_floatx(K.equal(y_pred_end, eos_index))

    #lengths
    true_lengths = 1 + K.argmax(occurrence_weights * is_eos_true, axis=1)
    pred_lengths = 1 + K.argmax(occurrence_weights * is_eos_pred, axis=1)

    #reshape
    true_lengths = K.reshape(true_lengths, (-1,1))
    pred_lengths = K.reshape(pred_lengths, (-1,1))

    return K.ctc_batch_cost(y_true, y_pred, pred_lengths, true_lengths)

model.compile(loss=ctc_find_eos, ....)

如果使用另一个选项,请使用模型分支来计算长度,将这些长度连接到输出的第一步或最后一步,并确保对地面真相数据中的真实长度执行相同的操作。然后,在损失函数中,只取长度的部分:

def ctc_concatenated_length(y_true, y_pred):

    #assuming you concatenated the length in the first step
    true_lengths = y_true[:,:1] #may need to cast to int
    y_true = y_true[:, 1:]

    #since y_pred uses one-hot, you will need to concatenate to full size of the last axis, 
    #thus the 0 here
    pred_lengths = K.cast(y_pred[:, :1, 0], "int32")
    y_pred = y_pred[:, 1:]

    return K.ctc_batch_cost(y_true, y_pred, pred_lengths, true_lengths)
 类似资料:
  • 由于连接到不同的API,我目前正在开发一个工具,允许我阅读所有的通知。 它工作得很好,但现在我想用一些声音命令来做一些动作。 就像当软件说“一封来自Bob的邮件”时,我想说“阅读”或“存档”。 我的软件是通过一个节点服务器运行的,目前我没有任何浏览器实现,但它可以是一个计划。 在NodeJS中,启用语音到文本的最佳方式是什么? 我在它上面看到了很多线程,但主要是使用浏览器,如果可能的话,我希望在一

  • 语音识别是以语音为研究对象,通过语音信号处理和模式识别让机器自动识别和理解人类口述的语言。语音识别技术就是让机器通过识别和理解过程把语音信号转变为相应的文本或命令的高技术。语音识别是一门涉及面很广的交叉学科,它与声学、语音学、语言学、信息理论、模式识别理论以及神经生物学等学科都有非常密切的关系。语音识别技术正逐步成为计算机信息处理技术中的关键技术,语音技术的应用已经成为一个具有竞争性的新兴高技术产

  • 识别简单的语句。

  • 光环板内置的麦克风和Wi-Fi功能相结合,可以实现语音识别相关的应用。通过接入互联网,可以使用各大主流科技公司提供的语音识别服务,像是微软语音识别服务。使用联网功能需要登陆慧编程账号。 注册/登陆慧编程 点击工具栏右侧的登陆/注册按钮,依据提示登陆/注册账号。 启用上传模式 点击启用上传模式。 新建语音识别项目 我们将新建一个语音识别项目,使用语音来点亮光环板的LED灯。 连接网络 1. 添加事件

  • 1.1. ASR(语音识别) HTTP接口文档 1.1.1. 概述 1.1.2. 服务地址 1.1.3. 协议详解 1.1.4. HTTP API 接入参考Demo 1.1.5. 协议概述 1.1. ASR(语音识别) HTTP接口文档 1.1.1. 概述 本文档目的是描述Rokid云ASR(语音识别)Http接口协议,面向想要了解ASR细节,并具有一定开发能力的开发者或用户。 1.1.2. 服务

  • 1.1. ASR(语音识别) WebSocket接口文档 1.1.1. 概述 1.1.2. 服务地址 1.1.3. 协议详解 1.1.4. 协议地址 1.1.5. 协议概述 1.1.6. ASR 云端一些细节 1.1. ASR(语音识别) WebSocket接口文档 1.1.1. 概述 本文档目的是描述Rokid云ASR(语音识别)WebSocket接口协议,面向想要了解ASR细节,并具有一定开发