当前位置: 首页 > 面试题库 >

如何正确组合TensorFlow的数据集API和Keras?

束帅
2023-03-14
问题内容

Keras的fit_generator()模型方法期望生成器生成形状(输入,目标)的元组,其中两个元素都是NumPy数组。该文档似乎暗示,如果我将Dataset迭代器简单地包装在生成器中,并确保将Tensors转换为NumPy数组,那我应该很好。这段代码给我一个错误:

import numpy as np
import os
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
import tensorflow as tf
from tensorflow.contrib.data import Dataset

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

with tf.Session() as sess:
    def create_data_generator():
        dat1 = np.arange(4).reshape(-1, 1)
        ds1 = Dataset.from_tensor_slices(dat1).repeat()

        dat2 = np.arange(5, 9).reshape(-1, 1)
        ds2 = Dataset.from_tensor_slices(dat2).repeat()

        ds = Dataset.zip((ds1, ds2)).batch(4)
        iterator = ds.make_one_shot_iterator()
        while True:
            next_val = iterator.get_next()
            yield sess.run(next_val)

datagen = create_data_generator()

input_vals = Input(shape=(1,))
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mean_squared_error')
model.fit_generator(datagen, steps_per_epoch=1, epochs=5,
                    verbose=2, max_queue_size=2)

这是我得到的错误:

Using TensorFlow backend.
Epoch 1/5
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__
    fetch, allow_tensor=True, allow_operation=True))
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
    generator_output = next(self._generator)
  File "./datagen_test.py", line 25, in create_data_generator
    yield sess.run(next_val)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__
    'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)

Traceback (most recent call last):
  File "./datagen_test.py", line 34, in <module>
    verbose=2, max_queue_size=2)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator
    generator_output = next(output_generator)
StopIteration

奇怪的是,next(datagen)在我初始化的位置之后直接添加包含一行datagen的代码会使代码运行正常,没有错误。

为什么我的原始代码不起作用?当我在代码中添加该行时,为什么它开始起作用?是否有更有效的方式将TensorFlow的Dataset API与Keras结合使用,而无需将Tensors转换为NumPy数组然后再次返回?


问题答案:

确实有一种更有效的使用方法,Dataset而不必将张量转换为numpy数组。但是,官方文档上还没有(尚未?)。在发行说明中,它是Keras 2.0.7中引入的功能。您可能需要安装keras> = 2.0.7才能使用它。

x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()

y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()

input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])
model.fit(steps_per_epoch=1, epochs=5, verbose=2)

几个区别:

  1. tensor参数提供给该Input层。Keras将从该张量读取值,并将其用作拟合模型的输入。
  2. target_tensors论据提供给Model.compile()
  3. 请记住将x和y都转换为float32。在正常使用情况下,Keras将为您完成此转换。但是现在您必须自己做。
  4. 批处理大小是在构建时指定的Dataset。使用steps_per_epoch和epochs控制何时停止模型拟合。
    总之,使用Input(tensor=...)model.compile(target_tensors=...)并且model.fit(x=None, y=None, ...)如果你的数据是从张量读取。


 类似资料:
  • 我试图找到使用httr包通过R连接Appannie的API的方法(根本没有API连接的经验)。API要求包含来自appannie网站的请求标题引用:注册App Annie帐户并生成API密钥。将此密钥添加到您的请求标头中,如下所示: 授权:持有人“”引用 我写了这样的代码 命令http_status(getdata)显示我"客户端错误:(401)未经授权"有人能帮我吗,我做错了什么?

  • 问题内容: 我正在尝试从使用Hibernate检索的数据库中序列化对象,而我只对对象的实际数据整体感兴趣(包括循环)。 现在,我一直在使用XStream,它似乎功能强大。XStream的问题在于,它对信息过于盲目。它可以按原样识别Hibernate的PersistentCollections,并包含所有Hibernate元数据。我不想序列化那些。 因此,有没有一种合理的方法可以从Persisten

  • 有很多例子可以说明如何创建和使用TensorFlow数据集。 我的问题是如何以numpy格式从TF数据集获取数据/标签?换言之,WAND将是上面这行的反向操作,即我有一个TF数据集,并希望从中获取图像和标签。

  • 问题内容: 尝试从我知道仅包含整数的数组中获取最高和最低值似乎比我想象的要难。 我希望这能显示出来。相反,它显示。因此,似乎排序是将值作为字符串处理。 有没有一种方法可以使sort函数对整数值进行实际排序? 问题答案: 默认情况下,sort方法按字母顺序对元素进行排序。要进行数字排序,只需添加一个处理数字排序的新方法(sortNumber,如下所示)- 在ES6中,可以使用箭头功能简化此操作: 说

  • 问题内容: 只要我有键值对,解组就非常简单了,但是我将如何以不同的顺序解组不同类型的数组呢?单个元素定义明确且已知,但顺序不明确。 我无法提出一个漂亮的解决方案。 我会尝试对所有元素进行错误处理吗?是否有某种工会类型可以为我做到这一点? 游乐场版 问题答案: Go官方博客上有一篇不错的文章:JSON和GO。可以将“任意数据解码”到接口{}中,并使用类型断言来动态确定类型。 您的代码可能可以修改为:

  • 问题内容: 我有以下代码供用户单击按钮使用,当他们单击按钮时,正在查看的特定字符串将被收藏并存储在其他位置。 我有两个问题。 我现在所拥有的是什么问题?因为当您单击按钮时它会崩溃。 您将如何完成将要加载的字符串并将其保存到数组的load array方法,以便用户以后可以看到该数组? 谢谢你的时间!! 原木猫 问题答案: 函数返回。您可能想要像这样使用它: 该数组可以稍后在第二部分中使用。希望这可以

  • 我有两个数据源A和B。我有两个数据集'A'和'B'分别来自它们。 谢谢你的帮助!

  • 本文向大家介绍java 如何实现正确的删除集合中的元素,包括了java 如何实现正确的删除集合中的元素的使用技巧和注意事项,需要的朋友参考一下 在java中如果我们需要遍历集合并删除其中的某些元素时,例如对于List来说,我们有三种办法。 1. 普通的for循环遍历并删除 main中调用 程序输出[1,2,3] 这是因为,删除时改变了list的长度。删除第一个2后,长度变为了3,这时list.ge