官方解释:Dataloader 组合了 dataset & sampler,提供在数据上的 iterable
主要参数:
1、dataset:这个dataset一定要是torch.utils.data.Dataset本身或继承自它的类
里面最主要的方法是 __getitem__(self, index) 用于根据index索引来取数据的
2、batch_size:每个batch批次要返回几条数据
3、shuffle:是否打乱数据,默认False
4、sampler:sample strategy,数据选取策略,有它就不用shuffle了,因为sample本身就是一种无序。这个sampler貌似也一定要是torch.utils.data.sampler.Sampler本身或继承自它的类。
里面最主要的方法是__iter__(self) 方法,每次调用 iter 只能获取 batchsize 个数据,也就是一个批次的数据。
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
5、…… 后面就不说了
这里先贴一段我的代码:
trainloader = DataLoader(
ImageDataset(self.dataset.train, transform=self.transform_train),
# 为传入的数据中的每个id选择config.k个样本
sampler=ClassUniformlySampler(self.dataset.train, class_position=1, k=config.k), # 传入的数据中第2维是类别,所以class_position=1
batch_size=config.p * config.k, num_workers=config.workers,
# shuffle=True, # 有了ClassUniformlySampler就不用shuffle了
pin_memory=pin_memory, drop_last=False
)
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
print("batch_idx: ", batch_idx)
for i in range(len(pids)):
print(pids[i], imgs[i].shape)
一开始我并不真的明白它的内部原理,为什么执行 enumerate 代码,就可以源源不断地返回所需数据,后来我跟了一下整个代码才明白(如果你也在这个地方范迷糊,可以继续往下看,如果没有,则可离开)
在执行 trainloader = DataLoader()语句的时候,DataLoder,ImageDataset,ClassUniformlySampler 并没有什么特殊的操作,都仅仅是init初始化了一下。
这里所使用的 ClassUniformlySampler ,是Sampler类的一种,作用是对数据中的所有id仅保留k条数据。因此它在初始化时,生成了一个字典,key为类别,value为属于该类别的所有数据的索引。(这里仅作讲解使用,无需深入学习)
代码取自他处,已难寻根,在此标注一下**
class ClassUniformlySampler(Sampler):
'''
random sample according to class label
Arguments:
data_source (Dataset): data_loader to sample from
class_position (int): which one is used as class
k (int): sample k images of each class
'''
def __init__(self, data_source, class_position, k):
self.class_position = class_position
self.k = k
self.samples = data_source
self.class_dict = self._tuple2dict(self.samples) # 返回一个字典,key为类别,value为属于该类别的所有数据的索引
def __iter__(self):
self.sample_list = self._generate_list(self.class_dict)
return iter(self.sample_list)
def __len__(self):
return len(self.sample_list)
def _tuple2dict(self, inputs):
'''
:param inputs: list with tuple elemnts, [(image_path1, class_index_1), (imagespath_2, class_index_2), ...]
:return: dict, {class_index_i: [samples_index1, samples_index2, ...]}
'''
dict = {}
for index, each_input in enumerate(inputs):
class_index = each_input[self.class_position]
if class_index not in list(dict.keys()):
dict[class_index] = [index]
else:
dict[class_index].append(index)
return dict
def _generate_list(self, dict):
'''
:param dict: dict, whose values are list
:return:
'''
sample_list = []
dict_copy = dict.copy()
keys = list(dict_copy.keys())
random.shuffle(keys)
for key in keys:
value = dict_copy[key]
if len(value) >= self.k:
random.shuffle(value)
sample_list.extend(value[0: self.k])
else:
value = value * self.k
random.shuffle(value)
sample_list.extend(value[0: self.k])
return sample_list
在第一次执行 for batch_idx, (imgs, pids, _) in enumerate(trainloader) 时,首先调用的是sampler.__iter__() 方法,对所有数据进行采样后返回一个存储了所采样的数据的索引列表,并用iter(sampler_list) 作为返回。iter方法在一开始已经提及,每次调用只能返回 batchsize 条数据。
随后,Dataset就上场了,它只需根据 sampler_list 中的索引挨个取数据即可,取到第 batchsize 条数据的时候,iter 就不会再让它取了。
这之后,每一次执行 for batch_idx, (imgs, pids, _) in enumerate(trainloader) 时,Dataset 都会从上一次iter中断的数据索引处继续取 batchsize 个数据,直到取完所有数据。
注:因为在采样时,已经打乱了原有的数据顺序,对于采样后返回的sample_list,即使按顺序取,也不是真的有序,而且这样还可以防止重复抽取到相同数据,数据取完就可以结束一个epoch