对于一个自己的数据集,如果已经用torch.utils.data.random_split(dataset, [lengths_sequence]) 划分好了训练集,测试集,验证集,而又需要对训练集做数据增强,那么需要定义一个类来做处理:
datadir = my_data_path
all_dataset = datasets.ImageFolder(
datadir,
transforms.Compose([transforms.Resize([64,64]), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
)
all_length = len(all_dataset)
train_size = int(all_length * 0.7)
val_size = int(all_length * 0.1)
test_size = all_length - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size, test_size], torch.Generator().manual_seed(40))
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.2),
transforms.RandomVerticalFlip(p=0.2)
])
# 定义数据增强训练集
class train_dataset_transformed(Dataset):
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
x, y = self.subset[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.subset)
train_dataset = train_dataset_transformed(train_dataset, transform=train_transform)
于是得到数据增强后的训练集。
参考: