举例:Faster-RCNN基于vgg19提取features,但是只使用了vgg19一部分模型提取features。
步骤:
下载vgg19的pth文件,在anaconda中直接设置pretrained=True下载一般都比较慢,在model_zoo里面有各种预训练模型的下载链接:
model_urls = {
‘vgg11‘: ‘https://download.pytorch.org/models/vgg11-bbd30ac9.pth‘,
‘vgg13‘: ‘https://download.pytorch.org/models/vgg13-c768596a.pth‘,
‘vgg16‘: ‘https://download.pytorch.org/models/vgg16-397923af.pth‘,
‘vgg19‘: ‘https://download.pytorch.org/models/vgg19-dcbb9e9d.pth‘,
‘vgg11_bn‘: ‘https://download.pytorch.org/models/vgg11_bn-6002323d.pth‘,
‘vgg13_bn‘: ‘https://download.pytorch.org/models/vgg13_bn-abd245e5.pth‘,
‘vgg16_bn‘: ‘https://download.pytorch.org/models/vgg16_bn-6c64b313.pth‘,
‘vgg19_bn‘: ‘https://download.pytorch.org/models/vgg19_bn-c79401a0.pth‘ }
下载好的模型,可以用下面这段代码看一下模型参数,并且改一下模型。在vgg19.pth同级目录建立一个test.py。
import torch
import torch.nn as nn
import torchvision.models as models
vgg16 = models.vgg16(pretrained=False)
#打印出预训练模型的参数
vgg16.load_state_dict(torch.load(‘vgg16-397923af.pth‘))
print(‘vgg16:\n‘, vgg16)
modified_features = nn.Sequential(*list(vgg16.features.children())[:-1])
# to relu5_3
print(‘modified_features:\n‘, modified_features )#打印修改后的模型参数
修改好之后features就可以拿去做Faster-RCNN提取特征用了。