当前位置: 首页 > 知识库问答 >
问题:

pytorch dataloader和/或_getitem__________函数中的浅拷贝和深拷贝

万高轩
2023-03-14

我遇到了一个自定义pytorch dataloader的问题,我认为它与__getitem__()函数中的浅拷贝和深拷贝有关。但是,有些行为我不理解。我不知道它是来自pytorch dataloader类还是其他地方。

我根据自己的复杂用例创建了一个最小的工作示例。最初,我将一个数据集保存为. hdf5,并将其加载到__init__()中。对于NN,我希望元素归一化为1(我除以它们的总和),并分别返回总和。:

# imports
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
# create dataset with fixed seed
np.random.seed(1234)
data = np.random.rand(20, 4)
print(data)
# create custom dataset class

class TestDataset(Dataset):
    """ Test dataset to illustrate bug in get_item """

    def __init__(self, data_array, transform=None, apply_logit=True, with_noise=False):
        """
        Args:
            data_array (np.array): representing data loaded from hdf5 file or so
            transform (None, callable or 'norm'): if data should be transformed
            apply_logit (bool): if logit transform should be applied at the end
            with_noise (bool): if noise should be applied in each call
        """

        self.data = data_array

        self.transform = transform
        self.apply_logit = apply_logit
        self.with_noise = with_noise


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data = self.data[idx]

        if self.with_noise:
            data = add_noise(data)

        data_sum = data.sum(axis=(-1), keepdims=True)

        if self.transform:
            if self.transform == 'norm':
                data /= (data_sum + 1e-16) # this should be avoided
            else:
                data = self.transform(data)

        if self.apply_logit:
            data = logit_trafo(data)

        sample = {'data': data, 'data_sum': data_sum.squeeze()}

        return sample

def get_dataloader(data_array, device, batch_size=2, apply_logit=True, with_noise=False, normed=False):

    kwargs = {'num_workers': 2, 'pin_memory': True} if device.type is 'cuda' else {}

    dataset = TestDataset(data_array, transform='norm' if normed else None, apply_logit=apply_logit,
                              with_noise=with_noise)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs)

def add_noise(input_tensor):
    noise = np.random.rand(*input_tensor.shape)*1e-6
    return input_tensor+noise

ALPHA = 1e-6
def logit(x):
    return np.log(x / (1.0 - x))

def logit_trafo(x):
    local_x = ALPHA + (1. - 2.*ALPHA) * x
    return logit(local_x)
# with_noise=False will print just [1. 1.] after one epoch (due to the /= operation above)
# with_noise=True will remove this issue. Why?

mydata = get_dataloader(data, torch.device('cpu'), apply_logit=False, with_noise=False, normed=True)
with torch.no_grad():
    for n in range(3):
        print("epoch: ", n)
        for i, elem in enumerate(mydata):
            print('batch: ', i, #elem['data'].numpy(), 
                  elem['data_sum'].numpy())

我得到以下输出:

[[0.19151945 0.62210877 0.43772774 0.78535858]
 [0.77997581 0.27259261 0.27646426 0.80187218]
 [0.95813935 0.87593263 0.35781727 0.50099513]
 [0.68346294 0.71270203 0.37025075 0.56119619]
 [0.50308317 0.01376845 0.77282662 0.88264119]
 [0.36488598 0.61539618 0.07538124 0.36882401]
 [0.9331401  0.65137814 0.39720258 0.78873014]
 [0.31683612 0.56809865 0.86912739 0.43617342]
 [0.80214764 0.14376682 0.70426097 0.70458131]
 [0.21879211 0.92486763 0.44214076 0.90931596]
 [0.05980922 0.18428708 0.04735528 0.67488094]
 [0.59462478 0.53331016 0.04332406 0.56143308]
 [0.32966845 0.50296683 0.11189432 0.60719371]
 [0.56594464 0.00676406 0.61744171 0.91212289]
 [0.79052413 0.99208147 0.95880176 0.79196414]
 [0.28525096 0.62491671 0.4780938  0.19567518]
 [0.38231745 0.05387369 0.45164841 0.98200474]
 [0.1239427  0.1193809  0.73852306 0.58730363]
 [0.47163253 0.10712682 0.22921857 0.89996519]
 [0.41675354 0.53585166 0.00620852 0.30064171]]

epoch:  0
batch:  0 [2.03671454 2.13090485]
batch:  1 [2.69288438 2.3276119 ]
batch:  2 [2.17231943 1.42448741]
batch:  3 [2.77045097 2.19023559]
batch:  4 [2.35475675 2.49511645]
batch:  5 [0.96633253 1.73269209]
batch:  6 [1.5517233 2.1022733]
batch:  7 [3.5333715  1.58393664]
batch:  8 [1.86984429 1.56915029]
batch:  9 [1.70794311 1.25945542]
epoch:  1
batch:  0 [1. 1.]
batch:  1 [1. 1.]
batch:  2 [1. 1.]
batch:  3 [1. 1.]
batch:  4 [1. 1.]
batch:  5 [1. 1.]
batch:  6 [1. 1.]
batch:  7 [1. 1.]
batch:  8 [1. 1.]
batch:  9 [1. 1.]
epoch:  2
batch:  0 [1. 1.]
batch:  1 [1. 1.]
batch:  2 [1. 1.]
batch:  3 [1. 1.]
batch:  4 [1. 1.]
batch:  5 [1. 1.]
batch:  6 [1. 1.]
batch:  7 [1. 1.]
batch:  8 [1. 1.]
batch:  9 [1. 1.]

