在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:
class Model(nn.Module): def __init__(self): super(Transfer_model, self).__init__() self.linear1 = nn.Linear(20, 50) self.linear2 = nn.Linear(50, 20) self.linear3 = nn.Linear(20, 2) def forward(self, x): pass
假如我们想要冻结linear1层,需要做如下操作:
model = Model() # 这里是一般情况,共享层往往不止一层,所以做一个for循环 for para in model.linear1.parameters(): para.requires_grad = False # 假如真的只有一层也可以这样操作: # model.linear1.weight.requires_grad = False
最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)
其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。
filter(function, iterable)
filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。
该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。
filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持小牛知识库。
本文向大家介绍对pytorch网络层结构的数组化详解,包括了对pytorch网络层结构的数组化详解的使用技巧和注意事项,需要的朋友参考一下 最近再写openpose,它的网络结构是多阶段的网络,所以写网络的时候很想用列表的方式,但是直接使用列表不能将网络中相应的部分放入到cuda中去。 其实这个问题很简单的,使用moduleList就好了。 1 我先是定义了一个函数,用来根据超参数,建立一个基础网
本文向大家介绍如何冻结JavaScript中的对象?,包括了如何冻结JavaScript中的对象?的使用技巧和注意事项,需要的朋友参考一下 在实时世界中,JavaScript没有其他语言中的传统类。它具有对象和构造函数。 Object.freeze()是许多有助于冻结对象的构造函数方法之一。 冻结对象不允许将新属性添加到该对象,并且还阻止该对象更改其自身的属性。Object.freeze()将始终
我希望将模型A的参数乘以一个标量lambda得到模型B(即模型B的架构和模型A的一样但每个参数都是模型A的lambda倍),现在我希望将一个tensor输入模型B,并对输出进行反向传播,然后优化参数lambda。但梯度无法反传到lambda上。具体代码如下 输出为None,现在我希望有代码能实现相同的功能并能成功计算lambda的梯度。 用tensorboard可视化计算图发现weighted_s
本文向大家介绍pytorch 求网络模型参数实例,包括了pytorch 求网络模型参数实例的使用技巧和注意事项,需要的朋友参考一下 用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量。下面分别介绍来两种方法求模型参数 一 .求得每一层的模型参数,然后自然的可以计算出总的参数。 1.先初始化一个网络模型model 比如我这里是 model=cliqueNet(里面是些初始化的参数)
无论何时可能需要将数据库状态设置为静态,意味着数据库未响应任何读取和写入操作的状态。 简单地说,数据库处于冻结状态。 在本章中,可以学习如何从OrientDB命令行冻结数据库。 以下语句是冻结数据库命令的基本语法。 注 - 只有在连接到远程或本地数据库中的特定数据库后,才能使用此命令。 示例 在这个例子中,我们将使用我们在前一章中创建的名为的数据库。 我们将从CLI冻结这个数据库。 可以使用以下命
问题内容: 我开发了一个简单的Python应用程序来做一些事情,然后决定使用Tkinter添加一个简单的GUI。 问题在于,当main函数正在执行其工作时,窗口会冻结。 我知道这是一个普遍的问题,我已经读过我应该使用多线程(非常复杂,因为该函数还会更新GUI)或将我的代码划分为不同的函数,每个函数工作一段时间。无论如何,我不想为这样一个愚蠢的应用程序更改代码。 我的问题是:有没有简便的方法可以每秒