当前位置: 首页 > 工具软件 > 3DDFA > 使用案例 >

新手学习3D人脸重建3DDFA程序

商和雅
2023-12-01


Face Alignment Across Large Poses: A 3D Solution

argparse模块

argparse模块的作用是用于解析命令行参数,基本使用:

import argparse
#创建解析器对象,description:描述程序
parser = argparse.ArgumentParser(description='3DDFA inference pipeline')
#添加参数,type:把从命令行输入的结果转成设置的类型
parser.add_argument('-f', '--files', nargs='+',
                    help='image files paths fed into network, single or multiple images')
parser.add_argument('-m', '--mode', default='gpu', type=str, help='gpu or cpu mode')
args = parser.parse_args()

main(args)进入main函数

#torch.load(model_path)返回的是一个 OrderedDict
checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict']
'''
pytorch允许把在GPU上训练的模型加载到CPU上,也允许把在CPU上训练的模型加载到GPU上
torch.load(checkpoint_fp)   #CPU->CPU,GPU->GPU
torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)   #GPU->CPU
torch.load(checkpoint_fp, map_location=lambda storage, loc: storage.cuda(1))   #CPU->GPU1
'''

将模型的全部参数保存到model_dict

model_dict = model.state_dict()

加载训练好的模型

model.load_state_dict(model_dict)

加载dlib模块进行人脸检测和裁剪

if args.dlib_landmark:
	dlib_landmark_model = 'models/shape_predictor_68_face_landmarks.dat'
	face_regressor = dlib.shape_predictor(dlib_landmark_model)
if args.dlib_bbox:
	face_detector = dlib.get_frontal_face_detector()

torchvision.transforms是pytorch中的图像预处理包
用transforms.Compose()将多个步骤融合到一起

transform = transforms.Compose([ToTensorGjz(), NormalizeGjz(mean=127.5, std=128)])
'''
Normalize:Normalized an tensor image with mean and standard deviation
ToTensor:convert a PIL image to tensor (H*W*C) in range [0,255] to a torch.Tensor(C*H*W) in the range [0.0,1.0]
'''

对args.files中的每一张图像循环进行如下操作:

#默认使用dlib  landmark和bbox
#将检测到的所有人脸位置保存到rects中
rects = face_detector(img_ori, 1)   
#然后将每一个rect的左、上、右、下四个位置保存到roi_bbox,利用crop_img()裁剪图片
roi_box = parse_roi_box_from_landmark(pts)
img = crop_img(img_ori, roi_box)
#将图片resize到网络要求的大小,插值方式为最近邻插值
img = cv2.resize(img, dsize=(STD_SIZE, STD_SIZE), interpolation=cv2.INTER_LINEAR)
input = transform(img).unsqueeze(0)  #在图片第零维增加一个维度
#传入模型
param = model(input)
#降维然后flatten至一维
param = param.squeeze().cpu().numpy().flatten().astype(np.float32)
# 预测68个特征点或者dense特征点
pts68 = predict_68pts(param, roi_box)

未完…

 类似资料: