当前位置: 首页 > 知识库问答 >
问题:

在pytorch的seq2seq模型中,批处理是如何工作的?

翟丰茂
2023-03-14

我试图在Pytorch中实现seq2seq模型,我对批处理有一些问题。例如,我有一批数据,其尺寸是

[batch_sizesequence_lengthsencoding_dimension]

其中,批次中每个示例的序列长度不同。

现在,我通过将批处理中的每个元素填充到最长序列的长度来完成编码部分。

通过这种方式,如果我向我的网络输入一个与上述形状相同的批次,我会得到以下输出:

输出,形状[批次大小、序列长度、隐藏层尺寸]

隐藏状态,形状[批次大小,隐藏层尺寸]

单元状态,形状[批次大小,隐藏层尺寸]

现在,从输出中,我为每个序列取最后一个相关元素,即沿着sequence_lengths维度的元素,对应于序列的最后一个未填充元素。因此,我得到的最终输出是形状[batch_size,hidden_layer_dimension]

但是现在我有一个问题,从这个向量中解码它。如何处理同一批中不同长度序列的解码?我试着用谷歌搜索,发现了这个,但他们似乎没有解决这个问题。我曾想过对整个批处理逐个元素进行处理,但随后我遇到了传递初始隐藏状态的问题,因为来自编码器的将是形状[batch\u size,hidden\u layer\u dimension],而来自解码器的将是形状[1,hidden\u layer\u dimension]

我是不是漏了什么?谢谢你的帮助!

共有2个答案

邬楚青
2023-03-14

可能有点老,但我目前面临着类似的问题:

Im将数据分批输入seq2seq模型,但当解码器预测来自

decoder_input = torch.tensor([[self.SOS_token] * self.batch_size], device=device)
# decoder_input.size() = ([1,2])
for di in range(target_length):
    decoder_output, decoder_hidden, decoder_attention = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
    topv, topi = decoder_output.data.topk(1)
    decoder_input = topi.squeeze().detach()
    # decoder_input.size() = ([1])
    outputs[di] = decoder_output
    if decoder_input.item() == self.EOS_token:
        break

因此,如果我使用batch_size=2,我有2个SOS_标记,但我只得到1个单词作为预测,这是模型无法计算的,因为它需要不同的大小。我可以把它像SOS_记号一样乘以吗?最好的

卫高谊
2023-03-14

你没有错过任何东西。我可以帮助你,因为我已经在几个序列到序列的应用程序使用PyTorch。下面我给大家举一个简单的例子。

