1.5.6.7 使用TensorFlow模板应用
介绍
TensorFlow template application表示通用的TensorFlow应用代码,用户可以直接使用这些模板而不需要编写TensorFlow应用代码。 用户训练数据一般都是稠密的CSV格式,或稀疏的LIBSVM格式或图片,这些数据都可以转成TFRecords,模型本身则可以使用代码生成,通过不同的超参数组合可以实现通过的TensorFlow应用,用户直接下载模板或者使用Xiaomi Cloud-ML服务直接提交应用,甚至不需要编写一行代码即可训练生成模型。
CSV数据
如果数据是稠密的,甚至是图片,也可以先转成CSV格式,然后使用Python脚本或者Spark转成TFRecords,保存到本地或者FDS中。
然后使用dense_classifier.py训练模型,模板地址是 https://github.com/tobegit3hub/deep_recommend_system/blob/master/dense_classifier.py 。
Cancer数据
默认使用Cancer数据集,feature_size是9,label_size是2,可以设置训练的epoch number、learning rate、optmizier、dnn或者网络层数等超参数。
./dense_classifier.py --batch_size 1024 --epoch_number 1000 --step_to_validate 10 --optmizier adagrad --model dnn --model_network "128 32 8"
Iris数据
可以使用iris数据集,需要指定feature_size是4,label_size是3。
./dense_classifier.py --train_tfrecords_file ./data/iris/iris_train.csv.tfrecords --validate_tfrecords_file ./data/iris/iris_test.csv.tfrecords --feature_size 4 --label_size 3
Lung cancer数据
可以使用lung cancer数据集,需要指定feature_size是262144,label_size是2,并且可以指定使用CNN模型。
./dense_classifier.py --train_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --validate_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --feature_size 262144 --label_size 2 --batch_size 2 --validate_batch_size 2 --epoch_number -1 --model cnn
LIBSVM数据
如果数据可以用LIBSVM格式表示,可以使用Python脚本或者Spark转成TFRecords,保存到本地或者FDS中。 然后使用sparse_classifier.py训练模型,模板地址是 https://github.com/tobegit3hub/deep_recommend_system/blob/master/sparse_classifier.py 。
实现方法
大家开发TensorFlow应用时,可以将可变的参数定义为命令行接受的超参数,用户不同模型传入不同的超参数即可,保证应用的通用性。
对于模型不通用需要大量定制的,暂时无法做成TensorFlow模型应用,这些由用户开发和维护即可。