当前位置: 首页 > 编程笔记 >

浅谈PyTorch的可重复性问题(如何使实验结果可复现)

钱震博
2023-03-14
本文向大家介绍浅谈PyTorch的可重复性问题(如何使实验结果可复现),包括了浅谈PyTorch的可重复性问题(如何使实验结果可复现)的使用技巧和注意事项,需要的朋友参考一下

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持小牛知识库。

 类似资料:
  • 问题内容: 我有一个使用TensorFlow创建多层感知器网络(带有辍学)的Python脚本,以进行二进制分类。即使我很小心地设置了Python和TensorFlow种子,我仍然得到了不可重复的结果。如果我运行一次然后再次运行,则会得到不同的结果。我什至可以运行一次,退出Python,重新启动Python,再次运行并获得不同的结果。 我尝试过的 我知道有人发布了有关在TensorFlow中获得不可

  • 本文向大家介绍浅谈pytorch、cuda、python的版本对齐问题,包括了浅谈pytorch、cuda、python的版本对齐问题的使用技巧和注意事项,需要的朋友参考一下 在使用深度学习模型训练的过程中,工具的准备也算是一个良好的开端吧。熟话说完事开头难,磨刀不误砍柴工,先把前期的问题搞通了,能为后期节省不少精力。 以pytorch工具为例: pytorch版本为1.0.1,自带python版

  • 钱箱类: 商户类: 输入数据: 我的任务 计算每个商家的总金额并返回商家列表 我正在尝试使用Stream API解决这个任务。并编写了以下代码: 结果 但显然,流返回四个对象,而不是所需的两个对象。我意识到,地图(第二行)为每个cashBoxId创建了四个对象。而且我不知道如何通过进行过滤,也不知道如何获得没有重复的结果。

  • 我正在学习Java的易失性,我的代码是这样的。 我知道如果flag没有volatile,线程就不会存在。这是能见度的问题。 但是,如果我在while循环中编写一些代码,如,t1线程将读取新值并停止循环。 我知道如何使用volatile来解决可见性问题,所以我的问题是: 为什么当我写?

  • 我正在 当我通过isundermaintanace时,真的最后没有执行。 我错过了什么?还有别的方法吗?

  • 我现在正在学习Java,我读到了这个问题:什么是Java字符串实习堆栈溢出 但我阅读了其他文章,其中提供了一些我不理解的示例: Java 11中的结果(在Java 8中应该相同)是: 我不知道为什么结果不同,我想这都是真的。有人能解释一下吗?