当前位置: 首页 > 知识库问答 >
问题:

Pyrotch LSTM输入尺寸

赫连华皓
2023-03-14

我试图用PyTorch LSTM训练一个简单的2层神经网络,但我很难解释PyTorch留档。具体来说,我不太确定如何处理我的训练数据。

我想做的是通过小批次在一个非常大的数据集上训练我的网络,每个批次有100个元素长。每个数据元素将具有5个特征。留档声明层的输入应该是形状(seq_len,batch_size,input_size)。我应该如何调整输入?

我一直在关注这篇文章:https://discuss.pytorch.org/t/understanding-lstm-input/31110/3如果我正确解释这一点,每个小批量应该是形状(100, 100, 5)。但是在这种情况下,seq_len和batch_size有什么区别?此外,这是否意味着输入LSTM层的第一层应该有5个单元?

非常感谢。

共有1个答案

冯通
2023-03-14

这是一个古老的问题,但由于它已经被浏览了80次,没有任何反应,让我来尝试一下。

LSTM网络用于预测序列。在NLP中,这将是一个单词序列;在经济学中,一系列经济指标;等

第一个参数是这些序列的长度。如果序列数据是由句子组成的,那么“汤姆有一只又黑又丑的猫”是一个长度为7(seq_len)的序列,每个单词一个,可能是第8个,表示句子的结尾。

当然,您可能会反对“如果我的序列长度不同怎么办?”这是一种常见的情况。

两种最常见的解决方案是:

>

  • 用空元素填充序列。例如,如果你的最长句子有15个单词,那么将上面的句子编码为“[Tom][has][a][black][and][aughous][cat][EOS][[]],其中EOS代表句子的结尾。突然,你所有的序列长度都变为15,这就解决了你的问题。一旦找到[EOS]令牌,模型将很快了解到它后面是无限序列的空令牌[],这种方法几乎不会对您的网络征税。

    发送等长的迷你批次。例如,用2个单词训练网络中的所有句子,然后用3个单词,然后用4个单词。当然,每个迷你批次的seq_len会增加,每个迷你批次的大小会根据您的数据中有多少个长度为N的序列而有所不同。

    一个两全其美的方法是将数据分成大小大致相等的小批量,按大致长度分组,只添加必要的填充。例如,如果你把长度为6、7和8的句子小批量组合在一起,那么长度为8的序列不需要填充,而长度为6的序列只需要2。如果你有一个大数据集,序列的长度变化很大,那是最好的方法。

    不过,选项1是最简单(也是最懒)的方法,在小数据集上效果很好。

    最后一件事。。。始终在末尾填充数据,而不是在开头。

    我希望这有帮助。

  •  类似资料:
    • 使用网格的列来设置输入框的大小,如 .large-6, .medium-6, 等。 更多网格系统知识,可以点击 相等大小列 以下演示了相等大小列的实例: 实例<form>   <div>     <div>       <label>medium-4 (100% on small, stacked)         <input type="text" placeholder="Name">   

    • 问题内容: 我试图使用keras训练LSTM模型,但我认为这里出了点问题。 我有一个错误 ValueError:检查输入时出错:预期lstm_17_input具有3个维,但数组的形状为(10000,0,20) 虽然我的代码看起来像 其中已的形状和前几个数据点像 并且具有这样的形状的,它是二进制(0/1)标签阵列。 有人可以指出我在哪里错了吗? 问题答案: 为了完整起见,这是发生了什么。 首先,像K

    • 我对TensorFlow和LSTM架构相当陌生。我在计算数据集的输入和输出(x_train、x_test、y_trainy_test)时遇到了问题。 我最初输入的形状: X_列车:(366,4) Ytrain和Ytest是一系列股票价格。Xtrain和Xtest是我想学习的四个预测股价的功能。 这是产生的错误: -------------------------------------------

    • 我写了这段代码。我的输入形状是(100 x100 X3)。我是深度学习的新手。我花了这么多时间在这个问题上,但无法解决这个问题。任何帮助都非常感谢。 错误:在[15]:运行文件('/user/Project/SM/src/ann\u algo\u keras.py',wdir='/user/Project/SM/src')中随机启动突触权重:模型:“sequential\u 3” conv2d_1

    • 文件 std::fs::File 本身实现了 Read 和 Write trait,所以文件的输入输出非常简单,只要得到一个 File 类型实例就可以调用读写接口进行文件输入与输出操作了。而要得到 File 就得让操作系统打开(open)或新建(create)一个文件。还是拿例子来说明 use std::io; use std::io::prelude::*; use std::fs::File;

    • 回顾一下我们写的第一个 Rust 程序就是带副作用的,其副作用就是向标准输出(stdout),通常是终端或屏幕,输出了 Hello, World! 让屏幕上这几个字符的地方点亮起来。println! 宏是最常见的输出,用宏来做输出的还有 print!,两者都是向标准输出(stdout)输出,两者的区别也一眼就能看出。至于格式化输出,基础运算符和字符串格式化小节有详细说明,这里就不再啰嗦了。 更通用