Tic商业评论

关注微信公众号【站长自定义模块】,定时推送前沿、专业、深度的商业资讯。

 找回密码
 立即注册

QQ登录

只需一步,快速开始

微信登录

微信扫码,快速开始

  • QQ空间
  • 回复
  • 收藏

GAN — Unrolled GAN(如何减少模式崩溃)

lijingle gan 2022-1-13 21:27 1716人围观

直觉:在任何游戏中,你都会提前预测对手接下来的几步棋,并相应地准备下一步棋。 在 Unrolled GAN 中,我们为生成器提供了一个机会,可以展开 k 步来了解鉴别器如何优化自身。 然后我们使用反向传播更新生成器,并使用最后 k 步计算的成本。 前瞻不鼓励生成器利用容易被鉴别器抵消的局部最优值。 否则,模型会振荡,甚至变得不稳定。 Unrolled GAN 降低了生成器过度拟合特定判别器的机会。 这减少了模式崩溃并提高了稳定性。
本文是 GAN 系列的一部分。 由于模式崩溃很常见,我们花一些时间来探索 Unrolled GAN,看看如何解决模式崩溃问题。

鉴别器训练
在 GAN 中,我们计算成本函数并使用反向传播来拟合判别器 D 和生成器 G 的模型参数。


我们重新绘制下图以强调模型参数 θ。 红色箭头显示了我们如何反向传播成本函数 f 以拟合模型参数。


这是成本函数和梯度下降。 (为了说明的目的,我们使用简单的梯度下降)


在下图中,我们添加了 SGD(梯度下降公式)来明确定义判别器参数的计算方式。

在 Unrolled GAN 中,我们训练判别器的方式与 GAN 完全相同。


生成器训练

Unrolled GAN 运行 k 个步骤来学习鉴别器如何针对特定生成器优化自身。 通常,我们使用 5 到 10 个展开步骤,这表明模型性能非常好。 下图将过程展开 3 次。



成本函数基于最新的鉴别器模型参数,而生成器的模型参数保持不变。


在每一步,我们应用梯度下降来优化判别器的新模型。



但如前所述,我们只使用第一步来更新判别器。 生成器使用展开来预测移动,但不用于鉴别器优化。 否则,我们可能会过度拟合特定生成器的判别器。


对于生成器,我们在所有 k 步中反向传播梯度。 这与 LSTM 如何展开以及梯度如何反向传播非常相似。 由于我们有 k 个展开的步骤,因此生成器还会累积参数更改 k 次(每一步一次),如上所示。


总而言之,Unrolled GAN 使用在最后一步计算的成本函数来执行生成器的反向传播,而鉴别器仅使用第一步。


代码

Unrolled GAN 的实现可以从这里找到。 实际上,这很简单。 展开 k 步的核心逻辑很简单:

for i in range(params['unrolling_steps'] - 1):
    cur_update_dict = graph_replace(update_dict, cur_update_dict)
    unrolled_loss = graph_replace(loss, cur_update_dict)

graph_replace 使用上一步中的最新鉴别器模型加载鉴别器。 这是在 TensorFlow 中构建计算图的核心逻辑。

with slim.arg_scope([slim.fully_connected],   
     weights_initializer=tf.orthogonal_initializer(gain=1.4)):
    samples = generator(noise, output_dim=params['x_dim'])
    real_score = discriminator(data)
    fake_score = discriminator(samples, reuse=True)

loss = tf.reduce_mean(
          tf.nn.sigmoid_cross_entropy_with_logits(logits=real_score, 
             labels=tf.ones_like(real_score)) +
          tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_score, 
             labels=tf.zeros_like(fake_score)))

...
updates = d_opt.get_updates(disc_vars, [], loss)
d_train_op = tf.group(*updates, name="d_train_op")
...
# Get dictionary mapping from variables to their update value
# after one optimization step
update_dict = extract_update_dict(updates)
cur_update_dict = update_dict
for i in range(params['unrolling_steps'] - 1):
    cur_update_dict = graph_replace(update_dict, cur_update_dict)
    unrolled_loss = graph_replace(loss, cur_update_dict)
...
g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)
...
f, _, _ = sess.run([[loss, unrolled_loss], g_train_op, d_train_op])


实验

在下面的实验中,我们从一个玩具数据集开始,其中包含 8 个高斯分布的混合。 提供了一个不太复杂的生成器,第二行中的 GAN 设法生成了良好的数据质量,但未能实现多样性。 模式崩溃。 应用 Unrolled GAN,它发现了所有 8 种高质量模式(第一行)。



RNN 生成器特别容易受到模式崩溃的影响。 Unrolled GAN(第一行)设法发现所有 10 种模式,而常规 GAN 模型崩溃(第二行)。




路过

雷人

握手

鲜花

鸡蛋
我有话说......

TA还没有介绍自己。

电话咨询: 135xxxxxxx
关注微信