当前位置: 首页 > 面试题库 >

输入placholder中的Tensorflow批处理大小

萧晓博
2023-03-14
问题内容

我是Tensorflow的新手,所以我不明白为什么输入占位符经常根据用于训练的批次大小来确定尺寸。

在此示例中,我在此处以及在Mnist官方教程中找到了

from get_mnist_data_tf import read_data_sets
mnist = read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
sess.run(tf.initialize_all_variables())
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
for i in range(1000):
  batch = mnist.train.next_batch(50)
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

print(accuracy.eval(feed_dict={x: mnist.test.images,
                               y_: mnist.test.labels}))

那么确定尺寸并创建模型输入并对其进行训练的最佳和正确方法是什么?


问题答案:

您在此处指定模型输入。您希望将批处理大小保留为None,这意味着您可以使用可变数量的输入(一个或多个)来运行模型。批处理对于有效使用您的计算资源很重要。

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

下一条重要的线是:

batch = mnist.train.next_batch(50)

在这里,您将发送50个元素作为输入,但您也可以将其更改为一个

batch = mnist.train.next_batch(1)

无需修改图形。如果指定批处理大小(在第一个代码段中用一些数字代替“无”),则每次都必须更改,这并不理想,特别是在生产中。



 类似资料:
  • 主要内容:重定向输出(Stdout和Stderr),抑制程序输出有三个键盘输入的通用“文件”,在屏幕上打印文本和在屏幕上打印错误。 标准输入文件(stdin)包含程序/脚本的输入。 标准输出(Standard Out)文件(stdout)被用来写输出以显示在屏幕上。 最后一种叫作的“标准错误”文件包含用于显示在屏幕上的任何错误消息。 这三个标准文件中的每一个(也称为标准流)分别使用数字,和进行引用。Stdin是文件,stdout是文件,stderr是文件。 重

  • 问题内容: 我有一些以表示的数据。它是一个未知大小的张量(应分批输入),每个项目的大小都为。经历,所以现在有尺寸,其中是嵌入尺寸并指未知的批量大小。 此处描述: 我现在正尝试将输入数据中的每个样本(现在通过嵌入维度进行扩展)乘以矩阵变量,而我似乎不知道该怎么做。 我首先尝试使用,但是由于形状不匹配而导致错误。然后,我通过扩展的维度和应用来尝试以下操作(我还尝试了从进行的功能,结果相同): 这将通过

  • 我有一个很像tensorflow语音命令演示的模型,只是它需要一个大小可变的1D数组作为输入。现在,我发现很难使用tflite\u convert将此模型转换为tflite,因为tflite\u convert需要输入形状。 据说tf lite需要固定大小的输入以提高效率,您可以在推理过程中调整输入大小,作为模型的一部分。然而,我认为这将涉及截断我不想要的输入。有什么方法可以让TF lite发挥作

  • 我目前正在使用Python API开发一个更大的Apache Beam管道,它从BigQuery中读取数据,并最终将其写回另一个BigQuery任务。 其中一个转换需要使用二进制程序来转换数据,为此,它需要加载一个23GB的二进制查找数据文件。因此,启动和运行该程序需要大量的开销(每次加载/运行大约需要2分钟)和RAM,并且仅为一条记录启动该程序是没有意义的。此外,每次都需要将23GB文件从云存储

  • 问题内容: 我有一个dao,它基本上使用hibernate将记录插入到一​​个表中,该dao用标记为注释,并且我有一个服务,该服务会生成其他一些东西,然后调用我的dao。我的服务也标注了使用。 我叫服务循环。我在dao上的插入内容是否可以批量或一个接一个地工作?我如何确定它们可以批量工作?hibernateTransaction Manager是否管理批处理插入? 我正在使用Oracle DB。

  • 问题内容: 我试图使用RNN(特别是LSTM)进行序列预测。但是,我遇到了序列长度可变的问题。例如, 我正在尝试使用一个基于此基准的简单RNN预测当前单词之后的下一个单词,以构建PTB LSTM模型 。 但是,该参数(用于展开到先前的隐藏状态)在每个Tensorflow的时期应保持相同。基本上,批处理句子是不可能的,因为句子的长度会有所不同。 在这里,对于我来说,每个句子都需要更改。我已经尝试了几