class Mydata(Dataset):
def __init__(self, root, train=True, transform=None,
target_transform=None):
self.root = root
self.train = train
self.transform = transform
self.target_transform = target_transform
file_list = os.listdir(self.root)
self.train_label = []
self.train_data = []
self.test_label = []
self.test_data = []
self.label_dic = {'nude': 0, 'nonnude': 1}
for file_name in file_list:
file_name = self.root + file_name
imge = Image.open(file_name).convert('RGB')
if self.train:
self.train_label.append(self.label_dic[re.split('_|\.|\d{2}|\d', file_name)[-3]]) #正则化含义
self.train_data.append(imge)
else:
self.test_label.append(self.label_dic[re.split('_|\.|\d{2}|\d', file_name)[-3]])
self.test_data.append(imge)
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_label[index]
else:
img, target = self.test_data[index], self.test_label[index]
if self.transform is not None:
imgee = self.transform(img)
#print(imgee.shape)
if self.target_transform is not None:
target = self.target_transform(target)
return imgee, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
正则化参考:
https://blog.csdn.net/weixin_40136018/article/details/81183504