当前位置: 首页 > 面试题库 >

TensorFlow用于二进制分类

隆芷阳
2023-03-14
问题内容

我正在尝试将此MNIST示例调整为二进制分类。

但是,改变我的时候NLABELS,从NLABELS=2NLABELS=1,损失函数总是返回0(和准确度1)。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# Import data
mnist = input_data.read_data_sets('data', one_hot=True)
NLABELS = 2

sess = tf.InteractiveSession()

# Create the model
x = tf.placeholder(tf.float32, [None, 784], name='x-input')
W = tf.Variable(tf.zeros([784, NLABELS]), name='weights')
b = tf.Variable(tf.zeros([NLABELS], name='bias'))

y = tf.nn.softmax(tf.matmul(x, W) + b)

# Add summary ops to collect data
_ = tf.histogram_summary('weights', W)
_ = tf.histogram_summary('biases', b)
_ = tf.histogram_summary('y', y)

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, NLABELS], name='y-input')

# More name scopes will clean up the graph representation
with tf.name_scope('cross_entropy'):
    cross_entropy = -tf.reduce_mean(y_ * tf.log(y))
    _ = tf.scalar_summary('cross entropy', cross_entropy)
with tf.name_scope('train'):
    train_step = tf.train.GradientDescentOptimizer(10.).minimize(cross_entropy)

with tf.name_scope('test'):
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    _ = tf.scalar_summary('accuracy', accuracy)

# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter('logs', sess.graph_def)
tf.initialize_all_variables().run()

# Train the model, and feed in test data and record summaries every 10 steps

for i in range(1000):
    if i % 10 == 0:  # Record summary data and the accuracy
        labels = mnist.test.labels[:, 0:NLABELS]
        feed = {x: mnist.test.images, y_: labels}

        result = sess.run([merged, accuracy, cross_entropy], feed_dict=feed)
        summary_str = result[0]
        acc = result[1]
        loss = result[2]
        writer.add_summary(summary_str, i)
        print('Accuracy at step %s: %s - loss: %f' % (i, acc, loss)) 
   else:
        batch_xs, batch_ys = mnist.train.next_batch(100)
        batch_ys = batch_ys[:, 0:NLABELS]
        feed = {x: batch_xs, y_: batch_ys}
    sess.run(train_step, feed_dict=feed)

我检查了两者的尺寸batch_ys(馈入y),_y并且它们都是1xN矩阵,NLABELS=1因此问题似乎早于此。也许与矩阵乘法有关?

我实际上在一个真实的项目中也遇到了同样的问题,所以任何帮助都将不胜感激……谢谢!


问题答案:

原始的MNIST示例使用单热编码来表示数据中的标签:这意味着,如果存在NLABELS = 10类(如MNIST中的类),则目标输出[1 0 0 0 0 0 0 0 0 0]针对的是0类,[0 1 0 0 0 0 0 0 0 0]针对的1类,等等。tf.nn.softmax()操作员转换logit计算tf.matmul(x, W) + b得出不同输出类别之间的概率分布,然后将其与的输入值进行比较y_

如果NLABELS = 1,这种行为,如果当时只有一个类,以及tf.nn.softmax()运算将计算的概率1.0为类,从而导致的交叉熵0.0,因为tf.log(1.0)0.0对所有的例子。

您可以尝试(至少)两种方法进行二进制分类:

  1. 最简单的方法是设置NLABELS = 2两个可能的类,并按照[1 0]标签0和[0 1]标签1对训练数据进行编码。

  2. 你可以保持html" target="_blank">标签作为整数01和使用tf.nn.sparse_softmax_cross_entropy_with_logits()



 类似资料:
  • 我一直在与TensorFlow的构建器进行斗争,以便能够为我的模型服务,我试图在为模型服务后向我的分类器提供数据 我的问题是如何向模型提供输入?我看过Google的inception教程使用的代码 并试图实施它 据我所知,输入被传递给一个名为serialized_tf_example的张量,顾名思义,该张量将输入序列化为string,但是他们使用我不理解的tf.fixedlenfeature,然后

  • 本文向大家介绍tensorflow 1.0用CNN进行图像分类,包括了tensorflow 1.0用CNN进行图像分类的使用技巧和注意事项,需要的朋友参考一下 tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。 任务:花卉分类 版本:tensorflow 1.0 数据:flower-photos

  • objdump工具用来显示二进制文件的信息,就是以一种可阅读的格式让你更多地了解二进制文件可能带有的附加信息。 14.1. 常用参数说明 -f 显示文件头信息 -D 反汇编所有section (-d反汇编特定section) -h 显示目标文件各个section的头部摘要信息 -x 显示所有可用的头信息,包括符号表、重定位入口。-x 等价于 -a -f -h -r -t 同时指定。 -i 显示对于

  • 我可以运行这个程序,但由于某些原因,它会显示/放置随机字符,而不是二进制的初始值,而且我似乎无法将程序从十进制运行回二进制。我该如何改进这些代码。要明确说明它不会将二进制转换为十进制,我将如何将其转换回十进制转换为二进制,如果有一些代码可以帮助我,将不胜感激。

  • 主要内容:二进制,八进制,十六进制我们平时使用的数字都是由 0~9 共十个数字组成的,例如 1、9、10、297、952 等,一个数字最多能表示九,如果要表示十、十一、二十九、一百等,就需要多个数字组合起来。 例如表示 5+8 的结果,一个数字不够,只能”进位“,用 13 来表示;这时”进一位“相当于十,”进两位“相当于二十。 因为逢十进一(满十进一),也因为只有 0~9 共十个数字,所以叫做 十进制(Decimalism)。十进

  • 本文向大家介绍PHP实现十进制、二进制、八进制和十六进制转换相关函数用法分析,包括了PHP实现十进制、二进制、八进制和十六进制转换相关函数用法分析的使用技巧和注意事项,需要的朋友参考一下 本文实例讲述了PHP实现十进制、二进制、八进制和十六进制转换相关函数用法。分享给大家供大家参考,具体如下: 1.二进制: 1.1.二进制转十进制: 函数:bindec(string $binary_string)