当前位置: 首页 > 编程笔记 >

PyTorch实现重写/改写Dataset并载入Dataloader

邵兴庆
2023-03-14
本文向大家介绍PyTorch实现重写/改写Dataset并载入Dataloader,包括了PyTorch实现重写/改写Dataset并载入Dataloader的使用技巧和注意事项,需要的朋友参考一下

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem__和__len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:


得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float() 

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索小牛知识库以前的文章或继续浏览下面的相关文章希望大家以后多多支持小牛知识库!

 类似资料:
  • 问题内容: 每当抛出javascript异常时,我们还想做一些额外的事情。 从以下文档: 角度表达式中任何未捕获的异常都委托给此服务。默认的实现只是将$ log.error委托给浏览器控制台。 它说“默认实现”的事实使我认为有一种方法可以为服务提供我们自己的实现,并在引发异常时做我们想要的事情。我的问题是,你如何做到这一点?我们如何使所有异常都保留给该服务,然后提供我们希望发生的功能? 问题答案:

  • 我得到一个包含100行数据的表。(Sqlite3和Linux) 这些行上的每个都由多个进程更新。既然同一行不能修改两次(一个进程“拥有”一行并且只有一行),你认为我真的需要使用事务吗?

  • 我想我在并发s3写入方面有问题。两个(或更多)进程同时将几乎相同的内容写入相同的s3位置。我想确定控制这种情况的并发规则。 按照设计,除了一个进程外,所有进程都会在写入s3时被杀死。(我说过,他们写的内容“几乎”相同,因为除了一个进程之外,所有进程都被杀死了。如果所有进程都被允许生存,他们最终会写相同的内容。) 我的理论是,被终止的进程在s3上留下了一个不完整的文件,而另一个文件(可能已完全写入)

  • 本文向大家介绍Python实现数据库并行读取和写入实例,包括了Python实现数据库并行读取和写入实例的使用技巧和注意事项,需要的朋友参考一下 这篇主要记录一下如何实现对数据库的并行运算来节省代码运行时间。语言是Python,其他语言思路一样。 前言 一共23w条数据,是之前通过自然语言分析处理过的数据,附一张截图: 要实现对news主体的读取,并且找到其中含有的股票名称,只要发现,就将这支股票和

  • 现在准备要构建一个工具,用来把前面idata.txt里的数据按group分行显示,就像这样: 2 9 10 3 1 2 3 我们可以借助语法分析树的Listener机制来对词法分析结束后生成的记号流进行改写,我们不需要实现每一个Listener接口方法,只需要在捕获到group的时候把换行符插到它末尾就行。实现改写的代码如下所示: import org.antlr.v4.runtime.Toke