当前位置: 首页 > 工具软件 > dataloader > 使用案例 >

Dataloader的使用

隆钊
2023-12-01

本文主要使用CIFAR10数据集来讲解Dataloader的使用方法,并写入tensorboard中,可以更好的去查看。



前言

在pytorch中如何读取数据主要有两个类,分别是Dataset和Dataloader。
dataset可以理解为:提供一种方式去获取数据及其label(标签)。
可以实现(1)如何获取每一个数据及其label;(2)总共有多少数据。这两个功能。

dataloader可以理解为:为后面的网络提供不同的数据形式。

相比较dataset,dataloader更难一点,前面我们讲过dataset的使用了,今天我们来讲dataloader的基本用法


一、DataLoader类的官方解释

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

看着是不是很多,没关系,下面我们来一一讲解。(其中很多参数都是有默认值的,我们一般写程序时,用的较多的是加粗的几个参数)

详细注释:

dataset (Dataset) – 从中加载数据的数据集。
batch_size (int, optional) – 每批要加载多少个样本(默认值:1)。
shuffle (bool, optional) – 设置为 True 让数据在每个 epoch 重新洗牌(默认值:False)。
sampler(Sampler 或 Iterable,可选)——定义从数据集中抽取样本的策略。可以是任何实现了 len 的 Iterable。如果指定,则不得指定 shuffle。
batch_sampler(Sampler 或 Iterable,可选)- 类似于 sampler,但一次返回一批索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
num_workers (int, optional) – 用于数据加载的子进程数。 0 表示数据将在主进程中加载​​。 (默认值:0)
collat​​e_fn(可调用,可选)——合并样本列表以形成小批量张量。从地图样式数据集中使用批量加载时使用。
pin_memory (bool, optional) – 如果为 True,数据加载器将在返回之前将张量复制到 CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的 collat​​e_fn 返回一个自定义类型的批次,请参见下面的示例。
drop_last (bool, optional) – 如果数据集大小不能被批次大小整除,则设置为 True 以丢弃最后一个不完整的批次。如果 False 并且数据集的大小不能被批大小整除,则最后一批将更小。 (默认:假)
timeout (numeric, optional) -- 如果为正,则从工人那里收集批次的超时值。应始终为非负数。 (默认值:0)
worker_init_fn (callable, optional) – 如果不是None,这将在播种之后和数据加载之前以worker id([0,num_workers - 1]中的一个int)作为输入在每个worker子进程上调用。 (默认:无)

二、使用方法

1.准备调试的数据集

使用CIFAR10数据集,并用totensor将数据转换成tensor的数据类型,代码如下:

import torchvision
test_data = torchvision.datasets.CIFAR10(
root="./dataset", 
train=False, 
transform=torchvision.transforms.ToTensor())

注释:
root="./dataset",     数据集保存在dataset文件中
train=False,          false代表测试数据集
transform=torchvision.transforms.ToTensor())   将数据集中的图片转换为tensor数据类型

接着,使用DataLoader。代码如下:

from torch.utils.data import DataLoader
test_loader = DataLoader(dataset=test_data, batch_size=128, shuffle=True, num_workers=0, drop_last=False  )

注释:
dataset=test_data,    从test_data中加载数据
batch_size=128,       一次加载128个样本
shuffle=True,          true表示 在每一次循环中都变化
num_workers=0,       在主程序进行
drop_last=False         保留最后一个不完整的批次

2.查看DataLoader的结果

因为数据已经转换成tensor的数据类型了,所以可以直接写入tensorboard中查看,代码如下:

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")   #写入tensorboard中

for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs,targets = data

        writer.add_images("epoch : {}".format(epoch), imgs, step)    
        #因为不是一张图片,所以使用images
        step = step + 1
writer.close()

3.完整代码

from torch.utils.data import DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter


#准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=128, shuffle=True, num_workers=0, drop_last=False  )

#测试数据集中第一张图片及target
# img, target = test_data[0]
# print(img.shape)
# print(target)

writer = SummaryWriter("logs")   #写入tensorboard中

for epoch in range(2):  #进行两次循环
    step = 0
    for data in test_loader:
        imgs,targets = data
        # print(imgs.shape)
        # print(targets)

        writer.add_images("epoch : {}".format(epoch), imgs, step)    #因为不是一张图片,所以使用images
        step = step + 1
writer.close()


总结

CIFAR10数据集很适合我们初期去学习,DataLoader直接与后面的神经网络有关,所以要重点学习。

 类似资料: