当前位置: 首页 > 工具软件 > resume > 使用案例 >

if args.resume:从断点处开始继续训练模型——How to resume training?(Faster RCNN)

乜胜泫
2023-12-01
if args.resume:
    load_name = os.path.join(output_dir,
      'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
    print("loading checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    args.session = checkpoint['session']
    args.start_epoch = checkpoint['epoch']
    fasterRCNN.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr = optimizer.param_groups[0]['lr']
    if 'pooling_mode' in checkpoint.keys():
      cfg.POOLING_MODE = checkpoint['pooling_mode']
    print("loaded checkpoint %s" % (load_name))

进入trainval_net.py文件,进入模型参数配置函数def parse_args()函数,修改resume trained model部分的参数,将:--r 修改为True,再添加对应的'--checksession''--checkepoch''--checkpoint'的参数值。(对应中断节点的模型训练文件参数,如:)

最终,在训练模型的命令行中添加修改的参数,即可。

CUDA_VISIBLE_DEVICES=0 python trainval_net.py --dataset pascal_voc --net res101 --bs 8 --nw 0 --lr 0.001 --lr_decay_step 5 --cuda --r True --checksession 1 --checkepoch 10 --checkpoint 91

最后的--r True --checksession 1 --checkepoch 10 --checkpoint 91即为控制模型从断点处继续开始,而额外添加的控制参数。

# resume trained model——pytorch重新加载参数
  parser.add_argument('--r', dest='resume',
                      help='resume checkpoint or not',
                      default=False, type=bool)
  parser.add_argument('--checksession', dest='checksession',
                      help='checksession to load model',
                      default=1, type=int)
  parser.add_argument('--checkepoch', dest='checkepoch',
                      help='checkepoch to load model',
                      default=1, type=int)
  parser.add_argument('--checkpoint', dest='checkpoint',
                      help='checkpoint to load model',
                      default=0, type=int)
 类似资料: