当前位置: 首页 > 面试题库 >

当state_is_tuple = True时如何设置TensorFlow RNN状态?

长孙沈义
2023-03-14
问题内容

我已经使用TensorFlow编写了RNN语言模型。该模型被实现为一个RNN类。图结构内置在构造函数中,而RNN.trainandRNN.test方法则运行它。

当我移到训练集中的新文档时,或者当我想在训练期间运行验证集时,我希望能够重置RNN状态。我通过在训练循环中管理状态,并通过Feed字典将其传递到图形中来实现此目的。

在构造函数中,我像这样定义RNN

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

训练循环如下所示

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

x并且y是在文档中的训练数据批次。这样做的想法是,每批批处理后都会传递最新状态,除非当我开始创建新文档时,通过运行将状态清零self.reset_state

这一切正常。现在我想更改我的RNN以使用推荐的state_is_tuple=True。但是,我不知道如何通过提要字典传递更复杂的LSTM状态对象。另外我也不知道要将哪些参数传递给self.state = tf.placeholder(...)构造函数中的行。

这里正确的策略是什么?仍然没有太多示例代码或文档dynamic_rnn可用。

TensorFlow问题2695和2838似乎相关。

关于WILDML的博客文章解决了这些问题,但没有直接说明答案。

另请参见TensorFlow:记住下一批的LSTM状态(有状态LSTM)。


问题答案:

Tensorflow占位符的一个问题是,您只能使用Python列表或Numpy数组来提供它(我认为)。因此,您无法在LSTMStateTuple的元组中保存两次运行之间的状态。

我通过将状态保存在这样的张量中解决了这个问题

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

在LSTM层中有两个组件, 单元状态隐藏状态 ,这就是“
2”的含义。(这篇文章很棒:https :
//arxiv.org/pdf/1506.00019.pdf)

构建图时,您将解压缩并创建元组状态,如下所示:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

然后您以通常的方式获得新状态

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

可能不应该这样……也许他们正在研究解决方案。



 类似资料:
  • 问题内容: 我需要禁用IPv6。为此,Java文档指示设置jvm属性。 但是我不了解如何从代码本身做到这一点。 许多论坛都演示了如何从命令提示符下执行此操作,但是我需要在运行时执行此操作。 问题答案: 您可以使用 这等效于通过以下命令在命令行中传递它

  • 我是Kibana新手,将数据加载到Elastic 5.0.0-alpha3中,并使用Kibana5.0.0-alpha3进行可视化。我可以将一些数字字段显示为直方图,但当我想使用文本字段时,我会得到: 我被警告说数据(出版商的名字)可能已经被分析成子字段,但是我还是想显示。 如何设置< code>fielddata=true? 编辑:Kibana github上最近的问题表明这是5.0.0中的新功

  • 我开始使用Spring Statemachine,但在管理对象的状态时遇到了一些麻烦。 null 你觉得我的做法怎么样?

  • 当我构建Azure Function(Java)的以下代码时,isSessionsEnabled未在生成的函数中设置。json。 如何在函数中设置isSessionsEnabled=true。json? 构建使用Gradle。 https://github.com/microsoft/azure-gradle-plugins/blob/master/azure-functions-gradle-p

  • 当我使用concurrentKafkaListenerContainerFactory时,有什么方法可以设置主题吗?我根本不想做任何注释。

  • 我正在使用JPA类创建数据库。 如果我们有许多对一个关系,我们可以重写ForeignKey name name,如下所示: 以DB为单位,我们会有这样一个结果: 我使用这个版本的JARS: 我的java版本是1.8