torchvision.datasets

优质
小牛编辑
133浏览
2023-12-01

译者:BXuan694

所有的数据集都是torch.utils.data.Dataset的子类, 即:它们实现了__getitem____len__方法。因此,它们都可以传递给torch.utils.data.DataLoader,进而通过torch.multiprocessing实现批数据的并行化加载。例如:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

目前为止,收录的数据集包括:

数据集

  • MNIST
  • Fashion-MNIST
  • EMNIST
  • COCO
    • Captions
    • Detection
  • LSUN
  • ImageFolder
  • DatasetFolder
  • Imagenet-12
  • CIFAR
  • STL10
  • SVHN
  • PhotoTour
  • SBU
  • Flickr
  • VOC

以上数据集的接口基本上很相近。它们至少包括两个公共的参数transformtarget_transform,以便分别对输入和和目标做变换。

class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

MNIST数据集。

参数:

  • rootstring)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • trainbool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • downloadbool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)

Fashion-MNIST数据集。

参数:

  • rootstring)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • trainbool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • downloadbool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
class torchvision.datasets.EMNIST(root, split, **kwargs)

EMNIST数据集。

参数:

  • rootstring)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • splitstring)– 该数据集分成6种:byclassbymergebalancedlettersdigitsmnist。这个参数指定了选择其中的哪一种。
  • trainbool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • downloadbool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选) – 一种函数或变换,输入目标,进行变换。

注意:

以下要求预先安装COCO API。

class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None)

MS Coco Captions数据集。

参数:

  • rootstring)– 下载数据的目标目录。
  • annFilestring)– json标注文件的路径。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.ToTensor
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。

示例

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

输出:

Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']

__getitem__(index)
参数:index (int) – 索引
返回:元组(image, target),其中target是列表类型,包含了对图片image的描述。
------
返回类型:tuple
------
class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None)

MS Coco Detection数据集。

参数:

  • rootstring)– 下载数据的目标目录。
  • annFilestring)– json标注文件的路径。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.ToTensor
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数:index (int) – 索引
返回:元组(image, target),其中target是coco.loadAnns返回的对象。
------
返回类型:tuple
------
class torchvision.datasets.LSUN(root, classes='train', transform=None, target_transform=None)

LSUN数据集。

参数:

  • rootstring)– 存放数据文件的根目录。
  • classesstring list)– {‘train’, ‘val’, ‘test’}之一,或要加载类别的列表,如[‘bedroom_train’, ‘church_train’]。
  • transform可被调用 , 可选) – 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数:index (int) – 索引
返回:元组(image, target),其中target是目标类别的索引。
------
Return type:tuple
------
class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)

一种通用数据加载器,其图片应该按照如下的形式保存:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

参数:

  • rootstring)– 根目录路径。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • loader – 一种函数,可以由给定的路径加载图片。
__getitem__(index)
参数:index (int) – 索引
返回:(sample, target),其中target是目标类的类索引。
------
返回类型:tuple
------
class torchvision.datasets.DatasetFolder(root, loader, extensions, transform=None, target_transform=None)

一种通用数据加载器,其数据应该按照如下的形式保存:

root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext

参数:

  • rootstring)– 根目录路径。
  • loader可被调用)– 一种函数,可以由给定的路径加载数据。
  • extensionslist[string])– 列表,包含允许的扩展。
  • transform可被调用 , 可选)– 一种函数或变换,输入数据,返回变换之后的数据。如:对于图片有transforms.RandomCrop
  • target_transform – 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数:index (int) – 索引
返回:(sample, target),其中target是目标类的类索引.
------
返回类型:tuple
------

这个类可以很容易地实现ImageFolder数据集。数据预处理见此处。

示例。

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)

CIFAR10数据集。

参数:

  • rootstring)– 数据集根目录,要么其中应存在cifar-10-batches-py文件夹,要么当download设置为True时cifar-10-batches-py文件夹保存在此处。
  • trainbool, 可选)– 如果设置为True, 从训练集中创建,否则从测试集中创建。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • downloadbool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数:index (int) – 索引
返回:(image, target),其中target是目标类的类索引。
------
返回类型:tuple
------
class torchvision.datasets.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)

CIFAR100数据集。

这是CIFAR10数据集的一个子集。

class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)

STL10数据集。

参数:

  • rootstring)– 数据集根目录,应该包含stl10_binary文件夹。
  • splitstring)– {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}之一,选择相应的数据集。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • downloadbool, optional)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数:index (int) – 索引
返回:(image, target),其中target应是目标类的类索引。
------
返回类型:tuple
------
class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)

SVHN数据集。注意:SVHN数据集将10指定为数字0的标签。然而,这里我们将0指定为数字0的标签以兼容PyTorch的损失函数,因为损失函数要求类标签在[0, C-1]的范围内。

参数:

  • rootstring)– 数据集根目录,应包含SVHN文件夹。
  • splitstring)– {‘train’, ‘test’, ‘extra’}之一,相应的数据集会被选择。‘extra’是extra训练集。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • downloadbool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数:index (int) – 索引
返回:(image, target),其中target是目标类的类索引。
------
返回类型:tuple
------
class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)

Learning Local Image Descriptors Data数据集。

参数:

  • rootstring)– 保存图片的根目录。
  • namestring)– 要加载的数据集。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。
  • download (bool, optional) – 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数:index (int) – 索引
返回:(data1, data2, matches)
------
返回类型:tuple
------

最后更新:

类似资料

  • 本文向大家介绍pytorch1.0中torch.nn.Conv2d用法详解,包括了pytorch1.0中torch.nn.Conv2d用法详解的使用技巧和注意事项,需要的朋友参考一下 Conv2d的简单使用 torch 包 nn 中 Conv2d 的用法与 tensorflow 中类似,但不完全一样。 在 torch 中,Conv2d 有几个基本的参数,分别是 in_channels 输入图像的深

  • torchvision.datasets中包含了以下数据集 MNIST COCO(用于图像标注和目标检测)(Captioning and Detection) LSUN Classification ImageFolder Imagenet-12 CIFAR10 and CIFAR100 STL10 Datasets 拥有以下API: __getitem__ __len__ 由于以上Dataset

相关阅读