先前写过不同的方式,今天做一个简单的整理。
第一种,单独写一个类用于绑定测试图像及对应的groundtruth,详情参考以下代码(手敲,可能有拼写错误),然后就可以使用诸如dataset=TestDataSet(...)的方式导入测试数据集,在使用的时候就可以直接在dataloader中加载图像及对应的groundtruth.
import os
import torch
import numpy as np
import torch.utils.data as tud
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
class TestDataSet(tud.Dataset):
def __init__(self,img_root,gt_root,img_transform=None,gt_tranform=None):
self.img_root=img_root
self.gt_root=gt_root
self.img_list=sorted(os.listdir(self.img_root))
self.gt_list=sorted(os.listdir(self.gt_root))
self.file_num=len(self.img_list)
self.img_tranform=img_transform
self.gt_transform=gt_transform
def __getitem__(self,index):
img_name=os.path.join(self.img_root,self.img_list[index])
gt_name=os.path.join(self.gt_root,self.gt_list[index])
img=Image.open(img_name)
gt=Image.open(gt_name)
if self.img_transform is not None:
img=self.img_transform(img)
if self.gt_transform is not None:
gt=self.gt_transform(gt)
return img,gt
def __len__(self):
return self.file_num
以上方式虽然简单直观,但只适合一个子类别的情况。如果一个数据集下面有多个图像子类别,如果继续使用以上代码则需要根据子类别反复修改路径,这显然是不明智的。这时候,可以利用自身ImageFolder读取多个子类别的数据,只不过换一个粗暴一些的迭代器就好了,这时候应该也能保证input和groundtruth对应起来。(以下代码为手敲,只为表达意思,不保证准确性)
import os
import torch
from torchvision.datasets import ImageFolder
import numpy as np
import PIL.Image as Image
import torch.utils.data as tud
import torchvision.transforms as transforms
dataset["test"]=ImageFolder(os.path.join(opt.dataroot,opt.dataset,"test"))
dataset["gt"]=ImageFolder(os.path.join(opt.dataroot,opt.dataset,"ground_truth"))
iterTest=dataloader["test"].__iter__()
iterGt=dataloader["gt"].__iter__()
batchNum=len(dataloader["test"])
for i in range(batchNum):
batchTest=iterTest.__next__()
self.inputTest=batchTest[0].to(self.device)
batchGt=iterGt.__next__()
self.gt=batchGt[0].to(self.device)
....
最新补充,参考https://stackoverflow.com/questions/59467781/pytorch-dataloader-for-image-gt-dataset,直接在dataset里绑定img和对应的ground truth.核心思路就是自定义一个数据集类,然后重写make_datset函数。为了阅读方便,我做了一定的代码精简,鲁棒性变差,但是能保证我的场景下使用无误。以下代码纯手敲,可能存在typo,使用的时候可能需要小心,但是思路肯定是清晰的。
需要指出的是,以下代码对应的文件夹层级设置如下
---data
------test
---------color
------------001.png
------------002.png
------------003.png
---------blue
------------001.png
------------002.png
------------003.png
------ground truth
---------color
------------001_mask.png
------------002_mask.png
------------003_mask.png
---------blue
------------001_mask.png
------------002_mask.png
------------003_mask.png
import os
import sys
from torchvision.datasets.folder import default_loader
from torchvision.dataset.vision import VisionDataset
def make_img_gt_dataset(img_root,gt_root):
# Find sub-classes in the root. Since one sub-class in img_root should match one sub-class # in gt_root, we always suppose img_root and gt_root have folders with the same names
if sys.version_infor>=(3,5):
classes=[d.name for d in os.scandir(img_root) if d.is_dir()]
else:
classes=[d for d in os.listdir(img_root) if os.path.isdir(os.path.join(img_root,d))]
classes.sort()
images=[]
for sub_class in classes:
d1=os.path.join(img_root,sub_class)
d2=os.path.join(gt_root,sub_class)
img_names=sorted(os.listdir(d1))
gt_names=sorted(os.listdir(d2))
for img_name in img_names:
img_name_without_ext=img_name[0:len(img_name)-4]
gt_name=img_name_without_ext+"_mask.png"
if gt_name in gt_names:
img_path=os.path.join(img_root,sub_class,img_name)
gt_path=os.path.join(gt_root,sub_class,gt_name)
item=(img_path,gt_path)
images.append(item)
return images
class TestGtDataset(VisionDataset):
def __init__(self,img_root,gt_root,loader=default_loader,img_transform=None,gt_transform=None):
super().__init__(root=img_root,transform=img_transform,target_transform=gt_transform)
self.loader=loader
samples=make_img_gt_dataset(img_root,gt_root)
self.samples=samples
self.img_samples=[s[0] for s in samples]
self.gt_samples=[s[1] for s in samples]
def __getitem__(self,index):
img_path,gt_path=self.samples[index]
img_sample=self.loader(img_path)
gt_sample=self.loader(gt_path)
if self.transform is not None:
img_sample=self.transform(img_sample)
if self.target_transform is not None:
gt_sample=self.target_transform(gt_sample)
return img_sample,gt_sample
def __len__(self):
return len(self.samples)