代码运行至下列语句时:
for i, data in enumerate(train_loader):
有时会遇到以下三种报错:
TypeError: img should be PIL Image. Got <class 'dict'>
TypeError: img should be PIL Image. Got <class 'Torch.Tensor'>
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>
这类问题往往出错在 dataset 的 getitem 函数处,因为这里涉及到对数据做 transform,transform 中又涉及到一些图像转换的问题。
class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
我们来看 call() 方法,这里有几行注释,大致意思是, call 方法接受一个 PIL Image 格式的输入,经过 resize 方法后 返回一个 PIL Image 格式的输出,也就是说, pytorch 官方 中的 transform 默认是需要一个 PIL Image 格式输入的 。而有很多朋友会采用不同的方式读取 image ,比如使用 cv2.imread() 函数,我们可以来测试一下:
import cv2
img = cv2.imread('data/davis2016/JPEGImages/480p/train/00000.jpg')
type(img)
# Out[7]: numpy.ndarray
可以看见,使用 cv2.imread 函数读取的 image 是 numpy.ndarray 格式的,这时候如果直接对这个 image 做 transform 就会出现类型不匹配问题,这时候,需要在你写的 train_transform 中加一个 transform.ToPILImage() 函数,例如:
train_transforms = t.Compose([t.ToPILImage(), # Here
t.RandomHorizontalFlip(),
t.Resize((480, 852)),
t.ToTensor()])
依次可以类推,transform.Compose() 方法其实就是把一系列我们要对 image 做的操作(数据预处理,数据增强等)排列到一起,因此,我们要保证其从第一个函数到最后一个函数的输入都要是 PIL Image 格式。那些遇见错误的同学,要么是没有将 PIL Image 格式的图像做为 transform.Compose() 方法输入,要么是虽然输入了 PIL Image 格式图像,但是在一些列操作未结束之前就将其转为了 Tensor,见下列代码:
train_transforms = t.Compose([t.ToPILImage(),
t.RandomHorizontalFlip(),
t.ToTensor(), # Here
t.Resize((480, 852))])
这时,不用运行,我们就知道,这里肯定出错了,因为刚刚我们验证过 transform.Resize() 的 call 方法需要接受一个 PIL Image 格式的图像,而你提前使用了 transform.ToTensor() 方法将其转为了torch.Tensor 格式,这就肯定错了。
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>
你应该检查你的图像是否为 PIL Image 格式,如果不是,可以使用 transform.ToPILImage() 方法。
TypeError: img should be PIL Image. Got <class 'Torch.Tensor'>
你应该检查你的 transform.ToTensor() 方法是否写在了你要做的操作之前,如果是,调换一下它们的位置。
TypeError: img should be PIL Image. Got <class 'dict'>
这个应该很少有人会遇见,这是我需要将一个 img 和它的 gt 做为一个字典一起返回的时候遇见的一个错误:
def __getitem__(self, idx):
img = readImage(self.img_list[idx], channel=3)
gt = readImage(self.mask_list[idx], channel=1)
sample = {'images': img, 'gts': gt}
if self.transform is not None:
sample = self.transform(sample)
return sample
解决方法是,先分别对 img 和 gt 做 transform 再把它们组合到一个字典里:
def __getitem__(self, idx):
img = readImage(self.img_list[idx], channel=3)
gt = readImage(self.mask_list[idx], channel=1)
if self.transform is not None:
img = self.transform(img)
gt = self.transform(gt)
sample = {'images': img, 'gts': gt}
return sample
import PIL.Image
def readImage(img_path, channel=3):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
if not os.path.exists(img_path):
raise IOError("{} does not exist".format(img_path))
while not got_img:
if channel == 3:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
elif channel == 1:
try:
img = Image.open(img_path).convert('1')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
train_transforms = t.Compose([t.RandomHorizontalFlip(),
t.Resize((480, 852)),
t.ToTensor()])
注意:transform.ToTensor() 最好写在最后。
ps:其实…也有一些方法(例如,随即擦除,标准化等)不需要输入格式为 PIL.Image,下面这样写是可以的:
transform_train = T.Compose([
T.Random2DTranslation((256, 128)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.RandomErasing(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
具体情况还是要自己多试一试,多去看看官方的写法。