为什么梯度下降法对 GAN 的训练效果并不好?

  • 我们来探讨一下生成对抗网络的 优化器。进而讨论为什么梯度下降法并不太适合生成对抗网络,GAN 的训练为什么常常会 失败
  • 本文是 Make Your First GAN With PyTorch 的附录 D,本书的介绍详见这篇文章


1. 梯度下降法适用于 GAN 训练吗?

我们构建并训练神经网络,使用的 梯度下降法(gradient descent) ,通过找到损失函数的下降通道来找到可学习参数的组合,使得误差最小化,达到训练网络的目的。

  • 比如前面手写数字分类的网络,使用了 Adam 优化器,效果较为理想。

但是 GAN 的动态性不同于简单的神经网络,生成器和鉴别器网络的目标是相反的。

  • GAN 与对抗游戏是类似的,都是一个选手试图最大化某个目标,而另一个选手是为 了最小化这个目标,每个选手都会试图抵消另一方之前的动作的优势。
  • 那么,梯度下降法适用于类似的对抗游戏吗?

2. 简单的对抗游戏

考虑一个非常简单的目标函数:

f = x ⋅ y f=x \cdot yf=xy

  • 其中一个选手控制 x xx 的值,试图通过改变 x xx最大化 目标函数 f ff
  • 另一个选手则控制 y yy,试图使得目标函数 f ff 的值 最小

我们将这个函数可视化,下面的图片从三个不同的角度展示了 f = x ⋅ y f = x \cdot yf=xy 的表面图:

在这里插入图片描述

  • 可以看到, f = x ⋅ y f = x \cdot yf=xy 表面是 马鞍形(saddle) 的。
  • 这意味着,我们沿着一个方向运动时(对应某个值单调变化),目标值将先上升后下降;但是沿另一个方向,目标值会先下降后上升。

下面的图片使用颜色和箭头来表示 f ff 的值和导数的大小及变化方向:

在这里插入图片描述

  • 考虑一下,如果我们是对抗游戏的一方,该怎么做能获得最佳(或者避免最差)的收益呢?

如果使用直觉来说的话,我们可能会说最好的答案是在马鞍的中心点 ( x , y ) = ( 0 , 0 ) (x, y) = (0, 0)(x,y)=(0,0)

  • 在这个点,如果某个选手设置 x = 0 x = 0x=0,第二个选手无论如何选择 $y $的值,都不会影响 f ff 的值;
  • 类似的,如果先设定 y = 0 y = 0y=0x xx 的任何取值都不会改变 y yy 的值。

可以认为,这个值将使得两个选手同等快乐——或者同等不快乐。

”直觉“ 可能只能让我们避免 “不输”(而且,似乎并没有什么依据),下面我们基于 梯度下降 来模拟两个选手,找到对两个选手都很满意的解决方案。

我们可以依赖于目标函数 f ff 的梯度,对参数 x xxy yy 进行小幅度的调整:

x → x + l r ∗ δ f / δ x y → y − l r ∗ δ f / δ y x \rightarrow x + lr *\delta f / \delta x \\ y \rightarrow y - lr *\delta f / \delta yxx+lrδf/δxyylrδf/δy

  • 上面两个式子我们称之为 更新法则(update rules);两个式子的符号不同,是因为 y yy 通过降低梯度来最小化 f ff,而 x xx 是通过增加梯度来最大化 f ff。另外, l r lrlr 是学习率。

f = x ⋅ y f = x \cdot yf=xy 的导数进行计算,可以将上式进行改写:

x → x + l r ∗ y y → y − l r ∗ x x \rightarrow x + lr * y \\ y \rightarrow y - lr * xxx+lryyylrx

我们可以撰写一些代码,首先随机指定 x xxy yy 的初始值,然后重复地应用更新法则来获得连续的 x xxy yy 的值。

下图显示了随着训练进程,x xxy yy 是如何变化的:

在这里插入图片描述
可以看到 x xxy yy 的值并不收敛,而且振荡的幅度越来越大。

  • 尝试不同的初始值,结果也大同小异。减小学习率只不过是延迟了这种不可避免的 分歧(divergence)

这很不好。

  • 上述结果,显示出梯度下降并不能找到这个简单对抗游戏的解决方案,甚至更差一些,这个方法会导致灾难性的分歧。

下图将 x xxy yy 放在一起展示,可以看到这两个值是环绕着理想点 (0,0) 的,但是距离越来越远。

在这里插入图片描述

3. 梯度下降法并不是对抗游戏的理想选择

本文,我们使用一个简单的目标函数显示梯度下降并不能发现对抗游戏的解决方案。

事实上,它不仅仅是找不到解决方案,而是灾难性地出现分歧。

这是不是意味着 GAN 训练一般都会失败?

实际上,使用有意义的数据的实际 GAN 的函数一般非常复杂,这可以减少失控 分歧的概率,这也是我们系列文章中 GAN 的训练都相当好的原因。

但是本文的分析确实解释训练 GAN 困难和混沌的原因。

环绕一个好的解决方案运行也可以解释为什么许多简单的 GAN 似乎通过扩展训练而不是提高图像质量,来解决不同的模式坍塌问题。

  • 无论如何,从根本上讲,梯度下降对于 GAN 来说是错误的,即使它在很多情况下工作得很好。寻找针对 GAN 中对抗性动态的优化技术是当前一个开放的研究问题,一些研究人员已经发表了令人鼓舞的结果。

版权声明:本文为soih0718原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。