当前位置: 首页 > 知识库问答 >
问题:

如何用TF.Data API使用Keras生成器

呼延聪
2023-03-14
def make_generator():
    train_datagen = ImageDataGenerator(rescale=1. / 255)
    train_generator = 
    train_datagen.flow_from_directory(train_dataset_folder,target_size=(224, 224), class_mode='categorical', batch_size=32)
    return train_generator

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32)).shuffle(64).repeat().batch(32)

请注意,如果尝试直接将train_generate作为tf.data.dataset.from_generate的参数,将出现错误。但是,上面的方法不会产生错误。

当我在会话中运行它以检查数据集的输出时,我会得到以下错误。

iterator = train_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
for i in range(100):
    sess.run(next_element)

发现1000个图像,属于2个类。-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in_run_fn(feed_dict,fetch_list,target_list,options,run_metadata)1276返回self._call_tf_sessionrun(->1277 options,feed_dict,fetch_list,target_list,run_metadata)1278

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in_call_tf_sessionrun(self,options,feed_dict,fetch_list,target_list,run_metadata)1366 self._session,options,feed_dict,fetch_list,target_list,->1367 run_metadata)1368

InvalidArgumEnterror:无法批处理组件0中具有不同形状的张量。第一个元素具有形状[32,224,224,3],而第29个元素具有形状[8,224,224,3]。[[{{node iteratorgetnext2}}=iteratorgetnextoutput_shapes=[,],output_types=[DT_FLOAT,DT_FLOAT],_device=“/job:localhost/replica:0/task:0/device:cpu:0”]]

在处理上述异常的过程中,发生了另一个异常:

请让我知道,如果有人有任何经验与此或知道任何替代的方式。

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32))
model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self,x,y,batch_size,epochs,verbose,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_pepoch,validation_steps,**kwargs)1507 steps_name='steps_per_epoch,1508

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in_n标准化_user_data(self,x,y,sample_weight,class_weight,batch_size,check_steps_name,steps,validation_split)948 x=self._dataset_iterator_cache[x]949 else:-->950 iterator=x.make_initializable_iterator()951 self.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py中的make_initializable_iterator(self,shared_name)119与ops.colocate_with(iterator_resource):120 initializer=gen_dataset_ops.make_iterator(self._as_variant_tensor(),>121 iterator_resource)122返回iterator_ops.iterator(iterator_resource,123

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_dataset_ops.py in make_iterator(dataset,iterator,name)2542如果_ctx是None或not_ctx._eager_context.is_eager:2543_,_,_op=_op_def_lib._apply_op_helper(->2544“makeiterator”,dataset=dataset,iterator=iterator,name=name)2545返回_op 2546_result=None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py在_apply_op_helper(self,op_type_name,name,**keywords)348#中需要将所有参数都压缩到一个列表中。349#pylint:disable=protected-access-->350 g=ops._get_graph_from_inputs(_flatten(keywords.values())))351#pylint:enable=protected-access 352除AssertionError为e外:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in_get_graph_from_inputs(op_input_list,graph)5659 graph=graph_element.graph 5660 elif original_graph_element不是none:->5661_assert_same_graph(original_graph_element,graph)5662 elif graph_element.graph不是graph:
5663引发

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in_assert_same_graph(original_item.graph)5595如果original_item.graph不是item.graph:5596引发ValueError(“%s必须来自与%s相同的图。”%(item,->5597 original_item))5598 5599

ValueError:张量(“iteratorv2:0”,shape=(),dtype=resource)必须来自与张量(“FlatMapDataSet:0”,shape=(),dtype=variant)相同的图。

这是一个bug,还是Keras fit方法不应该这样使用?

共有1个答案

卫飞
2023-03-14

我试着用一个简单的例子重现结果,我发现当在生成器函数和tf.data中使用批处理时,会得到不同的输出形状。

Keras函数train_datagen.flow_from_directory(batch_size=32)已经返回具有形状[batch_size,width,height,deptth]的数据。如果使用tf.data.dataset().batch(32),输出数据将再次批处理成形状[batch_size,batch_size,width,height,deptth]

由于某种原因,这可能会导致您的问题。

 类似资料:
  • 问题内容: 我之前从未使用过JSON,并且尝试使用以下javascript:http : //jqueryselectcombo.googlecode.com/files/jquery.selectCombo1.2.6.js 它需要以下格式的JSON输出: 您能否指导我举一个有关如何使用PHP生成上述JSON输出的示例? 问题答案: 最简单的方法可能是从所需的对的关联数组开始: 然后使用forea

  • 我正在努力创建rest客户端,我将调用一个API来提供这个大的json输出。我想知道如何通过输入这个json来自动创建Pojo类来晃动代码gen,并让它为我创建我的pojo类,这将节省手动时间。这是我尝试过的 要为生成PHP客户端,请执行以下操作:http://petstore.swagger.io/v2/swagger.json,请运行以下命令: (如果您使用的是Windows,请将最后一个命令

  • 我知道这似乎已经讨论过了,答案是肯定的,可以为不同的字符串生成相同的值,但不太可能(Java的hashCode可以为不同的字符串生成相同的值吗?)。然而,它确实发生在我的应用程序中。 以下代码将生成相同的hashcode:-347019262(jave 1.7.25) 在这种情况下,我确实需要hashcode,我希望使用它为字符串生成唯一的主键。看来我做得不对。有什么建议吗? 多谢!

  • 我理解如何通过使用算法来验证信用卡。 但我想知道如何反向工程的问题,并创建新的有效信用卡号码

  • 我试图在Java程序中实现一个随机数生成器。我在用数学。random(),但这似乎效果不太好。然后我尝试使用SecureRandom,但这对我的游戏来说太长了。然而,我遇到了这个生成器,MersenNetWisterng随机数生成器。这似乎是我想要的;速度很快,但仍然是随机的。 然而,我已经很长时间没有用Java编写了,只有2个月,我对API既不了解也不了解。如果有人能帮我解释一下如何在我的代码中

  • 问题内容: 如何使用数据库查询回调设置变量值?我该怎么办? 问题答案: 自从使用node.js已经有一段时间了,但是我想我可以提供帮助。 首先,在node中,您只有一个线程,应该使用回调。您的代码将发生的情况是查询将排队等待执行,但是循环将毫无意义地连续作为繁忙循环运行。 您应该可以通过以下回调来解决您的问题: 并这样使用 我在大约2年内没有编写任何node / js的代码,也没有进行测试,但是基