当前位置: 首页 > 编程笔记 >

tensorflow 固定部分参数训练,只训练部分参数的实例

百里文景
2023-03-14
本文向大家介绍tensorflow 固定部分参数训练,只训练部分参数的实例,包括了tensorflow 固定部分参数训练,只训练部分参数的实例的使用技巧和注意事项,需要的朋友参考一下

在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练。

1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如

def inference(input_):
  """
  this is where you put your graph.
  the following is just an example.
  """
  
  conv1 = tf.layers.conv2d(input_)
 
  conv2 = tf.layers.conv2d(conv1)
 
  return conv2
 
 
input_ = tf.placeholder()
output = inference(input_)
...
calculate_loss_op = ...
train_op = ...
...
 
with tf.Session() as sess:
  sess.run([loss, train_op], feed_dict={input_: train_data})
 
  if validation == True:
    sess.run([loss], feed_dict={input_: validate_date})

这种方式很简单,也很直接了然。

2.但是,如果处理的数据量很大的时候,使用 tf.placeholder() 来载入数据会严重地拖慢训练的进度,因此,常用tfrecords文件来读取数据。

此时,很容易想到,将不同的值传入inference()函数中进行计算。

train_batch, label_batch = decode_train()
val_train_batch, val_label_batch = decode_validation()
 
 
train_result = inference(train_batch)
...
loss = ..
train_op = ...
...
 
if validation == True:
  val_result = inference(val_train_batch)
  val_loss = ..
  
 
with tf.Session() as sess:
  sess.run([loss, train_op])
 
  if validation == True:
    sess.run([val_result, val_loss])

这种方式看似能够直接调用inference()来对验证数据进行前向传播计算,但是,实则会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。

3.用一个tf.placeholder来控制是否训练、验证。

def inference(input_):
  ...
  ...
  ...
  
  return inference_result
 
 
train_batch, label_batch = decode_train()
val_batch, val_label = decode_validation()
 
is_training = tf.placeholder(tf.bool, shape=())
 
x = tf.cond(is_training, lambda: train_batch, lambda: val_batch)
y = tf.cond(is_training, lambda: train_label, lambda: val_label)
 
logits = inference(x)
loss = cal_loss(logits, y)
train_op = optimize(loss)
 
with tf.Session() as sess:
  
  loss, _ = sess.run([loss, train_op], feed_dict={is_training: True})
  
  if validation == True:
    loss = sess.run(loss, feed_dict={is_training: False})

使用这种方式就可以在一个大图里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。

以上这篇tensorflow 固定部分参数训练,只训练部分参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持小牛知识库。

 类似资料:
  • 本文向大家介绍Pytorch加载部分预训练模型的参数实例,包括了Pytorch加载部分预训练模型的参数实例的使用技巧和注意事项,需要的朋友参考一下 前言 自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了。对于深度学习的初学者,Pytorch值得推荐。今天主要主要谈谈Pytorch是

  • 相关概念 客户端 (Client):客户端是一个用于建立 TensorFlow 计算图并创立与集群进行交互的会话层 tensorflow::Session 的程序。一般客户端是通过 python 或 C++ 实现的。一个独立的客户端进程可以同时与多个 TensorFlow 的服务端相连 (上面的计算流程一节),同时一个独立的服务端也可以与多个客户端相连。 集群 (Cluster) : 一个 Ten

  • 译者:bat67 最新版会在译者仓库首先同步。 目前为止,我们以及看到了如何定义网络,计算损失,并更新网络的权重。 现在可能会想, 数据呢? 通常来说,当必须处理图像、文本、音频或视频数据时,可以使用python标准库将数据加载到numpy数组里。然后将这个数组转化成torch.*Tensor。 对于图片,有Pillow,OpenCV等包可以使用 对于音频,有scipy和librosa等包可以使用

  • 本文向大家介绍TensorFlow实现随机训练和批量训练的方法,包括了TensorFlow实现随机训练和批量训练的方法的使用技巧和注意事项,需要的朋友参考一下 TensorFlow更新模型变量。它能一次操作一个数据点,也可以一次操作大量数据。一个训练例子上的操作可能导致比较“古怪”的学习过程,但使用大批量的训练会造成计算成本昂贵。到底选用哪种训练类型对机器学习算法的收敛非常关键。 为了Tensor

  • 简介 TensorFlow只是library,分布式TensorFlow应用需要我们在多个节点启动Python脚本组成分布式计算集群。 Xiaomi Cloud-ML支持标准的分布式TensorFlow应用,用户只需编写对应的Python脚本即可提交运行,用法与单机版类似。 代码规范 由于分布式TensorFlow应用需要启动多节点,每个节点需要知道自己的角色,一般都是通过命令行参数传入,而用户自

  • 问题内容: 我正在尝试运行以下Colab项目,但是当我想将训练数据分为验证和训练部分时,出现此错误: 我使用以下代码: 如何解决此错误? 问题答案: 根据Tensorflow Dataset docs ,百分比拆分是可能的,例如 如示例所示,更改列表时,您的代码将起作用: 使用上面的代码,有2590个条目,而有1080个。