当前位置: 首页 > 软件库 > 神经网络/人工智能 > >

WGAN-tensorflow

a tensorflow implementation of WGAN
授权协议 Readme
开发语言 Python
所属分类 神经网络/人工智能
软件类型 开源软件
地区 不详
投 递 者 公冶龙野
操作系统 跨平台
开源组织
适用人群 未知
 软件概览

Wasserstein GAN

This is a tensorflow implementation of WGAN on mnist and SVHN.

Requirement

tensorflow==1.0.0+

numpy

matplotlib

cv2

Usage

Train: Use WGAN.ipynb, set the parameters in the second cell and choose the dataset you want to run on. You can use tensorboard to visualize the training.

Generation : Use generate_from_ckpt.ipynb, set ckpt_dir in the second cell. Don't forget to change the dataset type accordingly.

Note

  1. All data will be downloaded automatically, the SVHN script is modified from this.

  2. All parameters are set to the values the original paper recommends by default. Diters represents the number of critic updates in one step, in Original PyTorch version it was set to 5 unless iterstep < 25 or iterstep % 500 == 0 , I guess since the critic is free to fully optimized, it's reasonable to make more updates to critic at the beginning and every 500 steps, so I borrowed it without tuning. The learning rates for generator and critic are both set to 5e-5 , since during the training time the gradient norms are always relatively high(around 1e3), I suggest no drastic change on learning rates.

  3. MLP version could take longer time to generate sharp image.

  4. In this implementation, the critic loss is tf.reduce_mean(fake_logit - true_logit), and generator loss is tf.reduce_mean(-fake_logit) . Actually, the whole system still works if you add a - before both of them, it doesn't matter. Recall that the critic loss in duality form is , and the set is symmetric about the sign. Substitute $f$ with $-f$ gives us , the opposite number of original form. The original PyTorch implementation takes the second form, this implementation takes the first, both will work equally. You might want to add the - and try it out.

  5. Please set your device you want to run on in the code, search tf.device and change accordingly. It runs on gpu:0 by default.

  6. Inproved-WGAN is added, but somehow the gradient norm is close to 1, so the square-gradient normalizer doesn't work. Couldn't figure out why. ​

  • 论文传送门:https://arxiv.org/pdf/1701.07875.pdf 参考文章:令人拍案叫绝的Wasserstein GAN - 知乎​​​​​​ WGAN的目的:解决GAN的梯度不稳定、多样性不足的问题。 WGAN的思想:使用Wasserstein距离代替JS散度,来描述生成分布与真实分布的距离。 WGAN的实现:与GAN相比,有四处不同: ①判别器D去掉最后一层sigmoid激

  • 用pytorch写一个WGAN的代码需要先定义一个网络架构,然后定义生成器和判别器。这里是一段示例代码: import torch import torch.nn as nn import torch.nn.functional as F class Generator(nn.Module):     def __init__(self, input_size, hidden_size, outp

  • import torch from torch import nn, optim, autograd import numpy as np import visdom from torch.nn import functional as F from matplotlib import pyplot as plt import random h_dim = 400 ba

  • 1 生成器判别器实现 import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers class Generator(keras.Model): def __init__(self): super(Generator, self).__init_

  • 在网上找了一个wgan的实现代码,在本地跑了以下,效果还可以,我把它封装成一个函数了,感兴趣的朋友可以用一下 不过这个gan生成的是一维数据,对于图片数据可能需要对代码进行一些改变 import numpy as np import pandas as pd import torch import torch.autograd as autograd import torch.nn as nn i

  • import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os import numpy as np from scipy import misc,ndimage # 读入本地的MNIST数据集,该函数为mnist专用 mnist = input_data.read_data

相关阅读

相关文章

相关问答

相关文档