PyTorch Image Models,简称timm,是一个巨大的PyTorch代码集合,整合了常用的models、layers、utilities、optimizers、schedulers、data-loaders/augmentations和reference training/validation scripts。
pip install timm
model_list = timm.list_models()
print(model_list)
model_pretrain_list = timm.list_models(pretrained=True)
print(model_pretrain_list)
model_resnet = timm.list_models('*resnet*')
print(model_resnet)
x = torch.randn((1, 3, 256, 512))
modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True)
out = modle_mobilenetv2(x)
# print(out.shape)
# torch.Size([1, 1000])
x = torch.randn((1, 3, 256, 512))
modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, num_classes=100)
out = modle_mobilenetv2(x)
# print(out.shape)
# torch.Size([1, 100])
x = torch.randn((1, 10, 256, 512))
modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, in_chans=10)
out = modle_mobilenetv2(x)
# print(out.shape)
# torch.Size([1, 1000])
x = torch.randn((1, 3, 256, 512))
modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, features_only=True)
out = modle_mobilenetv2(x)
# for o in out:
# print(o.shape)
# torch.Size([1, 16, 128, 256])
# torch.Size([1, 24, 64, 128])
# torch.Size([1, 32, 32, 64])
# torch.Size([1, 96, 16, 32])
# torch.Size([1, 320, 8, 16])