用pytorch进行批训练其实很简单,只要把数据放入DataLoader
(可以把它看成一个收纳柜,它会帮你整理好)
大概步骤:
生成X
,Y
数据
将X
,Y
数据转为dataset
dataset = Data.TensorDataset(X,Y)
将dataset
放入DataLoader
中
loader = Data.DataLoader(
dataset=dataset,
batch_size=batch_size, #分几批
shuffle=True, #是否打乱数据,默认为False
num_workers=2, # 用多线程读数据,默认0表示不使用多线程
)
import torch
import torch.utils.data as Data
torch.manual_seed(1) # 确保每次取的数据一致,使得实验结果一致
# 准备数据
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
# 转换成torch可以识别的Dataset
torch_dataset = Data.TensorDataset(x, y) # 得到一个元组(x, y)
# 将dataset 放入DataLoader
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=5,
shuffle=True, # 每次训练打乱数据, 默认为False
num_workers=2, # 使用多进行程读取数据, 默认0,为不使用多进程
)
if __name__ == "__main__":
# train
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print("Epoch: ", epoch, "| Step: ", step, "| batch x:",
batch_x.numpy(), "| batch y: ", batch_y.numpy())
注:如果使用多进程就需要用 if __name__ == "__main__":
,将多进程操作的部分放到if __name__ == "__main__":
下,定义部分可以不放;否则会报 RuntimeError: DataLoader worker (pid(s) ###, ###) exited unexpectedly
。