当前位置: 首页 > 工具软件 > Slim Lang > 使用案例 >

slim的train

袁波
2023-12-01

函数定义

tf.contrib.slim.learning.train

_USE_DEFAULT=0
def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          session_wrapper=None,
          trace_every_n_steps=None,
          ignore_live_threads=False)

参数

使用TensorFlow 的监督器(supervisor)来运行训练循环。 提供sync_optimizer时,讲同步进行梯度更新,否则将异步进行梯度更新。

  • train_op: 这是一个Tensor,当被执行的时候,将进行梯度更新并返回损失值。
  • logdir: 训练损失(trian loss)写入的目录。
  • train_step_fn: 为了执行单次梯度跟新操作,这个函数将会被调用。这个函数必须有四个参数:session、train_op、 global step、dictionary.
  • train_step_kwargs: 传给train_step_fn的一个dictionary,默认情况下,两个叫做"should_stop" 和"should_log"的布尔值需要提供。
  • log_every_n_steps: 多少次迭代保存一次训练损失。
  • graph: 传递给监督其supervisor的图,如果为空,则使用默认的graph。
  • master: tensorflow master的地址
  • is_chief: 指定是否在主要副本上运行training。
  • global_step: 代表全局step的tensor,如果为空,那么将会调用training_util.get_or_create_global_step(),
  • number_of_steps: 训练时最大的梯度更新次数,当global step大于这个值时,停止训练。如果这个值为空,则训练不会停止下来。
  • init_op: 初始化操作,如果为空,则调用tf.global_variables_initializer()初始化
  • init_feed_dict: 当执行初始化操作时的需要feed进去的一个字典
  • local_init_op: 局部初始化操作,如果为空,则调用tf.local_variables_initializer()和 tf.tables_initializer()来初始化
  • init_fn: 在Init_op被执行后,一个可选的调用函数。这个函数需要接受一个参数,即被初始化的session。
  • ready_op: 检查模型是否准备好了的操作,如果为空,将会调用tf.report_uninitialized_variables()。
  • summary_op: summary操作。
  • save_summaries_secs: 多少秒保存一次summaries。
  • summary_writer: 一个SummaryWriter,如果为None,则不会又summary会被写入。如果没有设置该值,将会自动创建一个SummaryWriter。
  • startup_delay_steps: 在梯度更新之前需要等待的step数。如果sync_optimizer被提供,则这个值必须为0.
  • saver: 保存checkpoint文件的saver,如果为None,一个默认的saver将会被创建。
  • save_interval_secs: 多少秒保存一次模型的checkpoint文件到logdir。
  • sync_optimizer: tf.train.SyncReplicasOptimizer的一个实例,或者这个实例的一个列表。如果这个参数被更新,则梯度更新操作将同步进行。如果为None,则梯度更新操作将异步进行。
  • session_config: tf.ConfigProto的一个实例,用于配置Session,如果为None,则将会创建一个默认值。
  • session_wrapper: 会话包装器,它把tf.Session作为唯一的参数传入,返回一个和tf.Session具有相同方法的包装后的session。如果不为None,则包装后的对象将会在训练中使用。
  • trace_every_n_steps: 以一种Chrome trace format生成并保存Timeline 保存的频率为trace_every_n_steps,如果为None, 没有任何trace信息将会别保存。
  • ignore_live_threads: 如果为True,则忽略那些在停止supervisor之后仍然在运行的线程,而不是引发RuntimeError。

基本流程

  • 定义模型 --> 定义loss --> 定义optimizer(learning_rate) --> 创建train_op(loss, optimizer) --> 执行训练(trian_op, log_dir)
# 加载数据/创建模型
images, labels = LoadData(...)
predictions = MyModel(images)

# 定义损失函数loss
slim.losses.log_loss(predictions, labels)
total_loss = slim.losses.get_total_loss()

# 定义优化器optimizer
optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)

# 创建train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)