在第一个历元之后,应该给出每个输入向量之和的条目返回1。根据我的理解,原因是\uuu getitem()\uuu中的/=操作覆盖了原始数组(因为它只是一个浅拷贝)。但是,当我创建带有和\u noise=True的数据加载器时,输出变为

epoch:  0
batch:  0 [2.03671714 2.13090728]
batch:  1 [2.69288618 2.32761437]
batch:  2 [2.17232151 1.42449024]
batch:  3 [2.7704527  2.19023717]
batch:  4 [2.35475926 2.49511859]
batch:  5 [0.96633553 1.73269352]
batch:  6 [1.55172434 2.10227475]
batch:  7 [3.53337356 1.58393908]
batch:  8 [1.86984558 1.56915276]
batch:  9 [1.70794503 1.25945833]
epoch:  1
batch:  0 [2.03671729 2.13090765]
batch:  1 [2.69288721 2.32761405]
batch:  2 [2.17232208 1.42449008]
batch:  3 [2.77045253 2.19023718]
batch:  4 [2.35475815 2.4951189 ]
batch:  5 [0.96633595 1.73269401]
batch:  6 [1.55172476 2.10227547]
batch:  7 [3.53337382 1.58393882]
batch:  8 [1.86984584 1.56915165]
batch:  9 [1.70794547 1.25945795]
epoch:  2
batch:  0 [2.03671533 2.13090593]
batch:  1 [2.69288633 2.32761373]
batch:  2 [2.17232158 1.42448975]
batch:  3 [2.77045371 2.19023796]
batch:  4 [2.3547586  2.49511857]
batch:  5 [0.96633348 1.73269476]
batch:  6 [1.55172544 2.10227616]
batch:  7 [3.53337367 1.58393892]
batch:  8 [1.86984568 1.56915256]
batch:  9 [1.70794379 1.25945825]

如果我添加的噪波乘以0,情况也是如此

为什么呢?为什么它突然变成了一个深拷贝?


共有1个答案

萧明贤
2023-03-14

谢谢你,疯狂物理学家!我必须阅读它和代码几次看到的问题:

如果不调用add_noise(),行data/=(data_sum1e-16)将更改原始输入数组。因此,对它的每一次后续调用都会返回已经标准化的数据。调用add_noise()创建一个新的数组,它的编码方式。就地操作只会更改新数组,而不会触及原始数组(这是我错过的步骤)。因此,随后的调用返回原始的,而不是规范化的数组。

 类似资料:
  • 主要内容:到底是浅拷贝还是深拷贝对于基本类型的数据以及简单的对象,它们之间的拷贝非常简单,就是按位复制内存。例如: b 和 obj2 都是以拷贝的方式初始化的,具体来说,就是将 a 和 obj1 所在内存中的数据按照二进制位(Bit)复制到 b 和 obj2 所在的内存, 这种默认的拷贝行为就是 浅拷贝 ,这和调用 memcpy() 函数的效果非常类似。 对于简单的类,默认的拷贝构造函数一般就够用了,我们也没有必要再显式地定义一

  • 浅拷贝 对于对象或数组类型,当我们将a赋值给b,然后更改b中的属性,a也会随着变化。 也就是说,a和b指向了同一块堆内存,所以修改其中任意的值,另一个值都会随之变化,这就是浅拷贝。 深拷贝 那么相应的,如果给b放到新的内存中,将a的各个属性都复制到新内存里,就是深拷贝。 也就是说,当b中的属性有变化的时候,a内的属性不会发生变化。 参考链接: 深拷贝与浅拷贝的实现(一) javaScript中浅拷

  • 一、引言 对象拷贝(Object Copy)就是将一个对象的属性拷贝到另一个有着相同类类型的对象中去。在程序中拷贝对象是很常见的,主要是为了在新的上下文环境中复用对象的部分或全部数据。Java中有三种类型的对象拷贝:浅拷贝(Shallow Copy)、深拷贝(Deep Copy)、延迟拷贝(Lazy Copy)。 二、浅拷贝 1、什么是浅拷贝 浅拷贝是按位拷贝对象,它会创建一个新对象,这个对象有着

  • 本文向大家介绍javascript深拷贝和浅拷贝详解,包括了javascript深拷贝和浅拷贝详解的使用技巧和注意事项,需要的朋友参考一下 一、数组的深浅拷贝 在使用JavaScript对数组进行操作的时候,我们经常需要将数组进行备份,事实证明如果只是简单的将它赋予其他变量,那么我们只要更改其中的任何一个,然后其他的也会跟着改变,这就导致了问题的发生。 这是为什么呢? 因为如果只是简单的赋值,它只

  • 本文向大家介绍iOS 深拷贝 和浅拷贝的区别相关面试题,主要包含被问及iOS 深拷贝 和浅拷贝的区别时的应答技巧和注意事项,需要的朋友参考一下 深拷贝和浅拷贝的区别? 答案:浅层复制:只复制指向对象的指针,而不复制引用对象本身。 深层复制:复制引用对象本身。 意思就是说我有个A对象,复制一份后得到A_copy对象后,对于浅复制来说,A和A_copy指向的是同一个内存资源,复制的只不过是是一个指针,

  • 本文向大家介绍深入理解python中的浅拷贝和深拷贝,包括了深入理解python中的浅拷贝和深拷贝的使用技巧和注意事项,需要的朋友参考一下 在讲什么是深浅拷贝之前,我们先来看这样一个现象: 为什么我只对b进行修改,却影响到了a呢?看过我在之前的文章中就说过:序列中保存的都是内存的引用。 所以,当我们通过b去修改里面的空列表的时候,其实就是修改内存中的同一个对象,所以会影响到a。 代码验证无误,所以