在tools/train.py
中找到以下表示开始训练的代码
# -----------------------start training---------------------------
logger.info('**********************Start training %s/%s(%s)**********************'
% (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
train_model(
model, optimizer, train_loader, model_func=model_fn_decorator(),
lr_scheduler=lr_scheduler, optim_cfg=cfg.OPTIMIZATION,
start_epoch=start_epoch, total_epochs=args.epochs,
start_iter=it, rank=cfg.LOCAL_RANK, tb_log=tb_log,
ckpt_save_dir=ckpt_dir, train_sampler=train_sampler,
lr_warmup_scheduler=lr_warmup_scheduler,
ckpt_save_interval=args.ckpt_save_interval,
max_ckpt_save_num=args.max_ckpt_save_num,
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch
)
train_model
定位在头文件中的tools/train_utils/train_utils.py
中
其关键信息的框架为
for key in epochs:
train_one_epoch #训练一个epoch
save_trained_model #储存训练好的模型
找到train_model
中的train_one_epoch()
accumulated_iter = train_one_epoch(
model, optimizer, train_loader, model_func,
lr_scheduler=cur_scheduler,
accumulated_iter=accumulated_iter, optim_cfg=optim_cfg,
rank=rank, tbar=tbar, tb_log=tb_log,
leave_pbar=(cur_epoch + 1 == total_epochs),
total_it_each_epoch=total_it_each_epoch,
dataloader_iter=dataloader_iter
)
train_one_epoch
函数的定义在同一py文件中,其关键信息为
def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False):
#找到表示训练和梯度优化等的关键函数
model.train() #一个固定语句
optimizer.zero_grad() #梯度清零
loss, tb_dict, disp_dict = model_func(model, batch) #求loss
loss.backward() #反向传播
clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP) #梯度裁剪
optimizer.step() #更新
用到了几个Pytorch自带的函数:
model.train()
:
在训练模型时都会在前面加上model.train()
在测试模型时都会在前面加上model.eval()
如果不写这两个程序也可以运行,这两个方法是针对在训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout
。详细介绍。
clip_grad_norm_()
:
功能是梯度裁剪。即为了防止梯度爆炸,当梯度超过阈值optim_cfg
时将其设置为阈值。详细介绍。
optimizer.step()
:
功能是根据网络反向传播的梯度信息,更新网络的参数,以降低loss。详细介绍。
训练结束之后info
logger.info('**********************End training %s/%s(%s)**********************\n\n\n'
% (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))