Cynthia

V1

2022/06/16阅读:28主题:自定义主题1

由浅入深了解Diffusion Model

前言

其实早在去年就看过大佬Lil关于diffusion model精彩的介绍What are Diffusion Models? [1]但是后面一直没深入研究,很快就忘细节了。最近Diffusion Model火到爆炸(GLIDE[2],DALLE2[3],Imagen[4],和一系列Image Editing方法等等),所以又重新建起来学习了下。恐怕diffusion拥有成为下一代图像生成模型的代表的潜力(或者已经是了?)本文主要是对Lil博客进行翻译整理,会添加一些细节的理解和对照代码的思考,主要是为了方便自己学习记录,如果有理解错误的地方还请指出。

什么是Diffusion Model(扩散模型)?

首先我们来看一下最近火爆各个公众号的text-to-image结果:

图1. DALLE2生成结果

图2. Imagen生成结果

上述图片的结果都非常惊人,无论从真实度还是还原度都几乎无可挑剔。这里我们从由浅入深来了解一下Diffusion Model。首先还是放一个各类生成模型对比图:

图3. 不同生成模型对比图(来源:Lil博客)

diffusion model和其他模型最大的区别是它的latent code(z)和原图是同尺寸大小的,当然最近也有基于压缩的latent diffusion model[5],不过是后话了。一句话概括diffusion model,即存在一系列高斯噪声( 轮),将输入图片 变为纯高斯噪声 。而我们的模型则负责将 复原回图片 。这样一来其实diffusion model和GAN很像,都是给定噪声 生成图片 ,但是要强调的是,这里噪声 与图片 同维度的。

diffusion model有很多种理解,这里介绍是基于denoising diffusion probabilistic models (DDPM)[6]的_。_

Diffusion前向过程

所谓前向过程,即往图片上加噪声的过程。虽然这个步骤无法做到图片生成,但是这是理解diffusion model以及构建训练样本GT至关重要的一步。

给定真实图片 ,diffusion前向过程通过 次累计对其添加高斯噪声,得到 ,如下图的q过程。这里需要给定一系列的高斯分布方差的超参数 .前向过程由于每个时刻 只与 时刻有关,所以也可以看做马尔科夫过程:

这个过程中,随着 的增大, 越来越接近纯噪声。当 是完全的高斯噪声(下面会证明,且与均值系数 的选择有关)。且实际中 随着t增大是递增的,即 。在GLIDE的code中, 是由0.0001 到0.02线性插值(以 为基准, 增加, 对应降低)。

图4. diffusion的前向(q)和逆向(p)过程,来源:DDPM

前向过程介绍结束前,需要讲述一下diffusion在实现和推导过程中要用到的两个重要特性。

特性1:重参数(reparameterization trick)

重参数技巧在很多工作(gumbel softmax, VAE)中有所引用。如果我们要从某个分布中随机采样(高斯分布)一个样本,这个过程是无法反传梯度的。而这个通过高斯噪声采样得到 的过程在diffusion中到处都是,因此我们需要通过重参数技巧来使得他可微。最通常的做法是吧随机性通过一个独立的随机变量( )引导过去。举个例子,如果要从高斯分布 采样一个z,我们可以写成:

上式的z依旧是有随机性的, 且满足均值为 方差为 的高斯分布。这里的 可以是由参数 的神经网络推断得到的。整个“采样”过程依旧梯度可导,随机性被转嫁到了 上。

特性2:任意时刻的 可以由 表示

能够通过 快速得到 对后续diffusion model的推断和推导有巨大作用。首先我们假设 ,并且 ,展开 可以得到:

由于独立高斯分布可加性,即 ,所以

因此可以混合两个高斯分布得到标准差为 的混合高斯分布,然而Eq(3)中的 仍然是标准高斯分布。而任意时刻的 满足 .

一开始笔者一直不清楚为什么Eq(1)中diffusion的均值每次要乘上 .明明 只是方差系数,怎么会影响均值呢?替换为任何一个新的超参数,保证它<1,也能够保证值域并且使得最后均值收敛到0(但是方差并不为1). 然而通过Eq(3)(4),可以发现当 , .所以 的均值系数能够稳定保证 最后收敛到方差为1的标准高斯分布,且在Eq(4)的推导中也更为简洁优雅。(注:很遗憾,笔者并没有系统地学习过随机过程,也许 就是diffusion model前向过程收敛到标准高斯分布的唯一解,读者有了解也欢迎评论)

Diffusion逆向(推断)过程

如果说前向过程(forward)是加噪的过程,那么逆向过程(reverse)就是diffusion的去噪推断过程。如果我们能够逐步得到逆转后的分布 ,就可以从完全的标准高斯分布 还原出原图分布 .在文献[7]中证明了如果 满足高斯分布且 足够小, 仍然是一个高斯分布。然而我们无法简单推断 ,因此我们使用深度学习模型(参数为 ,目前主流是U-Net+attention的结构)去预测这样的一个逆向的分布 (类似VAE) :

虽然我们无法得到逆转后的分布 ,但是如果知道 ,是可以通过贝叶斯公式得到 为:

过程如下:

上式(7-1)巧妙地将逆向过程全部变回了前向,即 ,而(7-2)分别写出其对应的高斯概率密度函数,(7-3)则整理成了 的高斯分布概率密度函数形式。一般的高斯概率密度函数的指数部分应该写为 ,因此稍加整理我们可以得到(6)中的方差和均值为:

根据特性2,我们得知 ,因此带入(8-2)可以得到

其中高斯分布 为深度模型所预测的噪声(用于去噪),可看做为 ,即得到:

这样一来,DDPM的每一步的推断可以总结为:

1) 每个时间步通过 来预测高斯噪声 ,随后根据(9)得到均值 .

2) 得到方差 ,DDPM中使用untrained ,且认为 结果近似,在GLIDE中则是根据网络预测trainable方差 .

3) 根据(5-2)得到 ,利用重参数得到 .

在x0和xt反复横跳的diffusion逆向过程

Diffusion训练

搞清楚diffusion的逆向过程之后,我们算是搞清楚diffusion的推断过程了。但是如何训练diffusion model以得到靠谱的 呢?通过对真实数据分布下,最大化模型预测分布的对数似然,即优化在 下的 交叉熵:

从图4可以得知这个过程很像VAE,即可以使用变分下限(VLB)来优化负对数似然。由于KL散度非负,可得到: