当前位置: 首页 > 工具软件 > tf.Transform > 使用案例 >

tf.estimator.Estimator解析

房泉
2023-12-01

Estimator类代表了一个模型,以及如何对这个模型进行训练和评估,

class Estimator(builtins.object)

可以按照下面方式创建一个E

def resnet_v1_10_1(features,
                   labels,
                   mode,
                   params):

  learning_rate = params.get('learning_rate', 0.001)
  optimizer_func = params.get('optimizer', lambda learning_rate: tf.train.AdagradOptimizer(learning_rate=learning_rate))
  use_pretrained_weights = params['use_pretrained_weights']
  init_ckpt = params['pretrained_ckpt']
  inputs = features
  is_training = bool(mode == tf.estimator.ModeKeys.TRAIN)
  with slim.arg_scope(resnet_arg_scope()):
    backbone = resnet_v1_10(inputs,
                            is_training=is_training,
                            scope='resnet_v1_10')

va_predictor = tf.estimator.Estimator(
    model_fn=resnet_v1_10_1,
    params=config.resnet_params,
    model_dir=config.model_dir,
    config=run_config)

 

参数介绍:

1. model_fn: 模型函数,这个模型给定输入和参数,会返回训练,验证或者预测等所需要的操作节点,是一个Python函数,它根据给定的输入构建模型。模型函数把输入特征作为参数,将相应的标签作为张量,它也能以某种方式来告知用户模型是在训练,评估或是在执行推理传入的

2. params, **参数**应该是模型超参数的一个集合。会被传递到model_fn,keys是参数的名称,这可以是一个dictionary,Estimator只传递超参数,不会检查,但是我们将在这个例子中把它表示成一个HParams对象,就像namedtuple一样。传入的**配置**用于指定如何运行训练和评估,以及在。这个配置是一个RunConfig对象,该对象会把模型运行环境相关的信息告诉Estimator。

3. model_dir: 日志文件夹,指出在哪里存储结果,所有的输出,如检查点,事件文件等,会写入到这个文件夹中

4. config, 为tf.estimator.RunConfig对象,包含了执行环境的信息。如果没有传递config,则它会被Estimator实例化,使用的是默认设置。

类内方法:

1。 _init_(self,model_fn,model_dir=None,config=None,params=None, warm_start_from=None)

model_fn: 模型函数,格式如下:

参数

1、features: 这是input_fn返回的第一项是images(input_fn是train, evaluate predict的参数),类型应该是单一的tensor 或dict,,

  def input_fn(self):
    train_dataset = BarrierAttributesJson.create_dataset(self.dataset_size)

    transform_fn = lambda value: BarrierAttributesJson.transform_fn(value)

    map_fn = lambda value: tf.py_func(transform_fn, [value], (tf.float32, tf.float32))

    dataset = train_dataset.shuffle(buffer_size=self.dataset_size, reshuffle_each_iteration=True)
    dataset = dataset.repeat().map(map_fn, num_parallel_calls=self.num_parallel_calls)
    dataset = dataset.batch(self.batch_size).prefetch(self.prefetch_size)

    images, annotation = dataset.make_one_shot_iterator().get_next()

    images.set_shape([None] + list(self.input_shape))
    annotation.set_shape([None, 7])

    return images, annotation

2. labels: 这是input_fn返回的第二项annotation, 类型应该是单一的tensor或 dict,如果mode为ModeKeys.PREDICT,则会默认为labels=None,

3、model:可选,指定是训练,验证还是测试。
 

返回值: EstimatorSpec

#####2、train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None)

1、steps: 模型训练的步数,如果是None,则一直训练,直到input_fn抛出了超过界限的异常,steps是递进进行,如果执行了两次训练steps=10,则总共训练了20次,

2、max_steps,模型训练的最大步数,如果是None,则一直训练,如果你不想递进训练,直到input_fn抛出了超过界限的异常,若同时设置了steps和max_steps,比如steps=1000, max_steps=3000,上次已经训练到了3000,则此时会再训练中1000,即make_dir保存的是4000的结果

3、saving_listeners: checkpointSaverLISTENER对象的列表,用于在保存检查点之前或之后立即执行的回调函数

返回:self:为了链接下去。

#####3、evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)

根据所给数据input_fn,对模型进行验证。对于每一步,执行input_fn,返回数据的一个batch,一直进行验证,直到

steps个batches进行完毕或者input_fn抛出了越界异常

参数:

checkpont_path,用于验证的检查点路径,如果是none,则使用model_dir中最新的检查点,

name:验证的名字,使用者可以针对不同的数据集运行多个验证操作,比如训练集和测试集,不同的验证结果被保存在不同的文件夹中,且分别出现在tensorboard中。

返回:

返回一个字典,包括model_fn中指定的评价指标,global_step包含验证进恒行的全局步数

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


import argparse
import cv2
import tensorflow as tf

from tensorflow_toolkit.vehicle_attributes.vehicle_attributes.trainer import create_session, resnet_v1_10_1, InputTrainData
from tensorflow_toolkit.utils.tfutils.helpers import load_module

def parse_args():
  parser = argparse.ArgumentParser(description='Perform training of vehicle attributes model')
  parser.add_argument('--path_to_config', help='Path to a config.py',
                      default='../cars_100/config.py')
  return parser.parse_args()

def train(config):
  cv2.setNumThreads(1)

  session_config = create_session(config, 'train')
#Estimator 类,估算器,用来训练和验证 TensorFlow 模型。
  run_config = tf.estimator.RunConfig(session_config=session_config,
                                      keep_checkpoint_every_n_hours=config.train.keep_checkpoint_every_n_hours,
                                      save_summary_steps=config.train.save_summary_steps,
                                      save_checkpoints_steps=config.train.save_checkpoints_steps,
                                      tf_random_seed=config.train.random_seed)

  va_predictor = tf.estimator.Estimator(
    model_fn=resnet_v1_10_1,
    params=config.resnet_params,
    model_dir=config.model_dir,
    config=run_config)#创建一个估算器Estimator,需要传入一个模型函数,一组参数和一些配置
  #模型函数是一个python函数,它根据给定的输入构建模型,
  #传入的参数应该是模型超参数的一个集合,可以是一个dictonary,
  #传入的配置用于指定如何运行训练和评估,以及在哪里存储结果,

  input_data = InputTrainData(batch_size=config.train.batch_size,
                              input_shape=config.input_shape,
                              json_path=config.train.annotation_path)

  va_predictor.train(
    input_fn=input_data.input_fn,
    steps=config.train.steps,
    hooks=[])

def main(_):
  args = parse_args()
  cfg = load_module(args.path_to_config)
  train(cfg)

if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run(main)

 

 类似资料: