当前位置: 首页 > 工具软件 > dataset > 使用案例 >

Dataset基础创建、对不同数据集处理及数据可视化

骆磊
2023-12-01

创建Dataset子类获取数据

这个子类必须实现两个方法:
1、getitem(self, index):获取通过索引返回图片(切片),通过dataset[:]方法
2、len(self):获取数据集长度,通过len(dataset)方法

class MyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
    def __getitem__(self, index):   
    def __len__(self):

init方法实现

init:实现得到图片路径列表
方法1、将图片路径放入初始化函数中
这里是有两种类型图片,并且通过glob得到文件列表

glob模块提供了一个函数用于从目录通配符搜索中生成文件列表

type1dir = glob.glob(/*/*.jpg)
type2dir = glob.glob(/*/*.png)
def __init__(self, type1dir, type2dir) -> None:
        super().__init__()
        self.t1dir = type1dir
        self.t2dir = type2dir

方法2、在初始化函数中进行文件列表的转换

os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。

type1dir = “/*/*.jpg”
type2dir = “/*/*.png”
def __init__(self, type1dir, type2dir) -> None:
        super().__init__()
        self.t1dir = type1dir
        self.t2dir = type2dir
        self.t1dir_path = os.listdir(self.t1dir)
        self.t2dir_path = os.listdir(self.t2dir)

getitem方法实现

通过传入index参数对图片文件列表进行切片

def __getitem__(self, index):
        type1item = self.t1dir[index]
        type2item = self.t2dir[index]
        return type1item, type2item

len方法实现

返回整个dataset长度

def __len__(self):
        return len(self.t1dir)

transform方法实现

首先定义transforms类

ToTensor(): 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
Resize(): 重置图像分辨率
Normalize(): 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc

transforms = transforms.Compose([transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)])

同时在DataSet子类中引用:

def __init__(self, type1dir, type2dir, transforms) -> None:
        super().__init__()
        self.t1dir = type1dir
        self.t2dir = type2dir
        self.transforms = transforms
def __getitem__(self, index):
        t1item = self.t1dir[index]
        t2item = self.t2dir[index]
        pil_t1item = Image.open(t1item)
        pil_t2item = Image.open(t2item)
        data1 = self.transforms(pil_t1item)
        data2 = self.transforms(pil_t2item)
        return data1, data2

数据可视化

思路

1、设置画布大小
2、将数据格式转换为narray类型,并调整格式
3、循环读取数据
4、将数据加入画布,并设置画布细节
5、展示画布

读取dataloader数据(tensor类型)

figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True) 相当于创建画布大小
num:图像编号或名称,数字为编号 ,字符串为名称
figsize:指定figure的宽和高,单位为英寸;
dpi参数指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80 1英寸等于2.5cm,A4纸是 21*30cm的纸张
facecolor:背景颜色
edgecolor:边框颜色
frameon:是否显示边框

zip: 返回元组的迭代器,其中第 i 个元组包含来自每个参数序列或可迭代对象的第 i 个元素

permute(dims): 将tensor的维度换位。

plt.subplot(nrows, ncols, plot_number)
这个函数用来表示把figure分成nrows*ncols的子图表示,
nrows:子图的行数
ncols:子图的列数
plot_number 索引值,表示把图画在第plot_number个位置

plt.title(): 函数用于设置图像标题。

plt.imshow(X, interpolation=None)
X:图像数据
(M, N):标量数据的图像,灰度图
(M, N, 3):RGB图像

plt.axis(‘square’): 作图为正方形,并且x,y轴范围相同,即
plt.axis(‘equal’): x,y轴刻度等长
plt.axis(‘off’): 关闭坐标轴 官网上也贴出了其他的一些选项

plt.show(): 展示画布

plt.figure(figsize=(12, 8))
for step, (img1, img2) enumerate(zip(img1_loader, img2_loader)):
	#将channel放后面,h、w放前面, 并转换成ndarray类型
	img1 = img1.permute(1, 2, 0).numpy()
	# 图片分为2行3列
	plt.subplot(2, 3, step + 1)
	# title自己设的标签名称
	plt.title(label[i])
	plt.imshow(img1)
	#取消坐标轴
	plt.axis('off')
	plt.savefig(f'-/image_at_epoch_{format(epoch)}.png')
plt.show()

这里假设使用测试数据集进行绘图(Gan模型)

def test(Gen, testImg_input, testFake_input):
	testImg_output = Gen(testImg_input).detach().numpy().permute(0, 2, 3, 1).cpu()
	testImg_input = testImg_input.permute(0, 2, 3, 1).numpy().cpu()
	testFake_input = testFake_input.permute(0, 2, 3, 1).numpu().cpu()
	plt.figure(figsize=(10, 10))
	for step, (img, _) in enumerate(test_dataloader):
		plt.subplot(

读取dataset(imgLIst)

plt.figure(figsize=(8, 8))
for i , img in enumerate(train_a[:4]):
    img = Image.open(img)
    img = np.array(img)
    plt.subplot(2, 2, i + 1)
    plt.imshow(img)
    plt.title(str(img.shape))
    plt.axis("off")
plt.show()
 类似资料: