当前位置: 首页 > 工具软件 > truth > 使用案例 >

Pytorch测试数据集和Groundtruth绑定在一起

爱唯
2023-12-01

先前写过不同的方式,今天做一个简单的整理。

 

第一种,单独写一个类用于绑定测试图像及对应的groundtruth,详情参考以下代码(手敲,可能有拼写错误),然后就可以使用诸如dataset=TestDataSet(...)的方式导入测试数据集,在使用的时候就可以直接在dataloader中加载图像及对应的groundtruth.

import os
import torch
import numpy as np
import torch.utils.data as tud
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

class TestDataSet(tud.Dataset):
    def __init__(self,img_root,gt_root,img_transform=None,gt_tranform=None):
        self.img_root=img_root
        self.gt_root=gt_root
        self.img_list=sorted(os.listdir(self.img_root))
        self.gt_list=sorted(os.listdir(self.gt_root))
        self.file_num=len(self.img_list)
        self.img_tranform=img_transform
        self.gt_transform=gt_transform

    def __getitem__(self,index):
        img_name=os.path.join(self.img_root,self.img_list[index])
        gt_name=os.path.join(self.gt_root,self.gt_list[index])

        img=Image.open(img_name)
        gt=Image.open(gt_name)

        if self.img_transform is not None:
            img=self.img_transform(img)
        if self.gt_transform is not None:
            gt=self.gt_transform(gt)

        return img,gt

    def __len__(self):
        return self.file_num

以上方式虽然简单直观,但只适合一个子类别的情况。如果一个数据集下面有多个图像子类别,如果继续使用以上代码则需要根据子类别反复修改路径,这显然是不明智的。这时候,可以利用自身ImageFolder读取多个子类别的数据,只不过换一个粗暴一些的迭代器就好了,这时候应该也能保证input和groundtruth对应起来。(以下代码为手敲,只为表达意思,不保证准确性)

import os
import torch
from torchvision.datasets import ImageFolder
import numpy as np
import PIL.Image as Image
import torch.utils.data as tud
import torchvision.transforms as transforms


dataset["test"]=ImageFolder(os.path.join(opt.dataroot,opt.dataset,"test"))
dataset["gt"]=ImageFolder(os.path.join(opt.dataroot,opt.dataset,"ground_truth"))


iterTest=dataloader["test"].__iter__()
iterGt=dataloader["gt"].__iter__()
batchNum=len(dataloader["test"])

for i in range(batchNum):
    batchTest=iterTest.__next__()
    self.inputTest=batchTest[0].to(self.device)

    batchGt=iterGt.__next__()
    self.gt=batchGt[0].to(self.device)

....

最新补充,参考https://stackoverflow.com/questions/59467781/pytorch-dataloader-for-image-gt-dataset,直接在dataset里绑定img和对应的ground truth.核心思路就是自定义一个数据集类,然后重写make_datset函数。为了阅读方便,我做了一定的代码精简,鲁棒性变差,但是能保证我的场景下使用无误。以下代码纯手敲,可能存在typo,使用的时候可能需要小心,但是思路肯定是清晰的。

需要指出的是,以下代码对应的文件夹层级设置如下

---data

------test

---------color

------------001.png

------------002.png

------------003.png

---------blue

------------001.png

------------002.png

------------003.png

------ground truth

---------color

------------001_mask.png

------------002_mask.png

------------003_mask.png

---------blue

------------001_mask.png

------------002_mask.png

------------003_mask.png

 

import os
import sys
from torchvision.datasets.folder import default_loader
from torchvision.dataset.vision import VisionDataset

def make_img_gt_dataset(img_root,gt_root):
# Find sub-classes in the root. Since one sub-class in img_root should match one sub-class # in gt_root, we always suppose img_root and gt_root have folders with the same names

    if sys.version_infor>=(3,5):
        classes=[d.name for d in os.scandir(img_root) if d.is_dir()]
    else:
        classes=[d for d in os.listdir(img_root) if os.path.isdir(os.path.join(img_root,d))]

    classes.sort()

    images=[]
    
    for sub_class in classes:
        d1=os.path.join(img_root,sub_class)
        d2=os.path.join(gt_root,sub_class)

        img_names=sorted(os.listdir(d1))
        gt_names=sorted(os.listdir(d2))

        for img_name in img_names:
            img_name_without_ext=img_name[0:len(img_name)-4]
            gt_name=img_name_without_ext+"_mask.png"

            if gt_name in gt_names:
                img_path=os.path.join(img_root,sub_class,img_name)
                gt_path=os.path.join(gt_root,sub_class,gt_name)

                item=(img_path,gt_path)
                images.append(item)

    return images

class TestGtDataset(VisionDataset):
    def __init__(self,img_root,gt_root,loader=default_loader,img_transform=None,gt_transform=None):
    super().__init__(root=img_root,transform=img_transform,target_transform=gt_transform)

    self.loader=loader

    samples=make_img_gt_dataset(img_root,gt_root)
    self.samples=samples
    self.img_samples=[s[0] for s in samples]
    self.gt_samples=[s[1] for s in samples]

    def __getitem__(self,index):
        img_path,gt_path=self.samples[index]
        img_sample=self.loader(img_path)
        gt_sample=self.loader(gt_path)

        if self.transform is not None:
            img_sample=self.transform(img_sample)
        if self.target_transform is not None:
            gt_sample=self.target_transform(gt_sample)

        return img_sample,gt_sample

    def __len__(self):
        return len(self.samples)

 

 类似资料: