当前位置: 首页 > 知识库问答 >
问题:

从Tensorflow PrefinchDataset中提取目标

韦泳
2023-03-14

我仍然在学习tenstorflow和keras,我怀疑这个问题有一个非常简单的答案,我只是因为不熟悉而错过了。

我有一个PrefinchDataset对象:

> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>

...由特征和目标组成。我可以使用for循环对其进行迭代:

> for example in tf_test:
>     print(example[0].numpy())
>     print(example[1].numpy())
>     exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
   [-0.22 -0.54 -0.14 ... 0.33 -0.55]
   [-0.60 -0.02 -1.41 ... 0.21 -0.63]
   ...
   [-0.03 -0.91 -0.12 ... 0.77 -0.23]
   [-0.76 -1.48 -0.15 ... 0.38 -0.35]
   [-0.55 -0.08 -0.69 ... 0.44 -0.36]]
  [0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
   ...
   0 1 1 0]

然而,这是非常缓慢的。我想做的是访问与类标签相对应的张量,并将其转换成一个Numpy数组,或者一个列表,或者任何类型的迭代,这些都可以输入到Scikit学习的分类报告和/或混淆矩阵中:

> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
   [0.14]
   [0.00]
   ...
   [0.32]
   [0.03]
   [0.00]]
> y_pred_list = [int(x[0]) for x in y_pred]             # assumes value >= 0.5 is positive prediction
> y_true = []                                           # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)

...或者访问数据,以便它可以在tenstorflow的混淆矩阵中使用:

> labels = []                                           # what I need help with
> predictions = y_pred_list                             # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)

在这两种情况下,从原始对象中以一种计算上不昂贵的方式获取目标数据的一般能力将非常有帮助(并且可能有助于我的基本直觉:Tenorflow和keras)。

如有任何建议,将不胜感激。

共有3个答案

范华清
2023-03-14

如果要保留批次或将所有标签提取为单个张量,可以使用以下函数


def get_labels_from_tfdataset(tfdataset, batched=False):

    labels = list(map(lambda x: x[1], tfdataset)) # Get labels 

    if not batched:
        return tf.concat(labels, axis=0) # concat the list of batched labels

    return labels
党权
2023-03-14

您可以使用map从每个(输入,标签)对中选择输入或标签,并将其转换为列表:

import tensorflow as tf
import numpy as np

inputs = np.random.rand(100, 99)
targets = np.random.rand(100)

ds = tf.data.Dataset.from_tensor_slices((inputs, targets))

X_train = list(map(lambda x: x[0], ds))
y_train = list(map(lambda x: x[1], ds))
章哲茂
2023-03-14

您可以使用list(ds)将其转换为列表,然后使用tf将其重新编译为普通数据集。数据数据集。从张量切片(列表(ds))。从那以后,你的噩梦又开始了,但至少这是一场别人以前经历过的噩梦。

请注意,对于更复杂的数据集(例如嵌套字典),在调用list(ds)后,您需要进行更多的预处理,但这应该适用于您所询问的示例。

这远不是一个令人满意的答案,但不幸的是,这个类完全没有文档记录,标准的数据集技巧都不起作用。

 类似资料:
  • 使用总和,我如何对消息部分是json(所以不完全是)的日志条目执行查询? 示例条目: 生产。警告:我们受到速率限制{ " class ":" App \ WebhookService \ WebhookExecutor "," headers":{"Date":["Thu,2020年4月30日02:10:32 GMT"]," Content-Type ":[" application/JSON "

  • 使用OpenXML(C#)解析*. docx文档有一个问题。 下面是我的步骤: 1。加载*。docx文档 2。接收段落列表 3。在每个段落中查找文本、图像和表格元素 4。为每个文本和图像元素创建html标记 5。将输出另存为*。html文件 我已经了解了如何在文档中定位图像文件并将其解压缩。现在有一个步骤要做——找到表格在文本(段落)中的位置。 如果有人知道如何在*中定位表。docx文档使用Ope

  • Q非常业余的程序员在这里,寻求你的帮助。 我必须经常编辑这样的xml文件 使用一个相当复杂的正则表达式搜索和替换过程,我只能提取标记属性的值。(这就是我所关心的)。 但是这很耗时,而且在Python中必须有非常简单的方法来查找属性标记="SOME_TEXT"部分并将所有值放入一个数组中,然后打印出该数组(到文件中)。但是我无法弄清楚:( 我正在寻找一种不包括导入任何类型的XML库的方法,因为我想让

  • 我从http请求中得到了这个QString,我需要做的是只提取字符串“一致“在标签内 怎么做?

  • 问题内容: AJAX调用返回的响应文本包括JSON字符串。我需要: 提取JSON字符串 修改它 然后重新插入以更新原始字符串 我不太担心步骤2和3,但是我不知道如何执行步骤1。我当时在考虑使用正则表达式,但是我不知道该怎么做,因为我的JSON可能具有嵌套对象的多个级别或数组。 问题答案: 您不能使用正则表达式从任意文本中提取JSON。由于正则表达式通常不够强大,无法验证JSON(除非可以使用PCR

  • 我目前正在将Ed Burns的JSF 2.0教科书中的虚拟培训师示例应用程序从JSF托管bean转换为CDI。到目前为止,我遇到的大多数问题都与作用域和忘记正确注入有关,但现在我正在努力克服最近的一个障碍,即从RequestMap中提取CDIBean(实际上是一个实体类)。从目前为止我所能确定的情况来看,似乎可以通过使用样板文件非常简单地提取请求范围的托管Bean。映射实现提供的get(Strin