当前位置: 首页 > 知识库问答 >
问题:

如何使用Dataset API读取变量长度列表的TFRecords文件?

阳勇
2023-03-14

我想使用Tensorflow的Dataset API来读取变量长度列表的TFRecords文件。这是我的密码。

def _int64_feature(value):
    # value must be a numpy array.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def main1():
    # Write an array to TFrecord.
    # a is an array which contains lists of variant length.
    a = np.array([[0, 54, 91, 153, 177],
                 [0, 50, 89, 147, 196],
                 [0, 38, 79, 157],
                 [0, 49, 89, 147, 177],
                 [0, 32, 73, 145]])

    writer = tf.python_io.TFRecordWriter('file')

    for i in range(a.shape[0]): # i = 0 ~ 4
        x_train = a[i]
        feature = {'i': _int64_feature(np.array([i])), 'data': _int64_feature(x_train)}

        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature=feature))

        # Serialize to string and write on the file
        writer.write(example.SerializeToString())

    writer.close()

    # Check TFRocord file.
    record_iterator = tf.python_io.tf_record_iterator(path='file')
    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)

        i = (example.features.feature['i'].int64_list.value)
        data = (example.features.feature['data'].int64_list.value)
        #data = np.fromstring(data_string, dtype=np.int64)
        print(i, data)

    # Use Dataset API to read the TFRecord file.
    def _parse_function(example_proto):
        keys_to_features = {'i'   :tf.FixedLenFeature([], tf.int64),
                            'data':tf.FixedLenFeature([], tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return parsed_features['i'], parsed_features['data']

    ds = tf.data.TFRecordDataset('file')
    iterator = ds.map(_parse_function).make_one_shot_iterator()
    i, data = iterator.get_next()
    with tf.Session() as sess:
        print(i.eval())
        print(data.eval())

检查TF记录文件

[0] [0, 54, 91, 153, 177]
[1] [0, 50, 89, 147, 196]
[2] [0, 38, 79, 157]
[3] [0, 49, 89, 147, 177]
[4] [0, 32, 73, 145]

但是当我试图使用数据集API读取TFRecords文件时,它显示了以下错误。

张量流。python框架错误。InvalidArgumentError:名称:,键:数据,索引:0。int64值的数目!=预期。值大小:5,但输出形状:[]

多谢各位
更新:我尝试使用以下代码读取带有Dataset API的TFRecord,但都失败了。

def _parse_function(example_proto):
    keys_to_features = {'i'   :tf.FixedLenFeature([], tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return parsed_features['i'], parsed_features['data']

ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))

def _parse_function(example_proto):
    keys_to_features = {'i'   :tf.VarLenFeature(tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return parsed_features['i'], parsed_features['data']

ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))

错误是:

回溯(最后一次调用):文件“/usr/local/lib/python3.5/dist packages/tensorflow/python/framework/tensor_util.py”,第468行,make_tensor_proto str_values=[compat.as_bytes(x)for x in proto_values]文件“/usr/local/lib/python3.5/dist packages/tensorflow/python/python/framework/python/framework/framework/tensor/tensor_util.py”,第468行,str_值=[compat对于proto_values]文件“/usr/local/lib/python3.5/dist packages/tensorflow/python/util/compat.py”中的x,第65行,在as_字节(bytes_或_text)中,类型错误:应为二进制或unicode字符串,已获取

在处理上述异常期间,发生了另一个异常:

Traceback(最近一次调用):文件"2tfrecord.py",第126行,在main 1()文件"2tfrecord.py",第72行,在main 1迭代器=ds.map(_parse_function)。make_one_shot_iterator()文件"/usr/本地/lib/python3.5/dist-包/Tensorflow/python/data/ops/dataset_ops.py",第712行,在地图中返回MapDataset(自,map_func)File"/usr/loc/lib/python3.5/dist-包/Tensorflow/python/data/ops/dataset_ops.py",第1385行,在初始化自。_map_func.add_to_graph(ops.get_default_graph())file"/usr/loce/lib/python3.5/dist-pack/tensorflow/python/框架/function.py",第486行,add_to_graph。_create_definition_if_needed()File"/usr/loce/lib/python3.5/dist-pack/tensorflow/python/框架/function.py",第321行,在_create_definition_if_needed._create_definition_if_needed_impl()File"/usr/loce/lib/python3.5/dist-包/tensorflow/python/框架/function.py",第338行,_create_definition_if_needed_impl输出=自己。_func(*输入)File"/usr/loce/lib/python3.5/dist-包/tensorflow/python/data/ops/dataset_ops.py",第1376行,tf_map_funcflattened_ret=[ops.convert_to_tensor(t)用于nest.flatten中的t]File"/usr/local/lib/python3.5/dist-包/tensorflow/python/data/ops/dataset_ops.py",第1376行,flattened_ret = [ops.convert_to_tensor(t)用于nest.flatten中的t(ret)]File"/usr/local/lib/python3.5/dist-pack/tenorflow/python/框架/ops.py",第836行,convert_to_tensoras_ref=False)File"/usr/local/lib/python3.5/dist-pack/tenorflow/python/框架/ops.py",第926行,在internal_convert_to_tensorret=conversion_func(value,dtype=dtype,name=name,as_ref=as_ref)File"/usr/local/lib/python3.5/dist-包/tenstorflow/python/框架/constant_op.py",第229行,在_constant_tensor_conversion_function返回常量(v,dtype=dtype,name=name)File"/usr/local/lib/python3.5/dist-包/tenorflow/python/框架/constant_op.py",第208行,常量值,dtype=dtype,形状=形状,verify_shape=verify_shape))File"/usr/local/lib/python3.5/dist-包/tenorflow/python/框架/tensor_util. py",第472行,在make_tensor_proto"支持的类型."%(类型(值),值))TypeError:无法将类型的对象转换为张量。内容:稀疏张量(索引=张量(ParseSingle示例/Slice_Indices_i: 0,形状=(?,1),dtype=int64),值=张量(ParseSingle示例/Parse示例/Parse示例: 3,形状=(?,), dtype=int64),dense_shape=张量(ParseSingle示例/Squeeze_Shape_i: 0,形状=(1,),dtype=int64))。考虑将元素转换为支持的类型。

