当前位置: 首页 > 工具软件 > dataloader > 使用案例 >

pytorch之dataloader,enumerate

赵景曜
2023-12-01

 对shuffle=True的理解:
之前不了解shuffle的实际效果,假设有数据a,b,c,d,不知道batch_size=2后打乱,具体是如下哪一种情况:
1.先按顺序取batch,对batch内打乱,即先取a,b,a,b进行打乱;
2.先打乱,再取batch。
证明是第二种。

from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b)#封装数据a与标签b

# 切片输出
print(train_ids[0:2])
print('=' * 80)
# 循环取数据
for x_train, y_label in train_ids:
    print(x_train, y_label)
# DataLoader进行数据封装
print('=' * 80)

train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader):  # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_data, label = data
    print(' batch:{0}\n x_data:{1}\nlabel: {2}'.format(i, x_data, label))



######################
for i, data in enumerate(train_loader,1):  # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_data, label = data
    print(' batch:{0}\n x_data:{1}\nlabel: {2}'.format(i, x_data, label))

 Dataloader:传入数据(这个数据包括:训练数据和标签),batchsize(代表的是将数据分成batch=[len(train_ids[0])除以batchsize],每一份包括的数据是batchsize)

enumerate:返回值有两个:一个是序号,也就是在这里的batch地址,一个是数据train_ids

for i, data in enumerate(train_loader,1):此代码中1,是batch从batch=1开始,也就是batch的地址是从1开始算起,不是0开始算起。batch仍然是3个。就算batch从8开始,他也是三份,分别是8,9,10

E:\软件安装\python3.7\python.exe E:/软件安装/code/RSN-master/Res2net.py
train_ids=
 <torch.utils.data.dataset.TensorDataset object at 0x0000000002836EC8>
(tensor([[1, 2, 3],
        [4, 5, 6]]), tensor([44, 55]))
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
 batch:0
 x_data:tensor([[1, 2, 3],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3]])
label: tensor([44, 44, 55, 44])
 batch:1
 x_data:tensor([[4, 5, 6],
        [4, 5, 6],
        [7, 8, 9],
        [7, 8, 9]])
label: tensor([55, 55, 66, 66])
 batch:2
 x_data:tensor([[4, 5, 6],
        [1, 2, 3],
        [7, 8, 9],
        [7, 8, 9]])
label: tensor([55, 44, 66, 66])


###################################################
batch:1
 x_data:tensor([[7, 8, 9],
        [1, 2, 3],
        [1, 2, 3],
        [4, 5, 6]])
label: tensor([66, 44, 44, 55])
 batch:2
 x_data:tensor([[7, 8, 9],
        [4, 5, 6],
        [7, 8, 9],
        [1, 2, 3]])
label: tensor([66, 55, 66, 44])
 batch:3
 x_data:tensor([[4, 5, 6],
        [7, 8, 9],
        [1, 2, 3],
        [4, 5, 6]])
label: tensor([55, 66, 44, 55])

Process finished with exit code 0


Process finished with exit code 0

 类似资料: