当前位置: 首页 > 面试题库 >

Tensorflow,在RNN中保存状态的最佳方法?

田骁
2023-03-14
问题内容

我目前在tensorflow中具有一系列链接在一​​起的RNN的以下代码。我不使用MultiRNN,因为稍后我将对每个图层的输出进行处理。

 for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size]
        time_outputs = [None] * TIME_STEPS

        for t in range(TIME_STEPS):
            rnn_input = getTimeStep(rnn_outputs[r - 1], t)
            time_outputs[t], state = rnn_func(rnn_input, state)
            time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE))
            scope.reuse_variables()
        rnn_outputs[r] = tf.concat(1, time_outputs)

目前,我有固定的时间步数。但是,我想将其更改为只有一个时间步长,但要记住批次之间的状态。因此,我需要为每个层创建一个状态变量,并将其分配给每个层的最终状态。这样的事情。

for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        saved_state = tf.get_variable('saved_state', ...)
        rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state)
        saved_state = tf.assign(saved_state, state)

然后,对于每一层,我都需要评估sess.run函数中的保存状态以及调用训练函数。我需要为每个rnn层执行此操作。这似乎有点麻烦。我需要跟踪每个保存的状态并在运行中对其进行评估。同样,然后运行将需要将状态从我的GPU复制到主机内存,这会造成效率低下和不必要的情况。有更好的方法吗?


问题答案:

这是state_is_tuple=True通过定义状态变量来更新LSTM初始状态的代码。它还支持多层。

我们定义了两个函数-
一个用于获取具有初始零状态的状态变量,另一个用于返回操作的函数,可以传递给该函数以session.run用LSTM的最后一个隐藏状态更新状态变量。

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    return tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    return tf.tuple(update_ops)

我们可以用它来更新每批LSTM的状态。请注意,我tf.nn.dynamic_rnn用于展开:

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)

# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)

# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)

# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})

该答案的主要区别在于,state_is_tuple=True使LSTM的状态成为包含两个变量(单元状态和隐藏状态)而不是单个变量的LSTMStateTuple。然后,使用多层可以使LSTM的状态成为LSTMStateTuples的元组-
每层一个。

重置为零

使用训练有素的模型进行预测/解码时,您可能需要将状态重置为零。然后,您可以使用此功能:

def get_state_reset_op(state_variables, cell, batch_size):
    # Return an operation to set each variable in a list of LSTMStateTuples to zero
    zero_states = cell.zero_state(batch_size, tf.float32)
    return get_state_update_op(state_variables, zero_states)

例如上面的例子:

reset_state_op = get_state_reset_op(state, cell, max_batch_size)
# Reset the state to zero before feeding input
sess.run([reset_state_op])
sess.run([outputs, update_op], {data: ...})


 类似资料:
  • 问题内容: 我的Web应用程序使用会话来存储有关用户登录后的信息,并在用户在应用程序中逐页浏览时维护该信息。在这个特定的应用程序,我存储,和人的。 我想在登录时提供“保持登录状态”选项,该选项会将cookie在用户的计算机上放置两周,当他们返回应用程序时,将使用相同的详细信息重新开始会话。 这样做的最佳方法是什么?我不想将其存储在cookie中,因为这似乎会使一个用户更容易伪造另一位用户的身份。

  • 我需要保存一个字符串值,这个值每小时都在变化,由一个路由获取,供其他路由使用。我使用的是Spring XML DSL。 我让它工作正常,但它似乎很笨拙。我有一个带有setter和getter的java类来包装字符串,我用: 然后通过另一种途径把它放回体内: 有没有一种不需要定制java类的纯Spring方式在xml中做到这一点?

  • 问题内容: 我上了一堂课,想跟踪学生的统计数据。我打算稍后制作一个GUI来处理这些数据。 我的主要问题是:保存和以后检索此数据的最佳方法是什么? 我已经读过关于pickle和JSON的文章,但是我并没有真正了解它们的工作方式(尤其是关于它们如何保存数据的信息,例如哪种格式和位置)。 问题答案: 对于持久性数据(存储有关学生的信息),数据库是一个不错的选择。如前所述,Python附带了Sqlite3

  • 但我正在寻找一种更优雅的方式来存储存储中的获取状态。 因此,我尝试了另一种方法,并考虑创建一个fetchingActionsReducer,它将每个提取操作添加到存储中的dict(对象),然后状态如下所示: 导出常量loadingTasks=(state={},action)=>{switch(action.type){case start_loading_task:return object.a

  • 问题内容: 我有一些将数据存储在页面中的表格。我想找到一种方法来临时保存用户在导航中输入的数据,并保持到确认订单为止。谁能帮我? 问题答案: 这正是会议的目的

  • 我有一个Android应用程序,我想实现一个功能,用户可以上传SVG,然后将SVG保存在firebase实时数据库中,然后其他用户可以看到它并与之交互,我选择SVG而不是常规照片或PNG的原因是因为它非常小,不会消耗存储或数据来加载, 加上它们具有多种显示密度和不同的屏幕比例更好,问题是我不知道这是否可行,如果它是我不知道它将被保存为的格式是什么,因为我有一个保存SVG路径的想法,但没有找到将其转