Python版本:3.5.2
Tensorflow版本:1.4.1

共有2个答案

公西毅
2023-03-14

错误很简单。您的数据不是FixedLenFeature而是VarLenFeature。替换您的线路:

 'data':tf.FixedLenFeature([], tf.int64)}

 'data':tf.VarLenFeature(tf.int64)}

另外,当您调用print(i.eval())print(data.eval())时,您将调用迭代器两次。第一个print将打印0,但第二个将打印第二行的值[0,50,89,147,196]。您可以执行print(sess.run([i,data]))从同一行获取i数据。

虞承泽
2023-03-14

经过数小时的搜索和尝试,我相信答案终于出现了。下面是我的代码。

def _int64_feature(value):
    # value must be a numpy array.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten()))

# Write an array to TFrecord.
# a is an array which contains lists of variant length.
a = np.array([[0, 54, 91, 153, 177],
              [0, 50, 89, 147, 196],
              [0, 38, 79, 157],
              [0, 49, 89, 147, 177],
              [0, 32, 73, 145]])

writer = tf.python_io.TFRecordWriter('file')

for i in range(a.shape[0]): # i = 0 ~ 4
    x_train = np.array(a[i])
    feature = {'i'   : _int64_feature(np.array([i])), 
               'data': _int64_feature(x_train)}

    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))

    # Serialize to string and write on the file
    writer.write(example.SerializeToString())

writer.close()

# Check TFRocord file.
record_iterator = tf.python_io.tf_record_iterator(path='file')
for string_record in record_iterator:
    example = tf.train.Example()
    example.ParseFromString(string_record)

    i = (example.features.feature['i'].int64_list.value)
    data = (example.features.feature['data'].int64_list.value)
    print(i, data)

# Use Dataset API to read the TFRecord file.
filenames = ["file"]
dataset = tf.data.TFRecordDataset(filenames)
def _parse_function(example_proto):
    keys_to_features = {'i':tf.VarLenFeature(tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return tf.sparse_tensor_to_dense(parsed_features['i']), \
           tf.sparse_tensor_to_dense(parsed_features['data'])
# Parse the record into tensors.
dataset = dataset.map(_parse_function)
# Shuffle the dataset
dataset = dataset.shuffle(buffer_size=1)
# Repeat the input indefinitly
dataset = dataset.repeat()  
# Generate batches
dataset = dataset.batch(1)
# Create a one-shot iterator
iterator = dataset.make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))
    print(sess.run([i, data]))
    print(sess.run([i, data]))

有一些事情需要注意。
1.这个SO问题帮助很大。
2.tf.VarLenFeature将返回SparseTensor,因此,使用tf.sparse_tensor_to_dense转换为密集张量是必要的。在我的代码中,parse_single_example()不能用parse_example()替换,它困扰了我一天。我不知道为什么parse_example()不起作用。如果有人知道原因,请指教。

 类似资料:
  • 问题内容: 我有两个过程,其中一个正在写(附加)到文件,另一个正在从文件读取。这两个进程正在同时运行,但无法通信。另一个读取器进程可能在写入器进程完成之前开始。 这种方法有效,但read()通常返回已读取零字节且无错误的信息。它们的零长度读取与非零长度读取之比很高,效率很低。 有没有办法解决?这是在POSIX文件系统上。 问题答案: 没有通信通道,就无法保证在读取正在写入的文件时,防止零字节读取甚

  • 问题内容: 我想使用php阅读列表,列出网页文件夹中的文件名。有没有简单的脚本可以实现? 问题答案: 最简单,最有趣的方式(imo)是glob 但是标准方法是使用目录功能。 还有SPL DirectoryIterator方法 。如果你感兴趣

  • 问题内容: 我对NodeJ很陌生。而且我正在尝试将文件读入变量。这是我的代码。 但每次我运行该脚本,我得到 和 我想念什么?请帮助! 问题答案: 正如您在问题下的注释中所述,节点是异步的-意味着当您调用第二个函数时,您的函数尚未完成执行。 如果在读取文件后将日志语句移动到回调中,则应该看到输出的内容: 即使这将解决您眼前的问题,但如果不了解节点的异步特性,您将遇到很多问题。

  • 我正在尝试使用递归函数打印列表,该列表具有由我的以下代码产生的列表的最大长度: 我需要将下面的输出传递给找到最大长度的递归函数: 基于我对这个问题答案的理解,我尝试使用以下代码来实现它,但我无法很好地实现递归部分。以下是我的尝试: 注:我需要使用递归来解决最长递增序列的问题。

  • 问题内容: 我正在使用Selenium IDE来测试基于Web的HR / SW系统。 有一个用于输入员工休假的屏幕。 我有近3000名员工。 我构建了一个测试案例,该案例使用变量输入一位员工的假期。 如何在不创建3000次测试用例的情况下为所有3000名员工重复测试用例。要做到这一点将是不可能的。注意:每位员工都有不同的休假数据(类型,开始日期,结束日期) 有什么方法可以使用文件(Excel,…)