当前位置: 首页 > 工具软件 > Subset > 使用案例 >

对Subset做transform

牛智志
2023-12-01

对于一个自己的数据集,如果已经用torch.utils.data.random_split(dataset, [lengths_sequence]) 划分好了训练集,测试集,验证集,而又需要对训练集做数据增强,那么需要定义一个类来做处理:

datadir = my_data_path
all_dataset = datasets.ImageFolder(
	datadir,
	transforms.Compose([transforms.Resize([64,64]), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
	)
all_length = len(all_dataset)
train_size = int(all_length * 0.7)
val_size = int(all_length * 0.1)
test_size = all_length - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size, test_size], torch.Generator().manual_seed(40))

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.2),
    transforms.RandomVerticalFlip(p=0.2)
])

# 定义数据增强训练集
class train_dataset_transformed(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

train_dataset = train_dataset_transformed(train_dataset, transform=train_transform)

于是得到数据增强后的训练集。

参考:

  1. https://discuss.pytorch.org/t/torch-utils-data-dataset-random-split/32209
 类似资料: