当前位置: 首页 > 工具软件 > tootgroup.py > 使用案例 >

train.py | test.py

慕容灿
2023-12-01

1. train

tools/train.py中找到以下表示开始训练的代码

    # -----------------------start training---------------------------
    logger.info('**********************Start training %s/%s(%s)**********************'
                % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
    train_model(
        model, optimizer, train_loader, model_func=model_fn_decorator(),
        lr_scheduler=lr_scheduler, optim_cfg=cfg.OPTIMIZATION,
        start_epoch=start_epoch, total_epochs=args.epochs,
        start_iter=it, rank=cfg.LOCAL_RANK, tb_log=tb_log, 	
        ckpt_save_dir=ckpt_dir, train_sampler=train_sampler,
        lr_warmup_scheduler=lr_warmup_scheduler,
        ckpt_save_interval=args.ckpt_save_interval,
        max_ckpt_save_num=args.max_ckpt_save_num,
        merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch
    )

train_model定位在头文件中的tools/train_utils/train_utils.py
其关键信息的框架为

for key in epochs:
	train_one_epoch #训练一个epoch
save_trained_model #储存训练好的模型

找到train_model中的train_one_epoch()

accumulated_iter = train_one_epoch(
                model, optimizer, train_loader, model_func,
                lr_scheduler=cur_scheduler,
                accumulated_iter=accumulated_iter, optim_cfg=optim_cfg,
                rank=rank, tbar=tbar, tb_log=tb_log,
                leave_pbar=(cur_epoch + 1 == total_epochs),
                total_it_each_epoch=total_it_each_epoch,
                dataloader_iter=dataloader_iter
            )

train_one_epoch函数的定义在同一py文件中,其关键信息为

def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
                    rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False): 
        #找到表示训练和梯度优化等的关键函数
        model.train() #一个固定语句
        optimizer.zero_grad() #梯度清零
        loss, tb_dict, disp_dict = model_func(model, batch) #求loss
        loss.backward() #反向传播
        clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP) #梯度裁剪
        optimizer.step() #更新

用到了几个Pytorch自带的函数:
model.train()
在训练模型时都会在前面加上model.train()
在测试模型时都会在前面加上model.eval()
如果不写这两个程序也可以运行,这两个方法是针对在训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。详细介绍。

clip_grad_norm_()
功能是梯度裁剪。即为了防止梯度爆炸,当梯度超过阈值optim_cfg时将其设置为阈值。详细介绍。

optimizer.step()
功能是根据网络反向传播的梯度信息,更新网络的参数,以降低loss。详细介绍。

训练结束之后info

    logger.info('**********************End training %s/%s(%s)**********************\n\n\n'
                % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
 类似资料:

相关阅读

相关文章

相关问答