# 运行训练.
slim.learning.train(train_op, my_log_dir)

模型恢复/初始化

从checkpoint文件恢复模型

通过指定init_fn参数(该参数是一个函数)

  • 恢复所有变量
# 创建 train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)

# 创建初始化赋值 op
checkpoint_path = '/path/to/checkpoint'

 #恢复所有变量
variables_to_restore = slim.get_model_variables()

init_fn = slim.assign_from_checkpoint_fn(checkpoint_path, variables_to_restore)

# 运行训练
slim.learning.train(train_op, my_log_dir, init_fn=init_fn)
  • 恢复部分变量
...
#从检查点文件中恢复name='v2'的变量
variables_to_restore = slim.get_variables_by_name("v2")     
# or 从检查点文件中恢复name带有2的所有变量
variables_to_restore = slim.get_variables_by_suffix("2")     
# or 从检查点文件中恢复命名空间scope='nested'的所有变量
variables_to_restore = slim.get_variables(scope="nested")    
# or 恢复命名空间scope='nested'的所有变量
variables_to_restore = slim.get_variables_to_restore(include=['fc6', 'fc7', 'fc8'])  
# or 除了命名空间scope='v1'的变量
variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8'])      
...
  • 修改变量名
# 'conv1/weights'从checkpoint的'vgg16/conv1/weights'中恢复
def name_in_checkpoint(var):
  return 'vgg16/' + var.op.name

# 'conv1/weights'和'conv1/bias'从checkpoint的'conv1/params1'和'conv1/params2'中恢复
def name_in_checkpoint(var):
  if "weights" in var.op.name:
    return var.op.name.replace("weights", "params1")
  if "bias" in var.op.name:
    return var.op.name.replace("bias", "params2")

variables_to_restore = slim.get_model_variables()
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
...

用内存变量初始化

# 创建 train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)

# 创建变量名到值的映射(from variable names to values):
var0_initial_value = ReadFromDisk(...)
var1_initial_value = ReadFromDisk(...)

var_names_to_values = {'var0': var0_initial_value,
                       'var1': var1_initial_value,
}
init_assign_op, init_feed_dict = slim.assign_from_values(var_names_to_values)

# 创建初始化赋值函数
def InitAssignFn(sess):
    sess.run(init_assign_op, init_feed_dict)
# 也可以使用 init_fn = slim.assign_from_values_fn(var_names_to_values)

# 运行训练
slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)

"冻结"部分层

通过定制train_op实现

# slim.learning.create_train_op
def create_train_op(total_loss,
                  optimizer,
                    global_step=_USE_GLOBAL_STEP,
                    update_ops=None,
                    variables_to_train=None,
                    clip_gradient_norm=0,
                    summarize_gradients=False,
                    gate_gradients=tf_optimizer.Optimizer.GATE_OP,
                    aggregation_method=None,
                    colocate_gradients_with_ops=False,
                    gradient_multipliers=None,
                    check_numerics=True)
variables_to_train = slim.get_trainable_variables(scope='vgg16')
variabels_to_exclude = slim.get_trainable_variables(scope='vgg16/fc8')
variables_to_train = [var for var in variables_to_train if var not in variabels_to_exclude]
train_op = slim.learning.create_train_op(total_loss, optimizer, variables_to_train=variables_to_train)
...

非梯度更新

更新不在Graph图中的option,如moving mean 和 moving variance等。

# 方式一:
# 强制 TF-Slim 不采用任何 update_ops:
train_op = slim.learning.create_train_op(
     total_loss,
     optimizer,
     update_ops=[])
# 方式二:
# 替换 update ops 集:
train_op = slim.learning.create_train_op(
     total_loss,
     optimizer,
     update_ops=my_other_update_ops)
# 方式三
# 新增 update ops 到默认的 updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1)
train_op = slim.learning.create_train_op(
     total_loss,
     optimizer)
 类似资料: