torchvision.datasets
优质
小牛编辑
137浏览
2023-12-01
译者:BXuan694
所有的数据集都是torch.utils.data.Dataset
的子类, 即:它们实现了__getitem__
和__len__
方法。因此,它们都可以传递给torch.utils.data.DataLoader
,进而通过torch.multiprocessing
实现批数据的并行化加载。例如:
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
目前为止,收录的数据集包括:
数据集
- MNIST
- Fashion-MNIST
- EMNIST
- COCO
- Captions
- Detection
- LSUN
- ImageFolder
- DatasetFolder
- Imagenet-12
- CIFAR
- STL10
- SVHN
- PhotoTour
- SBU
- Flickr
- VOC
以上数据集的接口基本上很相近。它们至少包括两个公共的参数transform
和target_transform
,以便分别对输入和和目标做变换。
class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)
MNIST数据集。
参数:
- root(string)– 数据集的根目录,其中存放
processed/training.pt
和processed/test.pt
文件。 - train(bool, 可选)– 如果设置为True,从
training.pt
创建数据集,否则从test.pt
创建。 - download(bool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.RandomCrop
。 - target_transform (可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)
Fashion-MNIST数据集。
参数:
- root(string)– 数据集的根目录,其中存放
processed/training.pt
和processed/test.pt
文件。 - train(bool, 可选)– 如果设置为True,从
training.pt
创建数据集,否则从test.pt
创建。 - download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.RandomCrop
。 - target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
class torchvision.datasets.EMNIST(root, split, **kwargs)
EMNIST数据集。
参数:
- root(string)– 数据集的根目录,其中存放
processed/training.pt
和processed/test.pt
文件。 - split(string)– 该数据集分成6种:
byclass
,bymerge
,balanced
,letters
,digits
和mnist
。这个参数指定了选择其中的哪一种。 - train(bool, 可选)– 如果设置为True,从
training.pt
创建数据集,否则从test.pt
创建。 - download(bool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.RandomCrop
。 - target_transform(可被调用 , 可选) – 一种函数或变换,输入目标,进行变换。
注意:
以下要求预先安装COCO API。
class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None)
MS Coco Captions数据集。
参数:
- root(string)– 下载数据的目标目录。
- annFile(string)– json标注文件的路径。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.ToTensor
。 - target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
示例
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample
print("Image Size: ", img.size())
print(target)
输出:
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | 元组(image, target),其中target是列表类型,包含了对图片image的描述。 |
--- | --- |
返回类型: | tuple |
--- | --- |
class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None)
MS Coco Detection数据集。
参数:
- root(string)– 下载数据的目标目录。
- annFile(string)– json标注文件的路径。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.ToTensor
。 - target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | 元组(image, target),其中target是coco.loadAnns 返回的对象。 |
--- | --- |
返回类型: | tuple |
--- | --- |
class torchvision.datasets.LSUN(root, classes='train', transform=None, target_transform=None)
LSUN数据集。
参数:
- root(string)– 存放数据文件的根目录。
- classes(string 或 list)– {‘train’, ‘val’, ‘test’}之一,或要加载类别的列表,如[‘bedroom_train’, ‘church_train’]。
- transform(可被调用 , 可选) – 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.RandomCrop
。 - target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | 元组(image, target),其中target是目标类别的索引。 |
--- | --- |
Return type: | tuple |
--- | --- |
class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)
一种通用数据加载器,其图片应该按照如下的形式保存:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
参数:
- root(string)– 根目录路径。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.RandomCrop
。 - target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
- loader – 一种函数,可以由给定的路径加载图片。
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | (sample, target),其中target是目标类的类索引。 |
--- | --- |
返回类型: | tuple |
--- | --- |
class torchvision.datasets.DatasetFolder(root, loader, extensions, transform=None, target_transform=None)
一种通用数据加载器,其数据应该按照如下的形式保存:
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
参数:
- root(string)– 根目录路径。
- loader(可被调用)– 一种函数,可以由给定的路径加载数据。
- extensions(list[string])– 列表,包含允许的扩展。
- transform(可被调用 , 可选)– 一种函数或变换,输入数据,返回变换之后的数据。如:对于图片有
transforms.RandomCrop
。 - target_transform – 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | (sample, target),其中target是目标类的类索引. |
--- | --- |
返回类型: | tuple |
--- | --- |
这个类可以很容易地实现ImageFolder
数据集。数据预处理见此处。
示例。
class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
CIFAR10数据集。
参数:
- root(string)– 数据集根目录,要么其中应存在
cifar-10-batches-py
文件夹,要么当download设置为True时cifar-10-batches-py
文件夹保存在此处。 - train(bool, 可选)– 如果设置为True, 从训练集中创建,否则从测试集中创建。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.RandomCrop
。 - target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
- download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | (image, target),其中target是目标类的类索引。 |
--- | --- |
返回类型: | tuple |
--- | --- |
class torchvision.datasets.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
CIFAR100数据集。
这是CIFAR10
数据集的一个子集。
class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)
STL10数据集。
参数:
- root(string)– 数据集根目录,应该包含
stl10_binary
文件夹。 - split(string)– {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}之一,选择相应的数据集。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.RandomCrop
。 - target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
- download(bool, optional)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | (image, target),其中target应是目标类的类索引。 |
--- | --- |
返回类型: | tuple |
--- | --- |
class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)
SVHN数据集。注意:SVHN数据集将10
指定为数字0
的标签。然而,这里我们将0
指定为数字0
的标签以兼容PyTorch的损失函数,因为损失函数要求类标签在[0, C-1]
的范围内。
参数:
- root(string)– 数据集根目录,应包含
SVHN
文件夹。 - split(string)– {‘train’, ‘test’, ‘extra’}之一,相应的数据集会被选择。‘extra’是extra训练集。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:
transforms.RandomCrop
。 - target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
- download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | (image, target),其中target是目标类的类索引。 |
--- | --- |
返回类型: | tuple |
--- | --- |
class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)
Learning Local Image Descriptors Data数据集。
参数:
- root(string)– 保存图片的根目录。
- name(string)– 要加载的数据集。
- transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。
- download (bool, optional) – 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数: | index (int) – 索引 |
---|---|
返回: | (data1, data2, matches) |
--- | --- |
返回类型: | tuple |
--- | --- |