参考https://tensorflow.google.cn/api_docs/python/tf/io/decode_csv
tf.io.decode_csv(
records, record_defaults, field_delim=',', use_quote_delim=True,
na_value='', select_cols=None, name=None
)
records= '1,2,3,4,5'
record_defaults = [
tf.constant(0, dtype=tf.int32),
0,
np.nan,
"tensorflow",
tf.constant([])
]
tf.io.decode_csv(records , record_defaults)
结果
[<tf.Tensor: id=100, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=101, shape=(), dtype=int32, numpy=2>, <tf.Tensor: id=102, shape=(), dtype=float32, numpy=3.0>, <tf.Tensor: id=103, shape=(), dtype=string, numpy=b'4'>, <tf.Tensor: id=104, shape=(), dtype=float32, numpy=5.0>]
record_defaults 与records个数对应,可以设置格式