Data structure:Dataset.element_spec/

金赤岩
2023-12-01

tensorflow==2.2.0

1.用Dataset.element_spec查看一个elementde 的类型 (1 ) fit in a single component

import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
import numpy as np

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4,10]))
print(dataset1.element_spec) # 查看的是每一个element的属性
for element in dataset1:
    print(element)

shuchu

TensorSpec(shape=(10,), dtype=tf.float32, name=None)  # a single tuple
tf.Tensor(
[0.767676   0.73015213 0.5323138  0.50220263 0.95328426 0.19030821
 0.80313885 0.7522975  0.38625753 0.55641997], shape=(10,), dtype=float32)
tf.Tensor(
[0.02963817 0.6099347  0.94547665 0.51258385 0.45178926 0.9401567
 0.6142694  0.8651656  0.47888255 0.6972141 ], shape=(10,), dtype=float32)
tf.Tensor(
[0.1285268  0.9309876  0.5018612  0.3873167  0.98936284 0.35167396
 0.5177171  0.8133037  0.03449237 0.9925858 ], shape=(10,), dtype=float32)
tf.Tensor(
[0.24921525 0.45211852 0.5103594  0.30220544 0.30971515 0.15253949
 0.00923371 0.22665703 0.8954836  0.2224822 ], shape=(10,), dtype=float32)

(2)对于两种数据的结合也同样使用(1,1) fit in a tuple of components

import tensorflow as tf
# for two
dataset2 = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform([4]),
     tf.random.uniform([4,100]))
)
print(dataset2.element_spec)

dataset3 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([4])))
for element in dataset3:
    print(element)

输出

(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.float32, name=None)) # a tuple of component
tf.Tensor(0.17025971, shape=(), dtype=float32)
tf.Tensor(0.49341345, shape=(), dtype=float32)
tf.Tensor(0.49253404, shape=(), dtype=float32)
tf.Tensor(0.8137529, shape=(), dtype=float32)

(3) (1,(1,1)) fit in a nested tuple of components

import tensorflow as tf
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4,10])) # a single component
dataset2 = tf.data.Dataset.from_tensor_slices((
    tf.random.uniform([4]),
    tf.random.uniform([4,100],maxval=100,dtype=tf.int32)
))
dataset3 = tf.data.Dataset.zip((dataset1,

shuchu

(TensorSpec(shape=(10,), dtype=tf.float32, name=None), (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
 类似资料:

相关阅读

相关文章

相关问答