2021SC@SDUSC
源码:
models\init.py
models\gfpgan_model.py
本篇主要分析init.py与models\gfpgan_model.py下的
class GFPGANModel(BaseModel) 类init(self, opt) 方法
目录
自动扫描和导入注册表的模型模块
#在models文件夹下扫描所有以 '_model.py' 结尾的文件
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# 导入所有模型模块
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
那么实际上就是导入models文件夹下gfpgan_model.py文件,接下来我们来看一下
本文件中只包含GFPGANModel(BaseModel)一个类
创建了一个MODEL_REGISTRY对象,并在类定义的时候用装饰器装饰它,以装饰器的形式调用MODEL_REGISTRY类的register函数
@MODEL_REGISTRY.register()
class GFPGANModel(BaseModel):
"""GFPGAN model for <Towards real-world blind faces restoratin with generative facial prior>"""
基于生成性人脸先验信息的真实盲脸修复 的 GFPGAN 模型
简单看一下代码
super(GFPGANModel, self).__init__(opt)
self.idx = 0
# 网络定义
self.net_g = build_network(opt['network_g'])
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# 读取预训练的模型
load_path = self.opt['path'].get('pretrain_network_g', None)
#如果路径不为空
if load_path is not None:
param_key = self.opt['path'].get('param_key_g', 'params')
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
if self.is_train:
self.init_training_settings()
在读取预训练的模型时,实际上就是从train_gfpgan_v1.yml配置文件中读取到相应的参数的数值与路径。
初始化训练设置
1.读取opt['train']
train_opt = self.opt['train']
2.定义net_d
#构建网络
self.net_d = build_network(self.opt['network_d'])
#将模型放到gpu(cuda)上
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# 读取与训练好的模型
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
3.定义net_g
# net_g_ema 仅用于在一个GPU上测试并保存
# 不需要使用DistributedDataParallel进行包装
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# 读取预训练模型
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g.train()
self.net_d.train()
self.net_g_ema.eval()
根据配置文件:net_g读取预训练模型为arcface_resnet18.pth
4.面部组件网络
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
self.use_facial_disc = True
else:
self.use_facial_disc = False
if self.use_facial_disc:
# left eye
self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
self.print_network(self.net_d_left_eye)
load_path = self.opt['path'].get('pretrain_network_d_left_eye')
if load_path is not None:
self.load_network(self.net_d_left_eye, load_path, True, 'params')
# right eye
self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
self.print_network(self.net_d_right_eye)
load_path = self.opt['path'].get('pretrain_network_d_right_eye')
if load_path is not None:
self.load_network(self.net_d_right_eye, load_path, True, 'params')
# mouth
self.net_d_mouth = build_network(self.opt['network_d_mouth'])
self.net_d_mouth = self.model_to_device(self.net_d_mouth)
self.print_network(self.net_d_mouth)
load_path = self.opt['path'].get('pretrain_network_d_mouth')
if load_path is not None:
self.load_network(self.net_d_mouth, load_path, True, 'params')
self.net_d_left_eye.train()
self.net_d_right_eye.train()
self.net_d_mouth.train()
# ----------- 定义面部组件的 gan loss ----------- #
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
5.定义损失
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
# pyramid loss, component style loss, identity loss 都使用L1损失
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
# gan loss (wgan)
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
6.identity loss的定义
if 'network_identity' in self.opt:
self.use_identity = True
else:
self.use_identity = False
if self.use_identity:
# 定义 identity network
self.network_identity = build_network(self.opt['network_identity'])
self.network_identity = self.model_to_device(self.network_identity)
self.print_network(self.network_identity)
load_path = self.opt['path'].get('pretrain_network_identity')
if load_path is not None:
self.load_network(self.network_identity, load_path, True, None)
self.network_identity.eval()
for param in self.network_identity.parameters():
param.requires_grad = False
# 正则化权重
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
self.net_d_reg_every = train_opt['net_d_reg_every']
# 设置优化器和调度程序
self.setup_optimizers()
self.setup_schedulers()