Tensorflow2.0实战之GAN

GAN 入门

自 2014 年 Ian Goodfellow 的《生成对抗网络(Generative Adversarial Networks)》论文发表以来,GAN 的进展突飞猛进,生成结果也越来越具有照片真实感。

就在三年前,Ian Goodfellow 在 reddit 上回答 GAN 是否可以应用在文本领域的问题时,还认为 GAN 不能扩展到文本领域。

Tensorflow2.0实战之GAN_数据集

“由于 GAN 定义在实值数据上,因此 GAN 不能应用于 NLP。

GAN 的工作原理是训练一个生成网络,输出合成数据,然后利用判别网络判别合成数据。判别网络根据合成数据输出的梯度告诉你该如何对合成数据进行微调,使其更真实。

因此只有当合成数据是基于连续数字时,才能对其进行微调。如果是基于离散的数字,就没有办法做微小的改变。

例如,如果输出像素值为 1.0 的图像,则下一步可以将该像素值更改为 1.0001。

但如果输出单词‘penguin’,不能在下一步直接将其更改为‘penguin+.001’,因为没有‘penguin+.001’这样的单词。你必须从‘penguin’直接转变到‘ostrich’。

由于所有的 NLP 都是基于离散的值,如单词、字符或字节,所以目前还没有人知道该如何将 GAN 应用于 NLP。”

但是现在,GAN 已经可用于生成各种内容,包括图像、视频、音频和文本。这些输出的合成数据既可以用于训练其他的模型,也可以用于创建一些有趣的项目。

GAN 原理

GAN 由两个神经网络组成,一个是合成新样本的生成器,另一个是对比训练样本与生成样本的判别器。判别器的目标是区分“真实”和“虚假”的输入(对样本来自模型分布还是真实分布进行分类)。这些样本可以是图像、视频、音频片段和文本。

Tensorflow2.0实战之GAN_2d_02

为了合成这些新的样本,生成器的输入为随机噪声,然后尝试从训练数据中学习到的分布中生成真实的图像。

判别器网络(卷积神经网络)输出相对于合成数据的梯度,其中包含着如何改变合成数据以使其更具真实感的信息。最终生成器收敛,它可以生成符合真实数据分布的样本,而判别器无法区分生成数据和真实数据。

ok,接下来我们就来实现一下

准备阶段

下载数据集


解压数据集

将下载好的数据集解压,放在工程目录下

Tensorflow2.0实战之GAN_数据集_03

加载数据集

加载数据集的代码,笔者这里直接提供给大家了,下面只是展示部分代码,文末会提供完整项目的代码链接

登录后复制

import multiprocessingimport tensorflow as tfdef make_anime_dataset
(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):    
@tf.function    def _map_fn(img):        img = tf.image.resize(img, [resize, resize])        
img = tf.clip_by_value(img, 0, 255)        img = img / 127.5 - 1        return img    
dataset = disk_image_batch_dataset(img_paths,                                          
batch_size,                                          drop_remainder=drop_remainder,
map_fn=_map_fn,                                          shuffle=shuffle,
repeat=repeat)    img_shape = (resize, resize, 3)    len_dataset = len(img_paths) // batch_size    
return dataset, img_shape, len_datasetdef batch_dataset(dataset,                  batch_size,
drop_remainder=True,                  n_prefetch_batch=1,                  filter_fn=None,
map_fn=None,                  n_map_threads=None,                  filter_after_map=False,
shuffle=True,                  shuffle_buffer_size=None,                  
repeat=None):1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.



构建网络

搭建Generator,Generator包含两个部分,init部分和前向传播的call部分,代码如下

登录后复制

class Generator(keras.Model):    def __init__(self):        super(Generator, self).__init__()        
# z:[b,100]-->[b,3*3*512]-->[b,3,3,512]-->[b,64,64,3]        self.fc=keras.layers.Dense(3*3*512)        
self.conv1=keras.layers.Conv2DTranspose(256,3,3,'valid')  # 反卷积        
self.bn1=keras.layers.BatchNormalization()        self.conv2=keras.layers.Conv2DTranspose(128,5,2,'valid')
self.bn2=keras.layers.BatchNormalization()        self.conv3=keras.layers.Conv2DTranspose(3,4,3,'valid')
def call(self, inputs, training=None, mask=None):        # [z,100]-->[z,3*3*512]        x=self.fc(inputs)
x=tf.reshape(x,[-1,3,3,512])        x=tf.nn.leaky_relu(x)        
x=tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))        
x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))        
x=self.conv3(x)        x=tf.tanh(x)        
return x1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.

搭建Discriminator,同上

登录后复制

class Discriminator(keras.Model):    def __init__(self):        super(Discriminator, self).__init__()        
# [b,64,64,3]-->[b,1]        self.conv1=keras.layers.Conv2D(64,5,3,'valid')        
self.conv2=keras.layers.Conv2D(128,5,3,'valid')        self.bn2=keras.layers.BatchNormalization()        
self.conv3=keras.layers.Conv2D(256,5,3,'valid')        self.bn3=keras.layers.BatchNormalization()        
# [b,h,w,c]-->[b,-1]        self.flatten=keras.layers.Flatten()        # [b,-1]-->[b,1]        
self.fc=keras.layers.Dense(1)    def call(self, inputs, training=None, mask=None):        
x=tf.nn.leaky_relu(self.conv1(inputs))        
x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))        
x=tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))        
x=self.flatten(x)        logits=self.fc(x)        
return logits1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.

训练GAN

定义相关数据,包括epoch,lr等等

这些数据可以自定义,笔者这里就不改动了

登录后复制

z_dim = 100    epochs = 50000    batch_size = 512    learning_rate = 0.0002    is_training = True1.2.3.4.5.

加载数据

登录后复制

img_path=glob.glob(r'E:\python_pro\TF2.0\GAN\faces\*.jpg')    dataset, 
img_shape, _ = make_anime_dataset(img_path, batch_size)1.2.

可以打印查看数据集信息:

登录后复制

(512, 64, 64, 3), (64, 64, 3)(512, 64, 64, 3) ,1.0, -1.01.2.

定义优化器,注意我们在开始训练时,需要新建训练GAN图片的文件,为查看数据提供持久化依据

登录后复制

for epoch in range(epochs):        batch_z = tf.random.
uniform([batch_size, z_dim], minval=-1., maxval=1.)        
batch_x = next(db_iter)        # train D        with tf.GradientTape() as tape:            
d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)        
grads = tape.gradient(d_loss, discriminator.trainable_variables)        
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))        
with tf.GradientTape() as tape:            g_loss = g_loss_fn(generator, 
discriminator, batch_z, is_training)        grads = tape.gradient(g_loss, 
generator.trainable_variables)        g_optimizer.apply_gradients(zip(grads, 
generator.trainable_variables))        if epoch % 100 == 0:            
print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))            
z = tf.random.uniform([100, z_dim])            fake_image = generator(z, training=False)            
img_path = os.path.join('GAN_IMAGE', 'gan%d.png'%epoch)            
save_result(fake_image.numpy(), 10, img_path, color_mode='P')1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.
16.17.18.19.20.21.22.23.24.

训练结果

接下来我们来看看,训练的效果图,注意,GAN的训练过程是非常非常非常慢的,大概训练十几个小时,才能有个比较好的效果,有的数据集甚至会训练几天之久,这个随数据集的大小和对最终效果的要求来定的。笔者这个数据集比较的简单,只是给大家做演示,好了,废话就不过多的说了,上图

Tensorflow2.0实战之GAN_2d_04Tensorflow2.0实战之GAN_数据_05Tensorflow2.0实战之GAN_数据_06Tensorflow2.0实战之GAN_2d_07

上述分别是训练了100epoch、500、1500、4000的效果图,可以看到随着训练的次数增加,效果因为越来越好了

总结

大家在训练GAN时,还是需要一个好一些的GPU显卡才行,这样可以体验GPU给我们带来的加速效果。这样会使得训练的速度大大加快。

笔者水平有限,如有表述不准确的地方还请谅解,有错误的地方欢迎大家批评指正。

最后还是希望大家动手实践实践,共同进步。

最终的代码链接:https://github.com/huzixuan1/TF_2.0/tree/master/GAN



免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空