当前位置: 首页 > 知识库问答 >
问题:

如何将PyTorch中的DataLoader更改为读取一幅图像进行预测?

端木飞
2023-03-14

目前,我有一个预先训练过的模型,它使用数据加载器读取一批图像来训练模型。

self.data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, 
   num_workers=1, pin_memory=True)

...

model.eval()
for step, inputs in enumerate(test_loader.data_loader):
   outputs = model(torch.cat([inputs], 1))

...

我想对图像进行处理(预测),因为它们是从队列中到达的。它应该类似于读取单个图像并运行模型对其进行预测的代码。大致如下:

from PIL import Image

new_input = Image.open(image_path)
model.eval()
outputs = model(torch.cat([new_input ], 1))

我想知道您是否可以指导我如何做到这一点,并在DataLoader中应用相同的转换。

共有2个答案

常哲彦
2023-03-14

我不知道dataLoader,但您可以使用以下功能加载单个图像:

def safe_pil_loader(path, from_memory=False):
try:
    if from_memory:
        img = Image.open(path)
        res = img.convert('RGB')
    else:
        with open(path, 'rb') as f:
            img = Image.open(f)
            res = img.convert('RGB')
except:
    res = Image.new('RGB', (227, 227), color=0)
return res

对于应用转换,您可以执行以下操作:

trans = transforms.Compose([
            transforms.Resize(299),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            normalize,
        ])
img=trans(img)
班景龙
2023-03-14

您可以使用IterableDataset执行此操作:

from torch.utils.data import IterableDataset

class MyDataset(IterableDataset):
    def __init__(self, image_queue):
      self.queue = image_queue

    def read_next_image(self):
        while self.queue.qsize() > 0:
            # you can add transform here
            yield self.queue.get()
        return None

    def __iter__(self):
        return self.read_next_image()

和batch_size=1:

import queue
import torchvision.transforms.functional as TF

buffer = queue.Queue()
new_input = Image.open(image_path)
buffer.put(TF.to_tensor(new_input)) 
# ... Populate queue here

dataset = MyDataset(buffer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
for data in dataloader:
   model(data) # data is one-image batch of size [1,3,H,W] where 3 - number of color channels
 类似资料:
  • 我使用torch的标准数据加载器。乌提尔斯。数据我创建dataset类,然后按以下方式构建DataLoader: 它运行完美,但是数据集足够大——300k图像。因此,使用DataLoader读取图像需要大量时间。所以在调试阶段构建这么大的DataLoader真的很糟糕!我只是想测试一些我的假设,想快点做!我不需要为此加载整个数据集。 我试图找到一种方法,如何只加载数据集的一小部分,而不在整个数据集

  • 这样循环有点慢。我尝试添加numba的@njit decorator,但显然它与opencv有问题。 输入图像为32x32像素。它们映射到32x32圆的输出图像。每个圆绘制在一个20x20像素的正方形内。也就是说,输出图像为640x640像素 一张图像需要大约100毫秒才能转化为圆圈,我希望能将其降低到30毫秒或更低 有什么建议吗?

  • 我正在使用GridLayoutManager的回收器视图。我还使用以下代码使第一个项目变大 除了一个以外,其他一切都正常:位置1处的物品(即大图中的下一个物品)被垂直拉长,以匹配大件物品的高度。从第3行开始,所有图像如图所示。 我怎么才能摆脱这个? 编辑:经过一些分析 所以问题似乎是,大图像是水平的两个跨度,但垂直的单个跨度,因为我已经强制我的ImageView是方形的,它看起来像是也取了两行,而

  • 我看了其他答案,试着: 以及: 在类中,我把图像文件放在资源文件夹中,也放在与我的文件和我项目的根文件夹中,甚至在开始时包含了/符号URL字符串,但没有工作。我想知道最近有没有人尝试过并成功了?

  • 以下程序演示如何将彩色图像读取为灰度并使用JavaFX窗口显示。 在这里通过将标志与带有彩色图像路径的字符串一起传递来读取图像。 假定以下是上述程序中指定的输入图像。那么输出结果应该为一个灰色的图片。