当前位置: 首页 > 知识库问答 >
问题:

人工智能 - 如何获取 ResNet-50 模型的前 48 层的输出?

谷梁浩思
2023-06-16

图片.png

我想看看 resnet50 模型,第 48 层的输出,我写了下面的代码,但是运行报错了

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import Tensor
import torch.nn as nn

# 加载ResNet-50模型
model = torchvision.models.resnet50(pretrained=True)

# 获取前48层的子模型
model = nn.Sequential(*list(model.children())[:48])

# 修改fc层
# model.fc = nn.Linear(2048, 512)

# 设置模型为评估模式
model.eval()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
])

# 加载并预处理图像
image = Image.open('std.jpg')
image = transform(image).unsqueeze(0)  # 添加批次维度

# 使用模型进行推理
with torch.no_grad():
    features: Tensor = model(image)
    print(features.shape)

报错如下:

/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Traceback (most recent call last):
  File "/Users/ponponon/Desktop/code/me/resnet_example/resnet48_handle_image_into_vector.py", line 35, in <module>
    features: Tensor = model(image)
  File "/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x1 and 2048x1000)

我明明把后面的 fc 给丢掉了呀,为什么还报错呢?

我该如何修改?

共有1个答案

太叔岳
2023-06-16

改成下面这样就可以了

import torchvision.models as models
import torch.nn.functional as F
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import Tensor
import torch.nn as nn


class ImageRetrievalNet(nn.Module):

    def __init__(self, dim: int = 512):
        super().__init__()
        resnet50_model = models.resnet50()
        features = list(resnet50_model.children())[:-2]

        self.features = nn.Sequential(*features)

    def forward(self, x: Tensor):
        # featured_t shape: torch.Size([1, 2048, 7, 7])
        featured_t: Tensor = self.features(x)

        print(featured_t.shape)

        return featured_t


# 加载ResNet-50模型
model = ImageRetrievalNet()

# 修改fc层
# model.fc = nn.Linear(2048, 512)

# 设置模型为评估模式
model.eval()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
])

# 加载并预处理图像
image = Image.open('std.jpg')
image = transform(image).unsqueeze(0)  # 添加批次维度

# 使用模型进行推理
with torch.no_grad():
    features: Tensor = model(image)
    print(features.shape)

 类似资料:
  • 主要内容:AI类型 - 1:基于功能,人工智能类型-2:基于功能人工智能可以分为多种类型,主要有两种类型的主要分类,它们基于能力并基于AI的功能。以下是解释AI类型的流程图。 AI类型 - 1:基于功能 基于能力的人工智能的类型如下 - 1. 弱AI或狭隘AI 狭隘AI是一种能够执行智能专用任务的AI。最常见和当前可用的AI是人工智能领域的狭隘AI。 狭隘的AI不能超出其领域或限制,因为它只针对一项特定任务进行培训。因此它也被称为弱AI。如果超出限制,缩小的A

  • https://github.com/zalandoresearch/fashion-mnist#visualization 我下载了测试数据集 但是解压后发现,变成了 unix 可执行文件 但是我想获取原始图像,怎么办呢?

  • 目前的开源视觉大模型有哪些? 我知道的只有智谱的 CogVLM,还有其他的吗? https://github.com/THUDM/CogVLM

  • Kubernetes 在人工智能领域的应用。 TBD kubeflow - Kubernetes 机器学习工具箱

  • 主要内容:1. 简单的反射代理,2. 基于模型的反射代理,3. 基于目标的代理,4. 基于效用的代理,5. 学习代理代理可以根据其感知智能和能力的程度分为五类。所有这些代理都可以改善其性能并在一段时间内产生更好的行动。这些如下: 简单的反射代理 基于模型的反射代理 基于目标的代理商 基于效用的代理 学习代理 1. 简单的反射代理 简单反射代理是最简单的代理。这些代理人根据当前的感知来做出决定,并忽略其余的感知历史。 这些代理只能在完全可观察的环境中取得成功。 简单反射代理在决策和行动过程中不考虑

  • 人工智能在当今社会中具有各种应用。它已成为当今时代的必要条件,因为它可以在多个行业中以有效的方式解决复杂问题,例如医疗保健,娱乐,金融,教育等。AI使我们的日常生活更加舒适和快速。 以下是一些应用人工智能的领域: 1. AI在天文学中应用 人工智能对于解决复杂的宇宙问题非常有用。人工智能技术有助于理解宇宙,例如它的工作原理,起源等。 2. AI在医疗保健领域应用 在过去的五到十年中,人工智能对医疗