当前位置: 首页 > 工具软件 > GFPGAN > 使用案例 >

GFPGAN源代码分析(九)

昝光临
2023-12-01

2021SC@SDUSC

一、分析的代码片段

1.代码展示

class GFPGANer():

    def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
        self.upscale = upscale
        self.bg_upsampler = bg_upsampler

        # initialize model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # initialize the GFP-GAN
        if arch == 'clean':
            self.gfpgan = GFPGANv1Clean(
                out_size=512,
                num_style_feat=512,
                channel_multiplier=channel_multiplier,
                decoder_load_path=None,
                fix_decoder=False,
                num_mlp=8,
                input_is_latent=True,
                different_w=True,
                narrow=1,
                sft_half=True)
        else:
            self.gfpgan = GFPGANv1(
                out_size=512,
                num_style_feat=512,
                channel_multiplier=channel_multiplier,
                decoder_load_path=None,
                fix_decoder=True,
                num_mlp=8,
                input_is_latent=True,
                different_w=True,
                narrow=1,
                sft_half=True)
        # initialize face helper
        self.face_helper = FaceRestoreHelper(
            upscale,
            face_size=512,
            crop_ratio=(1, 1),
            det_model='retinaface_resnet50',
            save_ext='png',
            device=self.device)

        if model_path.startswith('https://'):
            model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
        loadnet = torch.load(model_path)
        if 'params_ema' in loadnet:
            keyname = 'params_ema'
        else:
            keyname = 'params'
        self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
        self.gfpgan.eval()
        self.gfpgan = self.gfpgan.to(self.device)

2.代码作用分析

初始化模型,根据参数arch选择性初始化GFP-GAN,初始化face helper。

二、具体作用

1.初始化face helper

self.face_helper = FaceRestoreHelper(
    upscale,
    face_size=512,
    crop_ratio=(1, 1),
    det_model='retinaface_resnet50',
    save_ext='png',
    device=self.device)

2.读取model并继续初始化

loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
    keyname = 'params_ema'
else:
    keyname = 'params'
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)

3.torch

包torch包含了多维疑是的数据结构及基于其上的多种数学操作。

torch.load()

torch.load() 使用 Python 的 解压工具(unpickling)来反序列化 pickled object 到对应存储设备上。首先在 CPU 上对压缩对象进行反序列化并且移动到它们保存的存储设备上,如果失败了(如:由于系统中没有相应的存储设备),就会抛出一个异常。用户可以通过 register_package 进行扩展,使用自己定义的标记和反序列化方法。

4. load_file_from_url( )

从指定url中读取model下载文件,如果路径是网址,则调用这个函数下载相应的model

def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
    """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
    """
    hub_dir = get_dir()
    model_dir = os.path.join(hub_dir, 'checkpoints')
    print('hub_dir',hub_dir)
    print('model_dir',model_dir)
    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')
    #做路径的拼接,并递归创建目录
    os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
 
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
    return cached_file
 类似资料: