model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。
应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。
再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。
补充:pytorch中state_dict的理解
在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
import torch import torch.nn as nn import torchvision import numpy as np from torchsummary import summary # Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,"\t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
输出如下:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]
我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!
补充:pytorch保存模型时报错***object has no attribute 'state_dict'
net=BaseNet()
保存net时报错 object has no attribute 'state_dict'
torch.save(net.state_dict(), models_dir)
原因是定义类的时候不是继承nn.Module类,比如:
class BaseNet(object): def __init__(self):
class BaseNet(nn.Module): def __init__(self): super(BaseNet, self).__init__()
以上为个人经验,希望能给大家一个参考,也希望大家多多支持小牛知识库。如有错误或未考虑完全的地方,望不吝赐教。
我遇到了一个自定义pytorch dataloader的问题,我认为它与函数中的浅拷贝和深拷贝有关。但是,有些行为我不理解。我不知道它是来自pytorch dataloader类还是其他地方。 我根据自己的复杂用例创建了一个最小的工作示例。最初,我将一个数据集保存为,并将其加载到中。对于NN,我希望元素归一化为1(我除以它们的总和),并分别返回总和。: 我得到以下输出: 在第一个历元之后,应该给出
本文向大家介绍javascript深拷贝和浅拷贝详解,包括了javascript深拷贝和浅拷贝详解的使用技巧和注意事项,需要的朋友参考一下 一、数组的深浅拷贝 在使用JavaScript对数组进行操作的时候,我们经常需要将数组进行备份,事实证明如果只是简单的将它赋予其他变量,那么我们只要更改其中的任何一个,然后其他的也会跟着改变,这就导致了问题的发生。 这是为什么呢? 因为如果只是简单的赋值,它只
这个深拷贝的代码,解决了循环引用的问题,但是我没看明白是怎么解决的循环引用? 当递归遇到引用了自己的属性的时候,WeakMap返回的不是空对象吗?不应该是指向自己啊。求指教
本文向大家介绍pytorch 状态字典:state_dict使用详解,包括了pytorch 状态字典:state_dict使用详解的使用技巧和注意事项,需要的朋友参考一下 pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等) (注意,只有那些参数可以训练的layer才会被保存到模型的s
本文向大家介绍java 深拷贝与浅拷贝机制详解,包括了java 深拷贝与浅拷贝机制详解的使用技巧和注意事项,需要的朋友参考一下 java 深拷贝与浅拷贝机制详解 概要: 在Java中,拷贝分为深拷贝和浅拷贝两种。java在公共超类Object中实现了一种叫做clone的方法,这种方法clone出来的新对象为浅拷贝,而通过自己定义的clone方法为深拷贝。 (一)Object中clone方法 如果
本文向大家介绍深入理解python中的浅拷贝和深拷贝,包括了深入理解python中的浅拷贝和深拷贝的使用技巧和注意事项,需要的朋友参考一下 在讲什么是深浅拷贝之前,我们先来看这样一个现象: 为什么我只对b进行修改,却影响到了a呢?看过我在之前的文章中就说过:序列中保存的都是内存的引用。 所以,当我们通过b去修改里面的空列表的时候,其实就是修改内存中的同一个对象,所以会影响到a。 代码验证无误,所以