class Seq2Seq(nn.Module):
    """A Seq2seq network trained on predicting the next query."""

    def __init__(self, dictionary, embedding_index, args):
        super(Seq2Seq, self).__init__()

        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(dictionary), self.config)
        self.embedding.init_embedding_weights(dictionary, embedding_index, self.config.emsize)

        self.encoder = Encoder(self.config.emsize, self.config.nhid_enc, self.config.bidirection, self.config)
        self.decoder = Decoder(self.config.emsize, self.config.nhid_enc * self.num_directions, len(dictionary),
                                   self.config)

    @staticmethod
    def compute_decoding_loss(logits, target, seq_idx, length):
        losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)).squeeze()
        mask = helper.mask(length, seq_idx)  # mask: batch x 1
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if not num_non_zero_elem:
        return losses.sum(), 0 if not num_non_zero_elem else losses.sum(), num_non_zero_elem[0]

    def forward(self, q1_var, q1_len, q2_var, q2_len):
        # encode the query
        embedded_q1 = self.embedding(q1_var)
        encoded_q1, hidden = self.encoder(embedded_q1, q1_len)

        if self.config.bidirection:
            if self.config.model == 'LSTM':
                h_t, c_t = hidden[0][-2:], hidden[1][-2:]
                decoder_hidden = torch.cat((h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2), torch.cat(
                    (c_t[0].unsqueeze(0), c_t[1].unsqueeze(0)), 2)
            else:
                h_t = hidden[0][-2:]
                decoder_hidden = torch.cat((h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2)
        else:
            if self.config.model == 'LSTM':
                decoder_hidden = hidden[0][-1], hidden[1][-1]
            else:
                decoder_hidden = hidden[-1]

        decoding_loss, total_local_decoding_loss_element = 0, 0
        for idx in range(q2_var.size(1) - 1):
            input_variable = q2_var[:, idx]
            embedded_decoder_input = self.embedding(input_variable).unsqueeze(1)
            decoder_output, decoder_hidden = self.decoder(embedded_decoder_input, decoder_hidden)
            local_loss, num_local_loss = self.compute_decoding_loss(decoder_output, q2_var[:, idx + 1], idx, q2_len)
            decoding_loss += local_loss
            total_local_decoding_loss_element += num_local_loss

        if total_local_decoding_loss_element > 0:
            decoding_loss = decoding_loss / total_local_decoding_loss_element

        return decoding_loss

您可以在这里看到完整的源代码。此应用程序是关于在给定当前web搜索查询的情况下预测用户的下一个web搜索查询。

你问题的回答者:

如何处理同一批中不同长度序列的解码?

你有填充序列,所以你可以考虑,因为所有的序列都是相同的长度。但是,当您计算损失时,您需要使用掩蔽忽略这些填充项的损失。

在上面的示例中,我使用了掩蔽技术来实现同样的效果。

此外,您在以下方面完全正确:您需要为小批量逐个元素解码。初始解码器状态[批大小、隐藏层大小]也很好。您只需要在维度0处取消查询,使其[1,批量大小,隐藏层大小]

请注意,您不需要循环遍历批处理中的每个示例,您可以一次执行整个批处理,但您需要循环遍历序列的元素。

 类似资料:
  • 我想了解Spring Batch是如何进行事务管理的。这不是一个技术问题,而是一个概念性的问题:Spring Batch使用什么方法?这种方法的后果是什么? 让我试着澄清一下这个问题。例如,在TaskletStep中,我看到步骤执行通常如下所示: 准备步骤元数据的几个JobRepository事务 每一块要处理的业务事务 更多JobRepository事务,用区块处理的结果更新步骤元数据 这似乎是

  • 我正在上一门使用处理的课。 我在理解map()函数时遇到了问题。 根据文件记载(http://www.processing.org/reference/map_.html): 将数字从一个范围重新映射到另一个范围。 在上面的第一个示例中,数字25从0到100范围内的值转换为从窗口的左边缘(0)到右边缘(宽度)的值。 如第二个示例所示,范围之外的数字不会被钳制到最小和最大参数值,因为范围之外的值通常

  • Batched Seq2Seq ExampleBased on the seq2seq-translation-batched.ipynb from practical-pytorch, but more extra features. This example runs grammatical error correction task where the source sequence is

  • 我对下面代码片段中的方法感到困惑。 我的困惑在于以下几行。 什么是张量。view()函数的作用是什么?我在很多地方见过它的用法,但我不明白它是如何解释它的参数的。 如果我将负值作为参数赋给函数,会发生什么情况?例如,如果我调用,? 有人能用一些例子解释一下函数的主要原理吗?

  • 上周我阅读了有关vertx的文档。我不明白的是vertx处理器是如何工作的?例如 和服务器是: (P.S.我知道我首先应该检查处理程序是否成功,然后采取一些措施,但为了简化代码,我删除了这种检查,如果在30秒内没有任何响应,则处理程序中会出现异常,也会从正式文档中删除。) 从上面的代码中,客户端每秒发送请求,并且不等待响应,但是它有一个处理程序,当响应到来时将执行该处理程序。 jdbcVertx监

  • 我有以下工作要处理在一定的时间间隔或特别的基础上。 作业中的步骤如下: 我也想要用户界面,在那里我可以触发一个特别的基础上的工作,而且我应该能够提供参数从用户界面。 我想用Spring batch来完成这个任务,但它更多的是用于读->处理->写之类的工作。这里,在第一步中,我正在生成由第二步读取的数据。我不确定我是否还可以使用Spring batch来实现这个,或者有更好的方法来实现这个。