import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.utils import save_image import os from anime_face_generator.dataset import ImageDataset batch_size = 32 num_epoch = 100 z_dimension = 100 dir_path = './wgan_img' # 创建文件夹 if not os.path.exists(dir_path): os.mkdir(dir_path) def to_img(x): """因为我们在生成器里面用了tanh""" out = 0.5 * (x + 1) return out dataset = ImageDataset() dataloader = DataLoader(dataset, batch_size=32, shuffle=False) class Generator(nn.Module): def __init__(self): super().__init__() self.gen = nn.Sequential( # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # 上一步的输出形状:(512) x 4 x 4 nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), # 上一步的输出形状: (256) x 8 x 8 nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), # 上一步的输出形状: (256) x 16 x 16 nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), # 上一步的输出形状:(256) x 32 x 32 nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False), nn.Tanh() # 输出范围 -1~1 故而采用Tanh # nn.Sigmoid() # 输出形状:3 x 96 x 96 ) def forward(self, x): x = self.gen(x) return x def weight_init(m): # weight_initialization: important for wgan class_name = m.__class__.__name__ if class_name.find('Conv') != -1: m.weight.data.normal_(0, 0.02) elif class_name.find('Norm') != -1: m.weight.data.normal_(1.0, 0.02) class Discriminator(nn.Module): def __init__(self): super().__init__() self.dis = nn.Sequential( nn.Conv2d(3, 64, 5, 3, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 输出 (64) x 32 x 32 nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # 输出 (128) x 16 x 16 nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # 输出 (256) x 8 x 8 nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), # 输出 (512) x 4 x 4 nn.Conv2d(512, 1, 4, 1, 0, bias=False), nn.Flatten(), # nn.Sigmoid() # 输出一个数(概率) ) def forward(self, x): x = self.dis(x) return x def weight_init(m): # weight_initialization: important for wgan class_name = m.__class__.__name__ if class_name.find('Conv') != -1: m.weight.data.normal_(0, 0.02) elif class_name.find('Norm') != -1: m.weight.data.normal_(1.0, 0.02) def save(model, filename="model.pt", out_dir="out/"): if model is not None: if not os.path.exists(out_dir): os.mkdir(out_dir) torch.save({'model': model.state_dict()}, out_dir + filename) else: print("[ERROR]:Please build a model!!!") import QuickModelBuilder as builder if __name__ == '__main__': one = torch.FloatTensor([1]).cuda() mone = -1 * one is_print = True # 创建对象 D = Discriminator() G = Generator() D.weight_init() G.weight_init() if torch.cuda.is_available(): D = D.cuda() G = G.cuda() lr = 2e-4 d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, ) g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, ) d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99) g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99) fake_img = None # ##########################进入训练##判别器的判断过程##################### for epoch in range(num_epoch): # 进行多个epoch的训练 pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader)) for i, img in enumerate(dataloader): num_img = img.size(0) real_img = img.cuda() # 将tensor变成Variable放入计算图中 # 这里的优化器是D的优化器 for param in D.parameters(): param.requires_grad = True # ########判别器训练train##################### # 分为两部分:1、真的图像判别为真;2、假的图像判别为假 # 计算真实图片的损失 d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0 real_out = D(real_img) # 将真实图片放入判别器中 d_loss_real = real_out.mean(0).view(1) d_loss_real.backward(one) # 计算生成图片的损失 z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声 z = z.reshape(num_img, z_dimension, 1, 1) fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离 fake_out = D(fake_img) # 判别器判断假的图片, d_loss_fake = fake_out.mean(0).view(1) d_loss_fake.backward(mone) d_loss = d_loss_fake - d_loss_real d_optimizer.step() # 更新参数 # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01 for parm in D.parameters(): parm.data.clamp_(-0.01, 0.01) # ==================训练生成器============================ # ###############################生成网络的训练############################### for param in D.parameters(): param.requires_grad = False # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D g_optimizer.zero_grad() # 梯度归0 z = torch.randn(num_img, z_dimension).cuda() z = z.reshape(num_img, z_dimension, 1, 1) fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片 output = D(fake_img) # 经过判别器得到的结果 # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss g_loss = torch.mean(output).view(1) # bp and optimize g_loss.backward(one) # 进行反向传播 g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数 # 打印中间的损失 pbar.set_right_info(d_loss=d_loss.data.item(), g_loss=g_loss.data.item(), real_scores=real_out.data.mean().item(), fake_scores=fake_out.data.mean().item(), ) pbar.update() try: fake_images = to_img(fake_img.cpu()) save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1)) except: pass if is_print: is_print = False real_images = to_img(real_img.cpu()) save_image(real_images, dir_path + '/real_images.png') pbar.finish() d_scheduler.step() g_scheduler.step() save(D, "wgan_D.pt") save(G, "wgan_G.pt")
到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索小牛知识库以前的文章或继续浏览下面的相关文章希望大家以后多多支持小牛知识库!
小伙伴会发现上面的页面右侧有一个重要的东西:软件操作人像动漫化的接口。这是一个Post请求,发送该请求的网址并不全,需要你提供自己的access_token。同时呢,发送Post请求不仅需要携带Headers,还需要携带一个Params参数,其中Headers是固定的,image参数是图片的Base64编码格式。
本文向大家介绍python实现人像动漫化的示例代码,包括了python实现人像动漫化的示例代码的使用技巧和注意事项,需要的朋友参考一下 利用百度api实现人像动漫化 百度API地址:https://ai.baidu.com/tech/imageprocess/selfie_anime 技术文档:https://ai.baidu.com/ai-doc/IMAGEPROCESS/Mk4i6olx5 注
本文向大家介绍Android实现调用摄像头,包括了Android实现调用摄像头的使用技巧和注意事项,需要的朋友参考一下 应用场景: 在Android开发过程中,有时需要调用手机自身设备的功能,本文侧重摄像头拍照功能的调用。 知识点介绍: 使用权限:调用手机自身设备功能(摄像头拍照功能),应该确保已经在AndroidManifest.xml中正确声明了对摄像头的使用及其它相关的feature 1.
本文向大家介绍关于ResNeXt网络的pytorch实现,包括了关于ResNeXt网络的pytorch实现的使用技巧和注意事项,需要的朋友参考一下 此处需要pip install pretrainedmodels 以上这篇关于ResNeXt网络的pytorch实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持呐喊教程。
本文向大家介绍python+openCV利用摄像头实现人员活动检测,包括了python+openCV利用摄像头实现人员活动检测的使用技巧和注意事项,需要的朋友参考一下 本文实例为大家分享了python+openCV利用摄像头实现人员活动检测的具体代码,供大家参考,具体内容如下 1.前言 最近在做个机器人比赛,其中一项要求是让机器人实现对是否有人员活动的检测,所以就先拿PC端写一下,准备移植到机器人
本文向大家介绍C#实现调用本机摄像头实例,包括了C#实现调用本机摄像头实例的使用技巧和注意事项,需要的朋友参考一下 本文实例源自一个项目,其中需要调用本机的摄像头进行拍照,分享给大家供大家参考之用。具体步骤如下: 硬件环境:联想C360一体机,自带摄像头 编写环境:vs2010 语言:C# WPF 实现步骤: 下载AForge类库,并添加引用: 在xaml界面中添加VideoSourcePlaye