这个子类必须实现两个方法:
1、getitem(self, index):获取通过索引返回图片(切片),通过dataset[:]方法
2、len(self):获取数据集长度,通过len(dataset)方法
class MyDataset(Dataset):
def __init__(self) -> None:
super().__init__()
def __getitem__(self, index):
def __len__(self):
init:实现得到图片路径列表
方法1、将图片路径放入初始化函数中
这里是有两种类型图片,并且通过glob得到文件列表
glob模块提供了一个函数用于从目录通配符搜索中生成文件列表
type1dir = glob.glob(/*/*.jpg)
type2dir = glob.glob(/*/*.png)
def __init__(self, type1dir, type2dir) -> None:
super().__init__()
self.t1dir = type1dir
self.t2dir = type2dir
方法2、在初始化函数中进行文件列表的转换
os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
type1dir = “/*/*.jpg”
type2dir = “/*/*.png”
def __init__(self, type1dir, type2dir) -> None:
super().__init__()
self.t1dir = type1dir
self.t2dir = type2dir
self.t1dir_path = os.listdir(self.t1dir)
self.t2dir_path = os.listdir(self.t2dir)
通过传入index参数对图片文件列表进行切片
def __getitem__(self, index):
type1item = self.t1dir[index]
type2item = self.t2dir[index]
return type1item, type2item
返回整个dataset长度
def __len__(self):
return len(self.t1dir)
首先定义transforms类
ToTensor(): 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
Resize(): 重置图像分辨率
Normalize(): 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.5)])
同时在DataSet子类中引用:
def __init__(self, type1dir, type2dir, transforms) -> None:
super().__init__()
self.t1dir = type1dir
self.t2dir = type2dir
self.transforms = transforms
def __getitem__(self, index):
t1item = self.t1dir[index]
t2item = self.t2dir[index]
pil_t1item = Image.open(t1item)
pil_t2item = Image.open(t2item)
data1 = self.transforms(pil_t1item)
data2 = self.transforms(pil_t2item)
return data1, data2
1、设置画布大小
2、将数据格式转换为narray类型,并调整格式
3、循环读取数据
4、将数据加入画布,并设置画布细节
5、展示画布
figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True) 相当于创建画布大小
num:图像编号或名称,数字为编号 ,字符串为名称
figsize:指定figure的宽和高,单位为英寸;
dpi参数指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80 1英寸等于2.5cm,A4纸是 21*30cm的纸张
facecolor:背景颜色
edgecolor:边框颜色
frameon:是否显示边框
zip: 返回元组的迭代器,其中第 i 个元组包含来自每个参数序列或可迭代对象的第 i 个元素
permute(dims): 将tensor的维度换位。
plt.subplot(nrows, ncols, plot_number)
这个函数用来表示把figure分成nrows*ncols的子图表示,
nrows:子图的行数
ncols:子图的列数
plot_number 索引值,表示把图画在第plot_number个位置
plt.title(): 函数用于设置图像标题。
plt.imshow(X, interpolation=None)
X:图像数据
(M, N):标量数据的图像,灰度图
(M, N, 3):RGB图像
plt.axis(‘square’): 作图为正方形,并且x,y轴范围相同,即
plt.axis(‘equal’): x,y轴刻度等长
plt.axis(‘off’): 关闭坐标轴 官网上也贴出了其他的一些选项
plt.show(): 展示画布
plt.figure(figsize=(12, 8))
for step, (img1, img2) enumerate(zip(img1_loader, img2_loader)):
#将channel放后面,h、w放前面, 并转换成ndarray类型
img1 = img1.permute(1, 2, 0).numpy()
# 图片分为2行3列
plt.subplot(2, 3, step + 1)
# title自己设的标签名称
plt.title(label[i])
plt.imshow(img1)
#取消坐标轴
plt.axis('off')
plt.savefig(f'-/image_at_epoch_{format(epoch)}.png')
plt.show()
def test(Gen, testImg_input, testFake_input):
testImg_output = Gen(testImg_input).detach().numpy().permute(0, 2, 3, 1).cpu()
testImg_input = testImg_input.permute(0, 2, 3, 1).numpy().cpu()
testFake_input = testFake_input.permute(0, 2, 3, 1).numpu().cpu()
plt.figure(figsize=(10, 10))
for step, (img, _) in enumerate(test_dataloader):
plt.subplot(
plt.figure(figsize=(8, 8))
for i , img in enumerate(train_a[:4]):
img = Image.open(img)
img = np.array(img)
plt.subplot(2, 2, i + 1)
plt.imshow(img)
plt.title(str(img.shape))
plt.axis("off")
plt.show()