当前位置: 首页 > 工具软件 > MM Blog > 使用案例 >

timm库使用

秦浩漫
2023-12-01

一、timm库简介

PyTorch Image Models,简称timm,是一个巨大的PyTorch代码集合,整合了常用的models、layers、utilities、optimizers、schedulers、data-loaders/augmentations和reference training/validation scripts。

二、安装

pip install timm

三、使用

  1. 查看所有模型
    model_list = timm.list_models()
    print(model_list)
  1. 查看具有预训练参数的模型
    model_pretrain_list = timm.list_models(pretrained=True)
    print(model_pretrain_list)
  1. 检索特定模型
    model_resnet = timm.list_models('*resnet*')
    print(model_resnet)
  1. 创建模型
    x = torch.randn((1, 3, 256, 512))
    modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True)
    out = modle_mobilenetv2(x)
    # print(out.shape)
    # torch.Size([1, 1000])
  1. 创建模型–改变输出类别数
    x = torch.randn((1, 3, 256, 512))
    modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, num_classes=100)
    out = modle_mobilenetv2(x)
    # print(out.shape)
    # torch.Size([1, 100])
  1. 创建模型–改变输入通道数
    x = torch.randn((1, 10, 256, 512))
    modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, in_chans=10)
    out = modle_mobilenetv2(x)
    # print(out.shape)
    # torch.Size([1, 1000])
  1. 创建模型–只提取特征
    x = torch.randn((1, 3, 256, 512))
    modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, features_only=True)
    out = modle_mobilenetv2(x)
    # for o in out:
    #     print(o.shape)
    # torch.Size([1, 16, 128, 256])
    # torch.Size([1, 24, 64, 128])
    # torch.Size([1, 32, 32, 64])
    # torch.Size([1, 96, 16, 32])
    # torch.Size([1, 320, 8, 16])
 类似资料: