https://github.com/taki0112/Self-Attention-GAN-Tensorflow
--dataset #数据集文件,需要自己下载
--celeba
---img1.jpg
---img2.jpg
--ops.py #图层文件
--utils.py #操作文件
--SAGAN.py#模型文件
--main.py#主函数文件
模型崩塌是指生成器学习到一种能够欺骗判别器的特征后,所有的学习特征都会向这个特征靠拢。具体表现就是GAN一旦生成一张能够欺骗判别器的图像之后,那么其他特征会与这个特征非常接近,导致最终生成的结果中有一样或者类似的图像,这个现象在DCGAN中非常明显,后面的提出的模型基本很少出现这种现象了。
(1)分批打乱数据
(2)期望值特征匹配
(3)更新历史均值
(4)one-side dlabel smoothing
(5)virtual batch normalization
python main.py --phase train --dataset celebA --gan_type hinge
总共有4个场景需要训练测试,
自动轮询这几个场景,一个场景训练完再训练另一个场景:
#!/usr/bin
echo 'start'
python main.py --phase train --dataset DJI_0501 --gan_type hinge
wait
python main.py --phase train --dataset Berghouse --gan_type hinge
wait
python main.py --phase train --dataset DJI_0862 --gan_type hinge
wait
python main.py --phase train --dataset Bluemlisalphutte --gan_type hinge
echo "end"
python main.py --phase test --dataset celebA --gan_type hinge
测试:四个场景都测试
#!/usr/bin
echo 'start'
python main.py --phase test --dataset DJI_0501 --gan_type hinge
wait
python main.py --phase test --dataset Bluemlisalphutte --gan_type hinge
wait
python main.py --phase test --dataset DJI_0862 --gan_type hinge
wait
python main.py --phase test --dataset Berghouse --gan_type hinge
echo "end"
测试源码是输入一张随机数据生成10张最好的相似图像
def test(self):
import time
from PIL import Image
import numpy as np
from sklearn import preprocessing
start_Time = time.time()
self.saver = tf.train.Saver()
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
result_dir = os.path.join(self.result_dir, self.model_dir)
check_folder(result_dir)
if could_load:
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
#原代码
for i in range(self.test_num) :
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
[image_frame_dim, image_frame_dim],
result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
end_Time = time.time()
print("process time % s " % (end_Time - start_Time))
修改测试代码:修改输入为指定图像,生成对应的图像
def test(self):
import time
from PIL import Image
import numpy as np
from sklearn import preprocessing
start_Time = time.time()
self.saver = tf.train.Saver()
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
result_dir = os.path.join(self.result_dir, self.model_dir)
check_folder(result_dir)
if could_load:
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
#原代码
#for i in range(self.test_num) :
# z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
# samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
# save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
# [image_frame_dim, image_frame_dim],
# result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
#end_Time = time.time()
#print("process time % s " % (end_Time - start_Time))
path_list = os.listdir('./dataset/' + self.dataset_name)
path_list.sort()
data_num = len(path_list)
for filename in path_list:
print('filenamexxxxxxxxxxxx',filename)
z_sample=Image.open(os.path.join('./dataset',self.dataset_name,filename))
z_sample = z_sample.resize((64,64),Image.ANTIALIAS)
z_sample = np.array(z_sample,dtype='int8')
z_sample = z_sample.reshape(48, self.z_dim)
z_sample = preprocessing.scale(z_sample)
z_sample = z_sample.reshape(48, 1, 1, self.z_dim)
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
[image_frame_dim, image_frame_dim],
result_dir + '/' + self.model_name + '_test_{}.png'.format(filename))
end_Time = time.time()
print("process time % s "% (end_Time-start_Time))
SAGAN
1、gan = SAGAN(sess, args)
2、Gan.build_model()—generator() +utils.py.ImageData()
3、Show_all_veriables()
4、Gan.train()
5、gan.visualize_results(args.epoch-1
6、Gan.test()
7、tmp = self.sess.run(self.out,feed_dict={self.input:[x[iself.img_size:(i+1)self.img_size,jself.img_size:(j+1)self.img_size]]})[0]
feed_dict喂入网络,self.out 网络输出
[0]代表有批次,其实输入应该为1wh*c,虽然只有一张图片,[0]就代表第一张即第一批次。
花絮:
在SAGAN调试之前进行Big-GAN的调研调试,未调通后进行SAGAN的调研调试,时间紧急未深究。
big-gan相关:
Large Scale GAN Training For High Fidelity Natural Image Synthesis
ICLR 2019
Deep Mind 团队
BigGAN 论文获得了8、7、10的评分
1、性能出色,称为“史上最强GAN图像生成器”
2、正交正则化用于生成器网络能够起到很好的效果,通过对隐变量的空间进行截断处理,能够在样本的真实性与多样性之间进行精细的平衡控制。
3、在类别控制的图像生成问题上取得了新高。
4、在ImageNet数据集下Inception Score 竟然比当前最好GAN模型SAGAN提高了100多分(接近2倍)
主要创新点:
1、证实了增大GAN的规模能够显著的提升建模的效果。
2、提出了两种简单而又具有一般性的框架改进,可以提高模型的伸缩性,并且改进了一种正则化策略来提升条件作用,证明了这些方法能够提升性能。
3、作为这些修改策略的副产品,本文提出的模型变得更服从截断技巧。截断技巧是一种简单的采样方法,能够在样本的逼真性、多样性之间做显示的、细粒度的控制。
4、发现了使大规模GAN不稳定的原因,对他们进行了经验性的分析,更进一步的,作者发现将已有的和新的技巧的组合使用能够降低这种不稳定性。但是完全的训练稳定性只有在巨大的性能代价下才能获得。对稳定性的分析通过生成器和判别器权重矩阵的奇异值分析而实现。
5、http://coder55.com/article/94002
6、https://aistudio.baidu.com/bdcpu5/user/210689/3240022/notebooks/3240022.ipynb
7、https://github.com/ajbrock/BigGAN-PyTorch
8、https://github.com/15805658608/tensorflow-GANs