当前位置: 首页 > 面试题库 >

获取TensorFlow中dynamic_rnn的最后输出

祁高格
2023-03-14
问题内容

我有一个3D张量形状[batch, None, dim],其中第二维(即时间步长)未知。我使用dynamic_rnn以下代码段来处理此类输入:

import numpy as np
import tensorflow as tf

batch = 2
dim = 3
hidden = 4

lengths = tf.placeholder(dtype=tf.int32, shape=[batch])
inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim])
cell = tf.nn.rnn_cell.GRUCell(hidden)
cell_state = cell.zero_state(batch, tf.float32)
output, _ = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)

实际上,以一些实际数字运行此代码后,我得到了一些合理的结果:

inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],
                    [[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],
                    dtype=np.int32)
lengths_ = np.asarray([3, 1], dtype=np.int32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    output_ = sess.run(output, {inputs: inputs_, lengths: lengths_})
    print(output_)

输出为:

[[[ 0.          0.          0.          0.        ]
  [ 0.02188676 -0.01294564  0.05340237 -0.47148666]
  [ 0.0343586  -0.02243731  0.0870839  -0.89869428]
  [ 0.          0.          0.          0.        ]]

 [[ 0.00284752 -0.00315077  0.00108094 -0.99883419]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]]]

有没有一种方法可以通过动态RNN[batch, 1, hidden]最后一个相关输出 获得形状的3-D张量?谢谢!


问题答案:

这就是collect_nd的目的!

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.gather_nd(data, indices)

    return res

在您的情况下:

output = extract_axis_1(output, lengths - 1)

现在output是维的张量[batch_size, num_cells]



 类似资料:
  • 问题内容: 我有一个正在运行并使用的脚本 这返回 我想要内联最后四个字符,这样 问题答案: 怎么样,用开关。例如,要获取“ hello”的最后四个字符: 请注意,我使用5(4 + 1),因为会添加一个换行符。如下面的Brad Koch所建议,请使用来防止添加换行符。

  • 在评论中,这是我在ATM上的代码: Atm I有两个函数:一个可以将字符串转换为日历的星期数,另一个是我正在搜索的方法。目前它只处理今天的一周中的一天正确的,应该做的工作为一周中的每隔一天的部分是缺失的(评论与...)

  • 我用Tensorflow培训了一个模型(更快的rcnn\U resnet101\U coco\U 2018\U 01\U 28)。我有一个“.pb”图形。 要制作一个冻结图,我需要输入和输出节点。 如何在图表中找到它? 我这里有完整的节点列表。 没有任何节点像Softmax,占位符,因为它在其他帖子中建议。

  • 当我将张量流精简模型添加到我的Android应用程序时。它建议自动生成的代码。 现在让我们假设我在python中的输入形状是一个50数字[1,2,3...]的int数组,它给出了一个浮点值的输出。 我必须以何种方式更改代码。

  • 问题内容: 到目前为止,这是我的JavaScript代码: 当前,它需要URL中数组的倒数第二个项目。但是,我想检查数组中的最后一项是否正确,如果是,请改为抓取倒数第三项。 问题答案: 如果您的服务器为“ index.html”和“ inDEX.htML”提供相同的文件,则您也可以使用:。 但是,如果可能的话,您可能要考虑在服务器端进行此操作:它将更加干净并且可以在没有JS的情况下使用。

  • 问题内容: 我需要定义数字的最后一位数字,并将其分配给值。此后,返回最后一位数字。 我的代码段无法正常工作… 码: 题: 如何解决这个问题? 问题答案: 刚回来; 即取模数。这将比解析字符串要快得多。 如果可以为负则使用