在用dataset读取tfrecord的时候,看到别人的代码里面基本都有tf.data.Dataset.map()
这个部分,而且前面定义了解析tfrecord的函数decord_example(example)
之后,在后面的的map里面直接就dataset.map(decord_example)
这样使用,并没有给example
赋值。
具体代码在这里:
def decode_example(example, resize_height, resize_width, label_nums):
dics={
'image_raw':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.int64)
}
parsed_example = tf.parse_single_example(serialized=example, features=dics)
tf_image=tf.decode_raw(parsed_example['image_raw'], out_type=tf.uint8) # 这个其实就是图像的像素模式,之前我们使用矩阵来表示图像
tf_image=tf.reshape(tf_image, shape=[resize_height, resize_width, 3]) # 对图像的尺寸进行调整,调整成三通道图像
tf_image=tf.cast(tf_image,tf.float32)*(1./255) # 对图像进行归一化以便保持和原图像有相同的精度
tf_label=tf.cast(parsed_example['label'],tf.int64)
tf_label=tf.one_hot(tf_label, label_nums,on_value=1,off_value=0) # 将label转化成用one_hot编码的格式
return tf_image, tf_label
def create_dataset(tfrecords_file, batch_size, resize_height, resize_width, num_class):
dataset = tf.data.TFRecordDataset(tfrecords_file)
# dataset = dataset1.map(decode_example)
dataset = dataset.map(lambda x: decode_example(x, resize_height, resize_width, num_class))
dataset = dataset.shuffle(20000).batch(batch_size)
# dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
return dataset
对于这点我是百思不得其解。翻了一下午的博客论坛之后,看到一个比较合理的解释,结合我自己的理解,在这里写出来:在使用dataset = tf.data.TFRecordDataset(tfrecords_file)
生成一个新的dataset后,这个dataset已经含有在decord_example(example)
中的example
所需要的参数,看起来虽然没有传参,但其实参数是在内部进行了传递。因此可以直接用dataset = dataset1.map(decode_example)
如果我们还想在map
里添加额外的参数,就要用lambda表达式
,也就是dataset = dataset.map(lambda x: decode_example(x, resize_height, resize_width, num_class))
。而这里的x
看起来没有外部参数传进去,但其实是和上面所说的一样。在dataset = tf.data.TFRecordDataset(tfrecords_file)
创建dataset之后,x
也就是example
所需要的参数都已经在dataset里了。
借用一个例子就是:
import tensorflow as tf
def fun(x, arg):
return x * arg
my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))
这里x
的参数就是上面ds = tf.data.Dataset.range(5)
所创建出来的0~4的值。