很多东西自己只是查了、用了,但却没有去总结吸收,还是需要一个个的去积累记录
目录
tf.data.Dataset.batch 、map、shuffle、repeat
tf.data.Dataset.make_one_shot_iterator().get_next()
Class TFRecordDataset。A Dataset comprising records from one or more TFRecord files.
Creates a TFRecordDataset to read for one or more TFRecord files.
**即读tfrecord文件创建一个TFRecordDataset类的实例对象。其父类是Dataset。(也即是创建了一个tf.data.Dataset的对象)。 tf.data.Dataset有的方法,他会全都继承。**
**Dataset 是基类,表示一串元素(elements),其中每个元素包含了一或多个Tensor对象。例如:在一个图片pipeline中,一个元素可以是单个训练样本,它们带有一个表示图片数据的tensors和一个label组成的pair。包括了创造和变换(transform)datasets的方法,同时也允许从内存中的数据来初始化dataset。**
TextLineDataset 从文本文件中读取行数据
TFRecordDataset 从TFRecord文件中读取records
FixLengthRecordDataset 从二进制文件中读取固定长度records
Iterator 它提供了主要的方式来从一个dataset中抽取元素。通过Iterator.get_next() 返回的该操作会yields出Datasets中的下一个元素,作为输入pipeline和模型间的接口使用
一个dataset由element组成,它们每个都具有相同的结构,一个element包含了一个或多个tf.Tensor对象,称为component,每个component都具有
<br>
**TFRecordDataset** 解析tfrecord文件中的所有记录,可以直接使用dataset的map方法;也可使用直接使用这个类的方法repeat、shuffle、batch方法对dataset进行重复、混洗、分批;
参数:(
filenames,
compression_type=None,
buffer_size=None,
num_parallel_reads=None
)
一般只传第一个参数filenames即可
filenames: A tf.string tensor or tf.data.Dataset containing one or more filenames.
compression_type: (Optional.) A tf.string scalar evaluating to one of "" (no compression), "ZLIB", or "GZIP".
buffer_size: (Optional.) A tf.int64 scalar representing the number of bytes in the read buffer. 0 means no buffering.
num_parallel_reads: (Optional.) A tf.int64 scalar representing the number of files to read in parallel. Defaults to reading files sequentially.
batch ( batch_size,
drop_remainder=False
)
参数: batch_size :表示一次迭代的batch中去数据量的条数。 A tf.int64 scalar tf.Tensor, representing the number of consecutive elements of this dataset to combine in a single batch.
drop_remainder:最后一个batch数据是否舍弃 (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case its has fewer than batch_size elements; the default behavior is not to drop the smaller batch.
返回: Dataset: A Dataset.
batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:
dataset = dataset.batch(32)
shuffle
(
buffer_size,
seed=None,
reshuffle_each_iteration=None
)
shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。Randomly shuffles the elements of this dataset.
参数:
buffer_size: A tf.int64 scalar tf.Tensor, representing the number of elements from this dataset from which the new dataset will sample.
seed: (Optional.) A tf.int64 scalar tf.Tensor, representing the random seed that will be used to create the distribution. See tf.set_random_seed for behavior.
reshuffle_each_iteration: (Optional.) A boolean, which if true indicates that the dataset should be pseudorandomly reshuffled each time it is iterated over. (Defaults to True.)
返回 : Dataset: A Dataset.
repeat
(count=None)
repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch. Repeats this dataset count times . count: (Optional.) A tf.int64 scalar tf.Tensor, representing the number of times the dataset should be repeated. The default behavior (if count is None or -1) is for the dataset be repeated indefinitely.
返回 : Dataset: A Dataset.
假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:
dataset = dataset.repeat(5)
如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常:
将dataset中的元素取出。Creates an Iterator for enumerating the elements of this dataset.
如何将这个dataset中的元素取出呢?方法是从Dataset中示例化一个Iterator,然后对Iterator进行迭代。 语句iterator = dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次。one_element = iterator.get_next()表示从iterator里取出一个元素当在非Eager模式下时。one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Site :
# @Software: PyCharm
# @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
# coding:utf-8
import os
import sys
import random
import tensorflow as tf
from collections import namedtuple
import json
def get_did():
with open('./tower.did', 'r') as fr:
line = fr.readline()
_did = json.loads(line.strip())
return _did
def parse_query_features(filename):
## tfrecord格式数据Features
# _did = get_did()
# print(_did)
with tf.Session() as sess:
dataset = tf.data.TFRecordDataset(tf.gfile.Glob(filename))
# 下面batch里面的值控制输出多少条数据
dataset = dataset.batch(2)
batch_data = dataset.make_one_shot_iterator().get_next()
data = sess.run(batch_data)
i = 0
for i in range(len(data)):
if i % 1000 == 0:
print(i)
example = tf.train.Example.FromString(data[i])
# str to example
features = example.features.feature
# print(features.keys())
# print ('!!!FEATURES: ',type(features))
# print("样本feature")
fea = {}
for ind, key in enumerate(features):
feature = features[key]
ftype = None
fvalue = None
if len(feature.bytes_list.value) > 0:
ftype = 'bytes_list'
fvalue = feature.bytes_list.value
if len(feature.float_list.value) > 0:
ftype = 'float_list'
fvalue = feature.float_list.value
if len(feature.int64_list.value) > 0:
ftype = 'int64_list'
fvalue = feature.int64_list.value
# fea[key] = fvalue
print("{}\t{}\t{}\t{}".format(ind, key, ftype, fvalue))
# if fea["1000011_s_u_did_debug"][0] in _did:
# print(fea)
## 统计某个record的part数据的大小。
def countTfRecord(file_path):
count=0
for record in tf.python_io.tf_record_iterator(file_path):
count+=1
print("数据{} 的样本条数为\t{}".format(file_path,count))
if __name__ == '__main__':
# print(get_did())
# 输入路径即可
data_path = "./part-r-00000"
#解析数据
parse_query_features(data_path) #统计某个样本的数据条数
# countTfRecord(data_path)
参考与鸣谢:
何之源的Dataset API入门教程 : https://zhuanlan.zhihu.com/p/30751039
Tensorflow tf.data.Dataset API学习笔记讲解,讲的很不错。https://zhuanlan.zhihu.com/p/37106443
tensorflow入门:tfrecord 和tf.data.TFRecordDataset https://blog.csdn.net/yeqiustu/article/details/79793454