码迷,mamicode.com
首页 > Web开发 > 详细

Wasserstein Generative Adversarial Nets(WGAN)

时间:2018-04-05 21:00:26      阅读:369      评论:0      收藏:0      [点我收藏+]

标签:有趣   模式   dropbox   src   imp   输出   sample   get   https   

GAN目前是机器学习中非常受欢迎的研究方向。主要包括有两种类型的研究,一种是将GAN用于有趣的问题,另一种是试图增加GAN的模型稳定性。

事实上,稳定性在GAN训练中是非常重要的。起初的GAN模型在训练中存在一些问题,e.g., 模式塌陷生成器演化成非常窄的分布,只覆盖数据分布中的单一模式)。模式塌陷的含义是发生器只能产生非常相似的样本(例如MNIST中的单个数字),即所产生的样本不是多样的。这当然违反了GAN初衷

GAN中的另一个问题是没有指很好的指标或度量说明模型的收敛性生成器鉴别器损失并没有告诉我们关于这方面的任何信息。当然,我们可以通过查看生成器产生的数据来监控训练过程。但是,这是一个愚蠢的手动过程。所以,我们需要一个可解释指标告诉我们训练过程的好坏。

Wasserstein GAN

Wasserstein GAN(WGAN)是一种新提出的GAN算法,可以在一定程度解决上述两个问题。对于WGAN背后的直觉和理论背景,可以查看相关资料

整个算法的伪代码如下:

技术分享图片

我们可以看到该算法与原始GAN算法非常相似。 但是,对于WGAN,我们根据上面的代码需要注意到下几点:
  1. 损失函数中没有log。判别器D(X)的输出不再是一个概率(标量),同时也就意味着没有sigmoid激活函数
  2. 对于判别器D(X)的权重W进行裁剪
  3. 训练判别器的次数生成器
  4. 采用RMSProp优化器,代替原先的ADAM优化器
  5. 非常低的learning rate, α=0.00005

WGAN TensorFlow implementation

GAN的基本实现可以在上一篇文章中介绍过。 我们只需要稍微修改下传统的GAN。 首先,让我们更新我们的判别器D(X)

技术分享图片
""" Vanilla GAN """
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    out = tf.matmul(D_h1, D_W2) + D_b2
    return tf.nn.sigmoid(out)

""" WGAN """
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    out = tf.matmul(D_h1, D_W2) + D_b2
    return out
View Code

接下来,修改loss函数,去掉log

技术分享图片
""" Vanilla GAN """
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

""" WGAN """
D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)
G_loss = -tf.reduce_mean(D_fake)
View Code

在每次梯度下降更新后,裁剪判别器D(X)的权重:

# theta_D is list of D‘s params
clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in theta_D]

然后,只需要训练更多次的判别器D(X)就行了

技术分享图片
D_solver = (tf.train.RMSPropOptimizer(learning_rate=5e-5)
            .minimize(-D_loss, var_list=theta_D))
G_solver = (tf.train.RMSPropOptimizer(learning_rate=5e-5)
            .minimize(G_loss, var_list=theta_G))

for it in range(1000000):
    for _ in range(5):
        X_mb, _ = mnist.train.next_batch(mb_size)

        _, D_loss_curr, _ = sess.run([D_solver, D_loss, clip_D], feed_dict={X: X_mb, z: sample_z(mb_size, z_dim)})

    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={z: sample_z(mb_size, z_dim)})
View Code

 

Wasserstein Generative Adversarial Nets(WGAN)

标签:有趣   模式   dropbox   src   imp   输出   sample   get   https   

原文地址:https://www.cnblogs.com/skykill/p/8724147.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!