Estimator类代表了一个模型,以及如何对这个模型进行训练和评估,
class Estimator(builtins.object)
可以按照下面方式创建一个E
def resnet_v1_10_1(features,
labels,
mode,
params):
learning_rate = params.get('learning_rate', 0.001)
optimizer_func = params.get('optimizer', lambda learning_rate: tf.train.AdagradOptimizer(learning_rate=learning_rate))
use_pretrained_weights = params['use_pretrained_weights']
init_ckpt = params['pretrained_ckpt']
inputs = features
is_training = bool(mode == tf.estimator.ModeKeys.TRAIN)
with slim.arg_scope(resnet_arg_scope()):
backbone = resnet_v1_10(inputs,
is_training=is_training,
scope='resnet_v1_10')
va_predictor = tf.estimator.Estimator(
model_fn=resnet_v1_10_1,
params=config.resnet_params,
model_dir=config.model_dir,
config=run_config)
参数介绍:
1. model_fn: 模型函数,这个模型给定输入和参数,会返回训练,验证或者预测等所需要的操作节点,是一个Python函数,它根据给定的输入构建模型。模型函数把输入特征作为参数,将相应的标签作为张量,它也能以某种方式来告知用户模型是在训练,评估或是在执行推理传入的
2. params, **参数**应该是模型超参数的一个集合。会被传递到model_fn,keys是参数的名称,这可以是一个dictionary,Estimator只传递超参数,不会检查,但是我们将在这个例子中把它表示成一个HParams对象,就像namedtuple一样。传入的**配置**用于指定如何运行训练和评估,以及在。这个配置是一个RunConfig对象,该对象会把模型运行环境相关的信息告诉Estimator。
3. model_dir: 日志文件夹,指出在哪里存储结果,所有的输出,如检查点,事件文件等,会写入到这个文件夹中
4. config, 为tf.estimator.RunConfig对象,包含了执行环境的信息。如果没有传递config,则它会被Estimator实例化,使用的是默认设置。
类内方法:
1。 _init_(self,model_fn,model_dir=None,config=None,params=None, warm_start_from=None)
model_fn: 模型函数,格式如下:
参数
1、features: 这是input_fn返回的第一项是images(input_fn是train, evaluate predict的参数),类型应该是单一的tensor 或dict,,
def input_fn(self):
train_dataset = BarrierAttributesJson.create_dataset(self.dataset_size)
transform_fn = lambda value: BarrierAttributesJson.transform_fn(value)
map_fn = lambda value: tf.py_func(transform_fn, [value], (tf.float32, tf.float32))
dataset = train_dataset.shuffle(buffer_size=self.dataset_size, reshuffle_each_iteration=True)
dataset = dataset.repeat().map(map_fn, num_parallel_calls=self.num_parallel_calls)
dataset = dataset.batch(self.batch_size).prefetch(self.prefetch_size)
images, annotation = dataset.make_one_shot_iterator().get_next()
images.set_shape([None] + list(self.input_shape))
annotation.set_shape([None, 7])
return images, annotation
2. labels: 这是input_fn返回的第二项annotation, 类型应该是单一的tensor或 dict,如果mode为ModeKeys.PREDICT,则会默认为labels=None,
3、model:可选,指定是训练,验证还是测试。
返回值: EstimatorSpec
#####2、train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None)
1、steps: 模型训练的步数,如果是None,则一直训练,直到input_fn抛出了超过界限的异常,steps是递进进行,如果执行了两次训练steps=10,则总共训练了20次,
2、max_steps,模型训练的最大步数,如果是None,则一直训练,如果你不想递进训练,直到input_fn抛出了超过界限的异常,若同时设置了steps和max_steps,比如steps=1000, max_steps=3000,上次已经训练到了3000,则此时会再训练中1000,即make_dir保存的是4000的结果
3、saving_listeners: checkpointSaverLISTENER对象的列表,用于在保存检查点之前或之后立即执行的回调函数
返回:self:为了链接下去。
#####3、evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)
根据所给数据input_fn,对模型进行验证。对于每一步,执行input_fn,返回数据的一个batch,一直进行验证,直到
steps个batches进行完毕或者input_fn抛出了越界异常
参数:
checkpont_path,用于验证的检查点路径,如果是none,则使用
model_dir中最新的检查点,
name:验证的名字,使用者可以针对不同的数据集运行多个验证操作,比如训练集和测试集,不同的验证结果被保存在不同的文件夹中,且分别出现在tensorboard中。
返回:
返回一个字典,包括model_fn中指定的评价指标,global_step包含验证进恒行的全局步数
import argparse
import cv2
import tensorflow as tf
from tensorflow_toolkit.vehicle_attributes.vehicle_attributes.trainer import create_session, resnet_v1_10_1, InputTrainData
from tensorflow_toolkit.utils.tfutils.helpers import load_module
def parse_args():
parser = argparse.ArgumentParser(description='Perform training of vehicle attributes model')
parser.add_argument('--path_to_config', help='Path to a config.py',
default='../cars_100/config.py')
return parser.parse_args()
def train(config):
cv2.setNumThreads(1)
session_config = create_session(config, 'train')
#Estimator 类,估算器,用来训练和验证 TensorFlow 模型。
run_config = tf.estimator.RunConfig(session_config=session_config,
keep_checkpoint_every_n_hours=config.train.keep_checkpoint_every_n_hours,
save_summary_steps=config.train.save_summary_steps,
save_checkpoints_steps=config.train.save_checkpoints_steps,
tf_random_seed=config.train.random_seed)
va_predictor = tf.estimator.Estimator(
model_fn=resnet_v1_10_1,
params=config.resnet_params,
model_dir=config.model_dir,
config=run_config)#创建一个估算器Estimator,需要传入一个模型函数,一组参数和一些配置
#模型函数是一个python函数,它根据给定的输入构建模型,
#传入的参数应该是模型超参数的一个集合,可以是一个dictonary,
#传入的配置用于指定如何运行训练和评估,以及在哪里存储结果,
input_data = InputTrainData(batch_size=config.train.batch_size,
input_shape=config.input_shape,
json_path=config.train.annotation_path)
va_predictor.train(
input_fn=input_data.input_fn,
steps=config.train.steps,
hooks=[])
def main(_):
args = parse_args()
cfg = load_module(args.path_to_config)
train(cfg)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(main)