当前位置: 首页 > 编程笔记 >

将自己的数据集制作成TFRecord格式教程

孟和玉
2023-03-14
本文向大家介绍将自己的数据集制作成TFRecord格式教程,包括了将自己的数据集制作成TFRecord格式教程的使用技巧和注意事项,需要的朋友参考一下

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串类型的属性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #输出TFRecord文件的地址
 
 writer = tf.python_io.TFRecordWriter(filename)
 
 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每个图片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #将图片转化成二进制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()
 
 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename): 
 #创建一个reader来读取TFRecord文件中的样例
 reader = tf.TFRecordReader()
 #创建一个队列来维护输入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example
 
 #解析读入的一个样例,如果需要解析多个,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #将字符串解析成图像对应的像素数组
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)
 
 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize
 
 image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
              batch_size=batchsize, 
              capacity=capacity, 
              min_after_dequeue=min_after_dequeue
              )
 
 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函数

if __name__ =="__main__":
 #训练图片两张为一个batch,进行训练,测试图片一起进行测试
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")
   
  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此时,生成的batch,就可以feed进网络了。

以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持小牛知识库。

 类似资料:
  • 本文向大家介绍对python制作自己的数据集实例讲解,包括了对python制作自己的数据集实例讲解的使用技巧和注意事项,需要的朋友参考一下 一、数据集介绍 点击打开链接17_Category_Flower 是一个不同种类鲜花的图像数据,包含 17 不同种类的鲜花,每类 80 张该类鲜花的图片,鲜花种类是英国地区常见鲜花。下载数据后解压文件,然后将不同的花剪切到对应的文件夹,如下图所示: 每个文件夹

  • 问题内容: 这是我最后一个问题的后续,从Pandas数据帧转换为TensorFlow张量对象 我现在处于下一步,需要更多帮助。我正在尝试替换此行代码 替换我自己的数据。我找到了这个答案:TensorFlow教程batch_xs,batch_ys = mnist.train.next_batch(100)的next_batch来自哪里?但我不明白: 1)为什么.next_batch()在我的张量上不

  • 本文向大家介绍Tensorflow 训练自己的数据集将数据直接导入到内存,包括了Tensorflow 训练自己的数据集将数据直接导入到内存的使用技巧和注意事项,需要的朋友参考一下 制作自己的训练集 下图是我们数据的存放格式,在data目录下有验证集与测试集分别对应iris_test, iris_train 为了向伟大的MNIST致敬,我们采用的数据名称格式和MNIST类似 classificati

  • 问题内容: 我的问题是关于数据标准化。 信息 我正在尝试将数据库中的测试结果制成表格。我要记录的信息是test_instance,user_id,test_id,完成(日期/时间),(测试时间),分数,不正确的问题和已复习的问题。 在大多数情况下,我认为我是根据表1来组织信息的,但是我想尽办法找出记录错误或已审查问题的最佳方法有些不为所动。请注意,我 不要 希望把所有不正确的问题集中在一个条目按表

  • 早些时候,我问了Flink一个简单的hello world示例。这给了我一些很好的例子! 然而,我想问一个更“流”的例子,我们每秒生成一个输入值。这在理想情况下是随机的,但即使每次都是相同的值也可以。 目标是获得一个无需/最少外部接触就能“移动”的流。 因此,我的问题是: 我发现如何显示这与外部生成数据和写入Kafka,或听一个公共源,但是我试图解决它与最小的依赖性(像在Nifi与Generate

  • 我对Spark SQL很陌生。在执行一项培训任务时,我遇到了以下问题,无法找到答案(以下所有示例都有点愚蠢,但出于演示目的,应该仍然可以)。 我的应用程序读取拼花文件并根据其内容创建数据集: 数据集。show()调用结果: 然后,我将数据集转换为一个新的数据集,其中包含Person类型: 哪里 最后,当我显示数据集的内容时,我希望看到 然而,我明白了 这是toString()方法的结果,而标头是正