Tic商业评论

关注微信公众号【站长自定义模块】,定时推送前沿、专业、深度的商业资讯。

 找回密码
 立即注册

QQ登录

只需一步,快速开始

微信登录

微信扫码,快速开始

  • QQ空间
  • 回复
  • 收藏

GAN — CGAN & InfoGAN(使用标签改进 GAN)

lijingle gan 2022-1-9 15:45 3110人围观

在判别模型(如分类)中,我们手工制作特征以使模型表现更好。 如果模型有足够的容量并且知道如何自己学习这些特征,则不需要这种做法。 在 GAN 中,训练模型非常重要。 它可以从标签中获得额外的帮助,以使模型表现更好。
在 CGAN(条件 GAN)中,标签充当潜在空间 z 的扩展,以更好地生成和区分图像。 下面的上图是常规 GAN,下图为生成器和判别器添加了标签,以更好地训练两个网络。


整个机制仍未完全了解。 这些标签可能会为 GAN 寻找什么提供一个重要的开端。 另一种可能是我们的视觉系统有偏见,对这些标签更敏感。 因此,生成的图像被认为更好。 作为 GAN 系列的一部分,本文研究了如何使用标签提高 GAN 的性能。


CGAN(条件GAN)

在 GAN 中,无法控制要生成的数据的模式。 条件 GAN 通过将标签 y 作为附加参数添加到生成器来改变这一点,并希望生成相应的图像。 我们还将标签添加到鉴别器输入中,以更好地区分真实图像。


在 MNIST 中,我们从均匀分布中对标签 y 进行采样以生成从 0 到 9 的数字。我们将此值编码为 1-hot 向量。 例如,值 3 将被编码为 (0, 0, 0, 1, 0, 0, 0, 0, 0, 0)。 我们将向量和噪声 z 输入生成器,以创建类似于“3”的图像。 对于鉴别器,我们将假定的标签作为 one-hot 向量添加到其输入中。

CGAN 的cost函数与 GAN 相同。


D(x|y) 和 G(z|y) 表明我们正在区分并生成给定标签 y 的图像。 (与其他图中的 D(x, y) 和 G(z, y) 相同。)这是 CGAN 的数据流。


在 CGAN 中,我们可以扩展机制以包含训练数据集可能提供的其他标签。 例如,如果已知数字的笔画大小,我们可以从正态分布中对其进行采样,并将其添加到生成图像中。


InfoGAN

CGAN 中的标签 y 在数据集中提供。 或者,我们可以使用我们的鉴别器来提取所有这些潜在特征。 在下面的示例中,我们从均匀分布中采样单个特征 c 并将其转换为 1-hot 向量。 然后生成器使用这个向量和 z 来生成图像。


当我们将图像输入判别器时,它会输出 D(x) 和一个附加输出:概率分布 Q(c|x)(给定图像 x 的 c 的概率分布。)例如,给定生成的图像类似于数字 “3”,Q 可以估计为 (0.1, 0, 0, 0.8, ...),这意味着图像是数字“0”的可能性为 0.1,而图像是数字“3”的可能性为 0.8。


我们用一个额外的项 I(x;y) 减去常规 GAN cost函数,以形成我们的新cost函数。


如果我们知道 y,I(互信息)衡量我们对 x 的了解程度。 如果图像 x 和估计的 c 完全不相关,则 I(c;x) 等于 0。 否则,如果判别器能正确预测 c,I 会很高,降低 InfoGAN 成本。

在没有证据的情况下,互信息 I 可以使用熵来估计。 我们使用 Q(c|x) 和 P(C) 来建立 I 的下界。


其中 H 代表熵。 当模型执行时,我将收敛到它的下限。 互信息的概念可能需要时间来理解。但是,它的代码非常简单。

# If the image is a "3"
# p_c = P(c) = [0, 0, 0, 1.0, 0, 0, 0, 0, 0, 0]
# The first term for I = - ∑ P(c) * log Q(c|x)
cross_H_p_q = tf.reduce_mean(
                -tf.reduce_sum(p_c * tf.log(Q_c_given_x + 1e-8), 1))
# The entropy of c: H(c) = - ∑ P(c) * log(P(c)) 
H_c = tf.reduce_mean(-tf.reduce_sum(p_c * tf.log(p_c + 1e-8), 1))

路过

雷人

握手

鲜花

鸡蛋
我有话说......
电话咨询: 135xxxxxxx
关注微信