当前位置: 首页 > 编程笔记 >

pytorch 实现模型不同层设置不同的学习率方式

欧阳智志
2023-03-14
本文向大家介绍pytorch 实现模型不同层设置不同的学习率方式,包括了pytorch 实现模型不同层设置不同的学习率方式的使用技巧和注意事项,需要的朋友参考一下

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(html" target="_blank">lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持小牛知识库。

 类似资料:
  • 问题内容: 我想知道是否有一种方法可以对Caffe中的不同层使用不同的学习率。我正在尝试修改预训练的模型,并将其用于其他任务。我想要的是加快对新添加的层的培训,并使受过培训的层保持较低的学习率,以防止它们变形。例如,我有一个5转换层的预训练模型。现在,我添加一个新的转换层并对其进行微调。前5层的学习率为0.00001,后5层的学习率为0.001。任何想法如何实现这一目标? 问题答案: 使用2个优化

  • 我是使用谓词的新手,不确定我是否正确理解了它。我有一个抽象的类,其中每小时和薪水员工是分开创建的。我的问题依赖于我的类,我不确定如何检查它是否是每小时员工并返回真或假。 我需要为以下所有条件创建不同的谓词: 所有员工,仅每小时一次,仅限工资和全职。 到目前为止,我只是想先让“仅每小时”谓词正常工作,然后再找出剩下的谓词。我不确定在“p”后面加什么来检查它是哪种类型的员工。我目前拥有的是: 我还有一

  • 在为存储在本地数据库中的数据实现搜索功能时,您将如何处理? > 每次用户输入 桌子上有200-300行。每行包含10-15列。 附言:我总是对缓存的列表执行过滤,主要是因为无论如何我都需要显示完整的列表,所以整个数据集已经被预取了。 让我们假设您不必首先显示整个列表,并且只在用户开始键入时显示,在这种情况下执行数据库查询总体上更好吗? 我只是不确定。我看见一位同事在做询问。 是的,我懒得测试数据库

  • 我似乎无法理解学习率的价值。我得到的是下面。 我已经尝试了200个epoch的模型,并希望查看/更改学习速率。这不是正确的方法吗?

  • 这是一个面向对象设计模式专家的问题。 假设我有一个类,负责读取/解析数据流(携带不同类型的信息包)。每个数据包都携带不同类型的信息,因此理想情况下,我会为每种类型的数据包创建一个类(,,…每个都有自己的接口)。 然后,方法将遍历流,并将类(或指向类的指针或引用)返回到相应的包类型。 你会用空指针吗?一个泛型类()可以提供一个公共(部分)接口来查询数据包的类型(这样就可以在运行时做出任何决定),怎么

  • 我试图将预训练的BN权重从pytorch模型复制到其等效的Keras模型,但我一直得到不同的输出。 我阅读了Keras和Pytorch BN文档,我认为差异在于它们计算“平均值”和“var”的方式。 Pytorch: 平均值和标准偏差是按小批量的每个尺寸计算的 来源:Pytorch BatchNorm 因此,他们对样本进行平均。 Keras: 轴:整数,应规格化的轴(通常是特征轴)。例如,在具有d