我想看看 resnet50 模型,第 48 层的输出,我写了下面的代码,但是运行报错了
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import Tensor
import torch.nn as nn
# 加载ResNet-50模型
model = torchvision.models.resnet50(pretrained=True)
# 获取前48层的子模型
model = nn.Sequential(*list(model.children())[:48])
# 修改fc层
# model.fc = nn.Linear(2048, 512)
# 设置模型为评估模式
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
])
# 加载并预处理图像
image = Image.open('std.jpg')
image = transform(image).unsqueeze(0) # 添加批次维度
# 使用模型进行推理
with torch.no_grad():
features: Tensor = model(image)
print(features.shape)
报错如下:
/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Traceback (most recent call last):
File "/Users/ponponon/Desktop/code/me/resnet_example/resnet48_handle_image_into_vector.py", line 35, in <module>
features: Tensor = model(image)
File "/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/ponponon/.local/share/virtualenvs/image2vector-n-kX1tX6/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x1 and 2048x1000)
我明明把后面的 fc 给丢掉了呀,为什么还报错呢?
我该如何修改?
改成下面这样就可以了