当前位置: 首页 > 工具软件 > nude > 使用案例 >

读取数据集nude

戚峻
2023-12-01
class Mydata(Dataset):
    def __init__(self, root, train=True, transform=None, 
                 target_transform=None):
        self.root = root
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        
        file_list = os.listdir(self.root)
        self.train_label = []
        self.train_data = []
        self.test_label = []
        self.test_data = []
        self.label_dic = {'nude': 0, 'nonnude': 1}
        for file_name in file_list:
            file_name = self.root + file_name
            imge = Image.open(file_name).convert('RGB')
            if self.train:  
                self.train_label.append(self.label_dic[re.split('_|\.|\d{2}|\d', file_name)[-3]]) #正则化含义
                self.train_data.append(imge)
            else:
                self.test_label.append(self.label_dic[re.split('_|\.|\d{2}|\d', file_name)[-3]])
                self.test_data.append(imge)
                
    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_label[index]
        else:
            img, target = self.test_data[index], self.test_label[index]
  
        

        if self.transform is not None:
            imgee = self.transform(img)
            #print(imgee.shape)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return imgee, target
    
    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

正则化参考:
https://blog.csdn.net/weixin_40136018/article/details/81183504

 类似资料: