当前位置: 首页 > 工具软件 > batch-loader > 使用案例 >

pytorch中的批训练(batch)

东弘扬
2023-12-01
  • 用pytorch进行批训练其实很简单,只要把数据放入DataLoader(可以把它看成一个收纳柜,它会帮你整理好)

    • 大概步骤:

      1. 生成XY数据

      2. XY数据转为dataset

        dataset = Data.TensorDataset(X,Y)
        
      3. dataset放入DataLoader

        loader = Data.DataLoader(
        	dataset=dataset,
            batch_size=batch_size, #分几批
            shuffle=True,  #是否打乱数据,默认为False
            num_workers=2,  # 用多线程读数据,默认0表示不使用多线程
        )
        

具体代码:

import torch
import torch.utils.data as Data

torch.manual_seed(1)   # 确保每次取的数据一致,使得实验结果一致

# 准备数据
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

# 转换成torch可以识别的Dataset
torch_dataset = Data.TensorDataset(x, y)    # 得到一个元组(x, y)

# 将dataset 放入DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=5,
    shuffle=True,  # 每次训练打乱数据, 默认为False
    num_workers=2,  # 使用多进行程读取数据, 默认0,为不使用多进程
)

if __name__ == "__main__":
    # train
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            print("Epoch: ", epoch, "| Step: ", step, "| batch x:",
                  batch_x.numpy(), "| batch y: ", batch_y.numpy())

:如果使用多进程就需要用 if __name__ == "__main__": ,将多进程操作的部分放到if __name__ == "__main__":下,定义部分可以不放;否则会报 RuntimeError: DataLoader worker (pid(s) ###, ###) exited unexpectedly

 类似资料: