1.5.6.7 使用TensorFlow模板应用

优质
小牛编辑
132浏览
2023-12-01

介绍

TensorFlow template application表示通用的TensorFlow应用代码,用户可以直接使用这些模板而不需要编写TensorFlow应用代码。 用户训练数据一般都是稠密的CSV格式,或稀疏的LIBSVM格式或图片,这些数据都可以转成TFRecords,模型本身则可以使用代码生成,通过不同的超参数组合可以实现通过的TensorFlow应用,用户直接下载模板或者使用Xiaomi Cloud-ML服务直接提交应用,甚至不需要编写一行代码即可训练生成模型。

CSV数据

如果数据是稠密的,甚至是图片,也可以先转成CSV格式,然后使用Python脚本或者Spark转成TFRecords,保存到本地或者FDS中。

然后使用dense_classifier.py训练模型,模板地址是 https://github.com/tobegit3hub/deep_recommend_system/blob/master/dense_classifier.py

Cancer数据

默认使用Cancer数据集,feature_size是9,label_size是2,可以设置训练的epoch number、learning rate、optmizier、dnn或者网络层数等超参数。

./dense_classifier.py --batch_size 1024 --epoch_number 1000 --step_to_validate 10 --optmizier adagrad --model dnn --model_network "128 32 8"

Iris数据

可以使用iris数据集,需要指定feature_size是4,label_size是3。

./dense_classifier.py --train_tfrecords_file ./data/iris/iris_train.csv.tfrecords --validate_tfrecords_file ./data/iris/iris_test.csv.tfrecords --feature_size 4 --label_size 3

Lung cancer数据

可以使用lung cancer数据集,需要指定feature_size是262144,label_size是2,并且可以指定使用CNN模型。

./dense_classifier.py --train_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --validate_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --feature_size 262144 --label_size 2 --batch_size 2 --validate_batch_size 2 --epoch_number -1 --model cnn

LIBSVM数据

如果数据可以用LIBSVM格式表示,可以使用Python脚本或者Spark转成TFRecords,保存到本地或者FDS中。 然后使用sparse_classifier.py训练模型,模板地址是 https://github.com/tobegit3hub/deep_recommend_system/blob/master/sparse_classifier.py

实现方法

大家开发TensorFlow应用时,可以将可变的参数定义为命令行接受的超参数,用户不同模型传入不同的超参数即可,保证应用的通用性。

对于模型不通用需要大量定制的,暂时无法做成TensorFlow模型应用,这些由用户开发和维护即可。