当前位置: 首页 > 编程笔记 >

解决pytorch 的state_dict()拷贝问题

任飞龙
2023-03-14
本文向大家介绍解决pytorch 的state_dict()拷贝问题,包括了解决pytorch 的state_dict()拷贝问题的使用技巧和注意事项,需要的朋友参考一下

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持小牛知识库。如有错误或未考虑完全的地方,望不吝赐教。

 类似资料:
  • 我遇到了一个自定义pytorch dataloader的问题,我认为它与函数中的浅拷贝和深拷贝有关。但是,有些行为我不理解。我不知道它是来自pytorch dataloader类还是其他地方。 我根据自己的复杂用例创建了一个最小的工作示例。最初,我将一个数据集保存为,并将其加载到中。对于NN,我希望元素归一化为1(我除以它们的总和),并分别返回总和。: 我得到以下输出: 在第一个历元之后,应该给出

  • 本文向大家介绍javascript深拷贝和浅拷贝详解,包括了javascript深拷贝和浅拷贝详解的使用技巧和注意事项,需要的朋友参考一下 一、数组的深浅拷贝 在使用JavaScript对数组进行操作的时候,我们经常需要将数组进行备份,事实证明如果只是简单的将它赋予其他变量,那么我们只要更改其中的任何一个,然后其他的也会跟着改变,这就导致了问题的发生。 这是为什么呢? 因为如果只是简单的赋值,它只

  • 这个深拷贝的代码,解决了循环引用的问题,但是我没看明白是怎么解决的循环引用? 当递归遇到引用了自己的属性的时候,WeakMap返回的不是空对象吗?不应该是指向自己啊。求指教

  • 本文向大家介绍pytorch 状态字典:state_dict使用详解,包括了pytorch 状态字典:state_dict使用详解的使用技巧和注意事项,需要的朋友参考一下 pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等) (注意,只有那些参数可以训练的layer才会被保存到模型的s

  • 本文向大家介绍java 深拷贝与浅拷贝机制详解,包括了java 深拷贝与浅拷贝机制详解的使用技巧和注意事项,需要的朋友参考一下  java 深拷贝与浅拷贝机制详解 概要: 在Java中,拷贝分为深拷贝和浅拷贝两种。java在公共超类Object中实现了一种叫做clone的方法,这种方法clone出来的新对象为浅拷贝,而通过自己定义的clone方法为深拷贝。 (一)Object中clone方法 如果

  • 本文向大家介绍深入理解python中的浅拷贝和深拷贝,包括了深入理解python中的浅拷贝和深拷贝的使用技巧和注意事项,需要的朋友参考一下 在讲什么是深浅拷贝之前,我们先来看这样一个现象: 为什么我只对b进行修改,却影响到了a呢?看过我在之前的文章中就说过:序列中保存的都是内存的引用。 所以,当我们通过b去修改里面的空列表的时候,其实就是修改内存中的同一个对象,所以会影响到a。 代码验证无误,所以