0%

今天,一则重磅消息席卷了 AI 圈:OpenAI 发布了视频模型 Sora,能根据文本生成长达一分钟的高质量 1920x1080 视频,生成能力远超此前只能生成 25 帧 576x1024 图像的顶尖视频生成模型 Stable Video Diffusion。

同时,OpenAI 也公布了一篇非常简短的技术报告。报告仅大致介绍了 Sora 的架构及应用场景,并未对模型的原理详加介绍。让我们来快速浏览一下这份报告,看看科研人员从这份报告中能学到什么。

官网链接:https://openai.com/sora

技术报告链接:https://openai.com/research/video-generation-models-as-world-simulators

这篇文章没怎么贴视频,感兴趣的话可以对照着原报告中的视频阅读。

LDM 与 DiT 的结合

简单来说,Sora 就是 Latent Diffusion Model (LDM) [1] 加上 Diffusion Transformer (DiT) [2]。我们先简要回顾一下这两种模型架构。

LDM 就是 Stable Diffusion 使用的模型架构。扩散模型的一大问题是计算需求大,难以拟合高分辨率图像。为了解决这一问题,实现 LDM时,会先训练一个几乎能无损压缩图像的自编码器,能把 512x512 的真实图像压缩成 64x64 的压缩图像并还原。接着,再训练一个扩散模型去拟合分辨率更低的压缩图像。这样,仅需少量计算资源就能训练出高分辨率的图像生成模型。

LDM 的扩散模型使用的模型是 U-Net。而根据其他深度学习任务中的经验,相比 U-Net,Transformer 架构的参数可拓展性强,即随着参数量的增加,Transformer 架构的性能提升会更加明显。这也是为什么大模型普遍都采用了 Transformer 架构。从这一动机出发,DiT 应运而生。DiT 在 LDM 的基础上,把 U-Net 换成了 Transformer。

顺带一提,Transformer 本来是用于文本任务的,它只能处理一维的序列数据。为了让 Transformer 处理二维图像,通常会把输入图像先切成边长为 $p$ 的图块,再把每个图块处理成一项数据。也就是说,原来边长为 $I$ 的正方形图片,经图块化后,变成了长度为 $(I/p)^2$ 的一维序列数据。

Transformer 是一种和顺序无关的计算。比如对于输入”abc”和”bca”,Transformer 会输出一模一样的值。为了描述数据的先后顺序,使用 Transformer 时,一般会给数据加一个位置编码。

Sora 是一个视频版的 DiT 模型。让我们看一下 Sora 在 DiT 上做了哪些改进。

时空自编码器

在此之前,许多工作都尝试把预训练 Stable Diffusion 拓展成视频生成模型。在拓展时,视频的每一帧都会单独输入进 Stable Diffusion 的自编码器,再重新构成一个压缩过的图像序列。而 VideoLDM[3] 工作发现,直接对视频使用之前的图像自编码器,会令输出视频出现闪烁的现象。为此,该工作对自编码器的解码器进行了微调,加入了一些能够处理时间维度的模块,使之能一次性处理整段压缩视频,并输出连贯的真实视频。

Sora 则是从头训练了一套能直接压缩视频的自编码器。相比之前的工作,Sora 的自编码器不仅能在空间上压缩图像,还能在时间上压缩视频长度。这估计是为什么 Sora 能生成长达一分钟的视频。

报告中提到,Sora 也能处理图像,即长度为1的视频。那么,自编码器怎么在时间上压缩长度为1的视频呢?报告中并没有给出细节。我猜测该自编码器在时间维度做了填充(比如时间被压缩成原来的 1/2,那么就对输入视频填充空数据直至视频长度为偶数),也可能是输入了视频长度这一额外约束信息。

时空压缩图块

输入视频经过自编码器后,会被转换成一段空间和时间维度上都变小的压缩视频。这段压缩视频就是 Sora 的 DiT 的拟合对象。在处理视频数据时,DiT 较 U-Net 又有一些优势。

之前基于 U-Net 的去噪模型在处理视频数据时(如 [3]),都需要额外加入一些和时间维度有关的操作,比如时间维度上的卷积、自注意力。而 Sora 的 DiT 是一种完全基于图块的 Transformer 架构。要用 DiT 处理视频数据,不需要这种设计,只要把视频看成一个 3D 物体,再把 3D 物体分割成「图块」,并重组成一维数据输入进 DiT 即可。和原本图像 DiT 一样,假设视频边长为 $I$,时长也为 $I$,要切成边长为 $p$ 的图块,最后会得到 $(I/p)^3$ 个数据。

报告没有给出视频图块化的细节。

处理任意分辨率、时长的视频

报告中反复提及,Sora 在训练和生成时使用的视频可以是任何分辨率(在 1920x1080 以内)、任何长宽比、任何时长的。这意味着视频训练数据不需要做缩放、裁剪等预处理。这些特性是绝大多数其他视频生成模型做不到的,让我们来着重分析一下这一特性的原理。

Sora 的这种性质还是得益于 Transformer 架构。前文提到,Transformer 的计算与输入顺序无关,必须用位置编码来指明每个数据的位置。尽管报告没有提及,我觉得 Sora 的 DiT 使用了类似于 $(x, y, t)$ 的位置编码来表示一个图块的时空位置。这样,不管输入的视频的大小如何,长度如何,只要给每个图块都分配一个位置编码,DiT 就能分清图块间的相对关系了。

相比以前的工作,Sora 的这种设计是十分新颖的。之前基于 U-Net 的 Stable Diffusion 为了保证所有训练数据可以统一被处理,输入图像都会被缩放与裁剪至同一大小。由于训练数据中有被裁剪的图像,模型偶尔也会生成被裁剪的图像。生成训练分辨率以外的图像时,模型的表现有时也会不太好。SDXL [4] 的解决方式是把裁剪的长宽做为额外信息输入进 U-Net。为了生成没有裁剪的图像,只要令输入的裁剪长宽为 0 即可。类似地,SDXL 也把图像分辨率做为额外输入,使得 U-Net 学习不同分辨率、长宽比的图像。相比 SDXL,Sora 的做法就简洁多了。

之前基于 DiT 的模型 (比如华为的 PixArt [5])似乎都没有利用到 Transformer 可以随意设置位置编码这一性质。DiT 在处理输入图块时,会先把图块变形成一维数据,再从左到右编号,即从从左到右,从上到下地给二维图块组编号。这种位置编码并没有保留图像的二维空间信息,因此,在这种编码下,模型的输入分辨率必须固定。比如对于下面这个$4\times4$的图块组,如果是从左到右、从上到下编码,模型等于是强行学习到了「1号在0号右边、4号在0号下面」这样的位置信息。如果输入的图块形状为 $4 \times 5$,那么图块间的相对关系就完全对不上了。而如果像 Sora 这样以视频图块的 $(x, y, t)$ 来生成位置编码的话,就没有这种问题了,输入视频可以是任何分辨率、任何长度。

Transformer 在视频生成的可拓展性

前文提过,Transformer 的特点就是可拓展性强,即模型越大,训练越久,效果越好。报告中展示了1倍、4倍、16倍某单位训练时间下的生成结果,可以看出模型确实一直有进步。

语言理解能力

之前大部分文生图扩散模型都是在人工标注的图片-文字数据集上训练的。后来大家发现,人工标注的图片描述质量较低,纷纷提出了各种提升标注质量的方法。Sora 复用了自家 DALL·E 3 的重标注技术,用一个训练的能生成详细描述的标注器来重新为训练视频生成标注。这种做法不仅解决了视频缺乏标注的问题,且相比人工标注质量更高。Sora 的部分结果展示了其强大了抽象理解能力(如理解人和猫之间的交互),这多半是因为视频标注模型足够强大,视频生成模型学到了视频标注模型的知识。但同样,视频标注模型的相关细节完全没有公开。

其他生成功能

  • 基于已有图像和视频进行生成:除了约束文本外,Sora 还支持在一个视频前后补充内容(如果是在一张图片后面补充内容,就是图生视频)。报告没有给出实现细节,我猜测是直接做了反演(inversion)再把反演得到的隐变量替换到随机初始隐变量中。
  • 视频编辑:报告明确写出,只用简单的 SDEdit (即目前 Stable Diffusion 中的图生图)即可实现视频编辑。
  • 视频内容融合:可能是对两个视频的初始隐变量做了插值。
  • 图像生成:当然,Sora 也可以生成图像。报告表明,Sora 可以生成最大 2048x2048 的图像。

涌现出的能力

通过学习大量数据,Sora 还涌现出一些意想不到的能力。

  • 3D 一致性:视频中包含自然的相机视角变换。之前的 Stable Video Diffusion 也有类似发现。
  • 长距离连贯性:AI 生成出来的视频往往有物体在中途突然消失的情况。而 Sora 有时候能克服这一问题。
  • 与世界的交互:比如在描述画画的视频中,画纸上的内容随画笔生成。
  • 模拟数字世界:报告展示了在输入文本有”Minecraft”时,模型能生成非常真实的 Minecraft 游戏视频。这大概只能说明模型的拟合能力太强了,以至于学会了生成 Minecraft 这一种特定风格的视频。

局限性

报告结尾还是给出了一些失败的生成示例,比如玻璃杯在桌子上没有摔碎。这表明模型还不能完全学会某些物理性质。然而,我觉得现阶段 Sora 已经展示了足够强大的学习能力。想模拟现有视频中已经包含的物理现象,只需要增加数据就行了。

总结

Sora 是一个惊艳的视频生成模型,它以卓越的生成能力(高分辨率、长时间)与生成质量令一众同期的视频生成模型黯然失色。Sora 的技术报告非常简短,不过我们从中还是可以学到一些东西。从技术贡献上来看,Sora 的创新主要有两点:

  1. 让 LDM 的自编码器也在视频时间维度上压缩。
  2. 使用了一种不限制输入形状的 DiT

其中,第二点贡献是非常有启发性的。DiT 能支持不同形状的输入,大概率是因为它以视频的3D位置生成位置编码,打破了一维编码的分辨率限制。后续大家或许会逐渐从 U-Net 转向 DiT 来建模扩散模型的去噪模型。

我认为 Sora 的成功有三个原因。前两个原因对应两项创新。第一,由于在时间维度上也进行了压缩,Sora 最终能生成长达一分钟的视频;第二,使用 DiT 不仅去除了视频空间、时间长度上的限制,还充分利用了 Transformer 本身的可拓展性,使训练一个视频生成大模型变得可能。第三个原因来自于视频标注模型。之前 Stable Diffusion 能够成功,很大程度上是因为有一个能够关联图像与文本的 CLIP 模型,且有足够多的带标注图片。相比图像,视频训练本来就少,带标注的视频就更难获得了。一个能够理解视频内容,生成详细视频标注的标注器,一定是让视频生成模型理解复杂文本描述的关键。除了这几点原因外,剩下的就是砸钱、扩大模型、加数据了。

Sora 显然会对 AIGC 社区产生一定影响。对于 AIGC 爱好者而言,他们或许会多了一些生成创意视频的方法,比如给部分帧让 Sora 来根据文本补全剩余帧。当然,目前 Sora 依然不能取代视频创作者,长视频的质量依然有待观察。对于正在开发相似应用的公司,我觉得他们应该要连夜撤销之前的方案,转换为这套没有分辨率限制的 DiT 的方案。他们的压力应该会很大。对于相关科研人员而言,除了学习这种较为新颖的 DiT 用法外,也没有太多收获了。这份技术报告透露出一股「我绝对不会开源」的意思。没有开源模型,普通的研究者也就什么都做不了。新技术的诞生绝对不可能靠一家公司,一个模型就搞定。像之前的 Stable Diffusion,也是先开源了一个基础模型,科研者和爱好者再补充了各种丰富的应用。我呼吁各大公司尽快训练并开源一个这种不限分辨率的 DiT,这样科研界或许会抛开 U-Net,基于 DiT 开发出新的扩散模型应用。

参考论文

  1. Latent Diffusion Model, Stable Difusion: High-Resolution Image Synthesis with Latent Diffusion Models
  2. DiT: Scalable Diffusion Models with Transformers
  3. VideoLDM: Align your Latents: High-Resolution Video Synthesis with Latent Diffusion Models
  4. SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis
  5. PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis

在使用预训练 Stable Diffusion (SD) 生成图像时,如果将其 U-Net 的自注意力层在某去噪时刻的输入 K, V 替换成另一幅参考图像的,则输出图像会和参考图像更加相似。许多无需训练的 SD 编辑科研工作都运用了此性质。尤其对于是对于视频编辑任务,如果在生成某一帧时将注意力输入替换成之前帧的,则输出视频会更加连贯。在这篇文章中,我们将快速学习 SD 自注意力替换技术的原理,并在 Diffusers 里实现一个基于此技术的视频编辑流水线。

注意力计算

我们先来回顾一下 Transformer 论文中提出的注意力机制。所有注意力机制都基于一种叫做放缩点乘注意力(Scaled Dot-Product Attention)的运算:

其中,$Q \in \mathbb{R}^{a \times d_k}, K \in \mathbb{R}^{b \times d_k}, V \in \mathbb{R}^{b \times d_v}$。注意力计算可以理解成先算 $a$ 个长度为 $d_k$ 的向量对 $b$ 个长度为 $d_k$ 的向量的相似度,再以此相似度为权重算 $a$ 个向量对 $b$ 个长度为 $d_v$ 的向量的加权和。

注意力计算是没有可学习参数的。为了加入参数,Transformer 设计了如下所示的注意力层,其中 $W^Q, W^K, W^V, W^O$ 都是参数。

一般在使用注意力层时,会让$K=V$。这种注意力叫做交叉注意力。交叉注意力可以理解成数据 $A$ 想从数据 $B$ 里提取信息,提取的根据是 $A$ 里每个向量和 $B$ 里每个向量的相似度。

交叉注意力的特例是自注意力,此时 $Q=K=V$ 。这表示数据里的向量两两之间交换了一次信息。

SD 中的自注意力替换

SD 的 U-Net 既用到了自注意力,也用到了交叉注意力。自注意力用于图像特征自己内部信息聚合。交叉注意力用于让生成图像对齐文本,其 Q 来自图像特征,K, V 来自文本编码。

由于自注意力其实可以看成一种特殊的交叉注意力,我们可以把自注意力的 K, V 替换成来自另一幅参考图像的特征。这样,扩散模型的生成图片会既和原本要生成的图像相似,又和参考图像相似。当然,用来替换的特征必须和原来的特征「格式一致」,不然就生成不了有意义的结果了。

什么叫「格式一致」呢?我们知道,扩散模型在采样时有很多步,U-Net 中又有许多自注意力层。每一步时的每一个自注意力层的输入都有自己的「格式」。也就是说,如果你要把某时刻某自注意力层的 K, V 替换,就得先生成参考图像,用生成参考图像过程中此时刻此自注意力层的输入替换,而不能用其他时刻或者其他自注意力层的。

一般这种编辑技术只会用在自注意力层而不是交叉注意力层上,这是因为 SD 中的交叉注意力是用来关联图像与文字的,另一幅图像的信息无法输入。当然,除了 SD,只要是用到了自注意力模块的扩散模型,都能用此方法编辑,只不过大部分工作都是基于 SD 开发的。

自注意力替换的应用

自注意力替换最常见的应用是提升 SD 视频编辑的连续性。在此任务中,一般会先正常编辑第一帧,再将后续帧的自注意力的 K, V 替换成第一帧的。这种技术在文献中一般被称为帧间注意力(cross-frame attention)。较早提出此论文的工作是 Text2Video-Zero。

自注意力替换也可以用于提升单幅图像编辑的保真度。一个例子是拖拽单幅图像的 DragonDiffusion。此应用可以拓展到图像插值上,比如 DiffMorpher 在图像插值时对两幅参考图像的自注意力输入等比例插值,再替换掉对应插值图像的自注意力的 K, V。

在 Diffusers 里实现自注意力替换

Diffusers 的 U-Net 专门提供了用于修改注意力计算的 AttentionProcessor 类。借助相关接口,我们可以方便地修改注意力的计算方法。在这个示例项目中,我们来用 Diffusers 实现一个参考第一帧和上一帧的注意力输入的 SD 视频编辑流水线。相比逐帧生成编辑图片,该流水线的结果会更加平滑一点。项目网址:https://github.com/SingleZombie/DiffusersExample/tree/main/ReplaceAttn

AttentionProcessor

在 Diffusers 中,U-Net 的每一个注意力模块都有一个 AttentionProcessor 类的实例。AttentionProcessor 类的 __call__ 方法描述了注意力计算的过程。如果我们想修改某些注意力模块的计算,就需要自己定义一个注意力处理类,其 __call__ 方法的参数需与 AttentionProcessor 的兼容。之后,我们再调用相关接口把原来的处理类换成我们自己写的处理类。下面我们将先看一下 AttentionProcessor 类的实现细节,再实现我们自己的
注意力处理类。

AttentionProcessor 类在 diffusers/models/attention_processor.py 文件里。它只有一个 __call__ 方法,其主要内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class AttnProcessor:

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states

方法参数中,hidden_states 是 Q, encoder_hidden_states 是 K, V。如果 K, V 没有传入(为 None),则 K, V 会被赋值成 Q。该方法的实现细节和 Tranformer 中的注意力层完全一样,此处就不多加解释了。一般替换注意力的输入时,我们不用改这个方法的实现,只会在需要的时候调用这个方法。

attention_processor.py 文件中还有一个功能类似的类 AttnProcessor2_0,它和 AttentionProcessor 的区别在于它调用了 PyTorch 2.0 起启用的算子 F.scaled_dot_product_attention 代替手动实现的注意力计算。这个算子更加高效,如果你确定 PyTorch 版本至少为 2.0,就可以用 AttnProcessor2_0 代替 AttentionProcessor

看完了 AttentionProcessor 类后,我们来看该怎么在 U-Net 里将原注意力处理类替换成我们自己写的。U-Net 类的 attn_processors 属性会返回一个词典,它的 key 是每个处理类所在位置,比如 down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor,它的 value 是每个处理类的实例。为了替换处理类,我们需要构建一个格式一样的词典attn_processor_dict,再调用 unet.set_attn_processor(attn_processor_dict) ,取代原来的 attn_processors。假如我们自己实现了处理类 MyAttnProcessor,我们可以编写下面的代码来实现替换:

1
2
3
4
5
6
7
8
attn_processor_dict = {}
for k in unet.attn_processors.keys():
if we_want_to_modify(k):
attn_processor_dict[k] = MyAttnProcessor()
else:
attn_processor_dict[k] = AttnProcessor()

unet.set_attn_processor(attn_processor_dict)

实现帧间注意力处理类

熟悉了 AttentionProcessor 类的相关内容,我们来编写自己的帧间注意力处理类。在处理第一帧时,该类的行为不变。对于之后的每一帧,该类的 K, V 输入会被替换成视频第一帧和上一帧的输入在序列长度维度上的拼接结果,即:

你是否会感到疑惑:为什么 K, V 的序列长度可以修改?别忘了,在注意力计算中,Q, K, V 的形状分别是:$Q \in \mathbb{R}^{a \times d_k}, K \in \mathbb{R}^{b \times d_k}, V \in \mathbb{R}^{b \times d_v}$。注意力计算只要求 K,V 的序列长度 $b$ 相同,并没有要求 Q, K 的序列长度相同。

现在,注意力计算不再是一个没有状态的计算,它的运算结果取决于第一帧和上一帧的输入。因此,我们在注意力处理类中需要额外维护这两个变量。我们可以按照如下代码编写类的构造函数。除了处理继承外,我们还需要创建两个数据词典来存储不同时间戳下第一帧和上一帧的注意力输入。

1
2
3
4
5
class CrossFrameAttnProcessor(AttnProcessor):
def __init__(self):
super().__init__()
self.first_maps = {}
self.prev_maps = {}

在运行方法中,我们根据 encoder_hidden_states 是否为空来判断该注意力是自注意力还是交叉注意力。我们仅修改自注意力。当该注意力为自注意力时,假设我们知道了当前时刻 t,我们就可以根据 t 获取当前时刻第一帧和前一帧的输入,并将它们拼接起来得到 cross_map。以此 cross_map 为当前注意力的 K, V,我们就实现了帧间注意力。

1
2
3
4
5
6
7
8
9
10
11
12
13
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs):

if encoder_hidden_states is None:
# Is self attention
cross_map = torch.cat(
(self.first_maps[t], self.prev_maps[t]), dim=1)
res = super().__call__(attn, hidden_states, cross_map, **kwargs)

else:
# Is cross attention
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)

return res

由于 Diffusers 经常修改函数接口,在调用普通的注意力计算接口时,最好原封不动地按照 super().__call__(..., **kwargs) 写,不然这份代码就不能兼容后续版本的 Diffusers。

上述代码只描述了后续帧的行为。如前所述,我们的注意力计算有两种行为:对于第一帧,我们不修改注意力的计算过程,只缓存其输入;对于之后每一帧,我们替换注意力的输入,同时维护当前「上一帧」的输入。既然注意力在不同情况下有不同行为,我们就应该用一个变量来记录当前状态,让 __call__ 能根据此变量决定当前的行为。相关的伪代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs):

if encoder_hidden_states is None:
# Is self attention
if self.state == FIRST_FRAME:
res = super().__call__(attn, hidden_states, cross_map, **kwargs)
# update maps
else:
cross_map = torch.cat(
(self.first_maps[t], self.prev_maps[t]), dim=1)
res = super().__call__(attn, hidden_states, cross_map, **kwargs)
# update maps

else:
# Is cross attention
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)

return res

在伪代码中,self.state 表示当前注意力的状态,它的值表明注意力计算是在处理第一帧还是后续帧。在视频编辑流水线中,我们应按照下面的伪代码,先编辑第一帧,再修改注意力状态后编辑后续帧。

1
2
3
4
edit(frames[0])
set_attn_state(SUBSEQUENT_FRAMES)
for i in range(1, len(frames)):
edit(frames[i])

现在,有一个问题:我们该怎么修改怎么每一个注意力模块的处理器的状态呢?显然,最直接的方式是想办法访问每一个注意力模块的处理器,再直接修改对象的属性。

1
2
3
4
modules = unet.get_attn_moduels
for module in modules:
if we_want_to_modify(module):
module.processor.state = ...

但是,每次都去遍历所有模块会让代码更加凌乱。同时,这样写也会带来代码维护上的问题:我们每次遍历注意力模块时,都可能要判断该注意力模块是否应该修改。而在用前面讲过的处理类替换方法 unet.set_attn_processor 时,我们也得判断一遍。同一段逻辑重复写在两个地方,非常不利于代码更新。

一种更优雅的实现方式是:我们定义一个状态管理类,所有注意力处理器都从同一个全局状态管理类对象里获取当前的状态信息。想修改每一个处理器的状态,不需要遍历所有对象,只需要改一次全局状态管理类对象就行了。

按照这种实现方式,我们先编写一个状态类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class AttnState:
STORE = 0
LOAD = 1

def __init__(self):
self.reset()

@property
def state(self):
return self.__state

def reset(self):
self.__state = AttnState.STORE

def to_load(self):
self.__state = AttnState.LOAD

在注意力处理类中,我们在初始化时保存状态类对象的引用,在运行时根据状态类对象获取当前状态。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class CrossFrameAttnProcessor(AttnProcessor):

def __init__(self, attn_state: AttnState):
super().__init__()
self.attn_state = attn_state
self.first_maps = {}
self.prev_maps = {}

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs):

if encoder_hidden_states is None:
# Is self attention

if self.attn_state.state == AttnState.STORE:
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)
else:
cross_map = torch.cat(
(self.first_maps[t], self.prev_maps[t]), dim=1)
res = super().__call__(attn, hidden_states, cross_map, **kwargs)
else:
# Is cross attention
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)

return res

到目前为止,假设已经维护好了之前的输入,我们的注意力处理类能执行两种不同的行为了。现在,我们来实现之前输入的维护。使用之前的注意力输入时,我们其实需要知道当前的时刻 t。当前的时刻也算是另一个状态,最好是也在状态管理类里维护。但为了简化我们的代码,我们可以偷懒让每个处理类自己维护当前时刻。具体做法是:如果知道了去噪迭代的总时刻数,我们就可以令当前时刻从0开始不断自增,直到最大时刻时,再重置为0。加入了时刻处理及之前输入维护的完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class AttnState:
STORE = 0
LOAD = 1

def __init__(self):
self.reset()

@property
def state(self):
return self.__state

@property
def timestep(self):
return self.__timestep

def set_timestep(self, t):
self.__timestep = t

def reset(self):
self.__state = AttnState.STORE
self.__timestep = 0

def to_load(self):
self.__state = AttnState.LOAD

class CrossFrameAttnProcessor(AttnProcessor):

def __init__(self, attn_state: AttnState):
super().__init__()
self.attn_state = attn_state
self.cur_timestep = 0
self.first_maps = {}
self.prev_maps = {}

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs):

if encoder_hidden_states is None:
# Is self attention

tot_timestep = self.attn_state.timestep
if self.attn_state.state == AttnState.STORE:
self.first_maps[self.cur_timestep] = hidden_states.detach()
self.prev_maps[self.cur_timestep] = hidden_states.detach()
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)
else:
tmp = hidden_states.detach()
cross_map = torch.cat(
(self.first_maps[self.cur_timestep], self.prev_maps[self.cur_timestep]), dim=1)
res = super().__call__(attn, hidden_states, cross_map, **kwargs)
self.prev_maps[self.cur_timestep] = tmp

self.cur_timestep += 1
if self.cur_timestep == tot_timestep:
self.cur_timestep = 0
else:
# Is cross attention
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)

return res

代码中,tot_timestep 为总时刻数,cur_timestep 为当前时刻。每运算一次,cur_timestep 加一,直至总时刻时再归零。在处理第一帧时,我们把当前时刻的输入同时存入第一帧缓存 first_maps 和上一帧缓存 prev_maps 中。对于后续帧,我们先做替换过输入的注意力计算,再更新上一帧缓存 prev_maps

视频编辑流水线

准备好了我们自己写的帧间注意力处理类后,我们来编写一个简单的 Diffusers 视频处理流水线。该流水线基于 ControlNet 与图生图流水线,其主要代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class VideoEditingPipeline(StableDiffusionControlNetImg2ImgPipeline):
def __init__(
self,
...
):
super().__init__(...)
self.attn_state = AttnState()
attn_processor_dict = {}
for k in unet.attn_processors.keys():
if k.startswith("up"):
attn_processor_dict[k] = CrossFrameAttnProcessor(
self.attn_state)
else:
attn_processor_dict[k] = AttnProcessor()

self.unet.set_attn_processor(attn_processor_dict)

def __call__(self, *args, images=None, control_images=None, **kwargs):
self.attn_state.reset()
self.attn_state.set_timestep(
int(kwargs['num_inference_steps'] * kwargs['strength']))
outputs = [super().__call__(
*args, **kwargs, image=images[0], control_image=control_images[0]).images[0]]
self.attn_state.to_load()
for i in range(1, len(images)):
image = images[i]
control_image = control_images[i]
outputs.append(super().__call__(
*args, **kwargs, image=image, control_image=control_image).images[0])
return outputs

在构造函数中,我们创建了一个全局注意力状态对象 attn_state。它的引用会传给每一个帧间注意力处理对象。一般修改自注意力模块时,只会修改 U-Net 上采样部分的,而不会动下采样部分和中间部分的。因此,在过滤注意力模块时,我们的判断条件是 k.startswith("up")。把新的注意力处理器词典填完后,用 unet.set_attn_processor 更新所有的处理类对象。

1
2
3
4
5
6
7
8
9
10
self.attn_state = AttnState()
attn_processor_dict = {}
for k in unet.attn_processors.keys():
if k.startswith("up"):
attn_processor_dict[k] = CrossFrameAttnProcessor(
self.attn_state)
else:
attn_processor_dict[k] = AttnProcessor()

self.unet.set_attn_processor(attn_processor_dict)

__call__ 方法中,我们要基于原图像编辑流水线 super().__call__(),实现我们的视频编辑流水线。在这个过程中,我们的主要任务是维护好注意力管理对象中的状态。一开始,我们要把管理类重置,根据参数设置最大去噪时刻数。经重置后,注意力处理器的状态默认为 STORE,即会保存第一帧的输入。处理完第一帧后,我们运行 attn_state.to_load() 改变注意力处理器的状态,让它们每次做注意力运算时先读第一帧和上一帧的输入,再维护上一帧输入的缓存。

1
2
3
4
5
6
7
8
9
10
11
12
13
def __call__(self, *args, images=None, control_images=None,  **kwargs):
self.attn_state.reset()
self.attn_state.set_timestep(
int(kwargs['num_inference_steps'] * kwargs['strength']))
outputs = [super().__call__(
*args, **kwargs, image=images[0], control_image=control_images[0]).images[0]]
self.attn_state.to_load()
for i in range(1, len(images)):
image = images[i]
control_image = control_images[i]
outputs.append(super().__call__(
*args, **kwargs, image=image, control_image=control_image).images[0])
return outputs

运行该流水线的示例脚本在项目根目录下的 replace_attn.py 文件中。示例中使用的视频可以在 https://github.com/williamyang1991/Rerender_A_Video/blob/main/videos/pexels-koolshooters-7322716.mp4 下载,下载后应重命名为 woman.mp4。不使用和使用新注意力处理器的输出结果如下:

可以看出,虽然注意力替换不能解决生成视频的闪烁问题,但帧间的一致性提升了不少。将注意力替换技术和其他技术结合起来的话,我们就能得到一个不错的 SD 视频生成工具。

总结

扩散模型中的自注意力替换是一种常见的提升图片一致性的技术。该技术的实现方法是将扩散模型 U-Net 中自注意力的 K, V 输入替换成另一幅图片的。在这篇文章中,我们学习了一个较为复杂的基于 Diffusers 开发的自注意力替换示例项目,用于提升 SD 视频生成的一致性。在这个过程中,我们学习了和 AttentionProcessor 相关接口函数的使用,并了解了如何基于全局管理类实现一个代码可维护性强的多行为注意力处理类。如果你能看懂这篇文章的示例,那你在开发 Diffusers 的注意力处理类时基本上不会碰到任何难题。

项目网址:https://github.com/SingleZombie/DiffusersExample/tree/main/ReplaceAttn

如果你想进一步学习 Diffusers 中视频编辑流水线的开发,可以参考我给 Diffusers 写的流水线:https://github.com/huggingface/diffusers/tree/main/examples/community#Rerender_A_Video

如果你一直关注 Stable Diffusion (SD) 社区,那你一定不会对 “LoRA” 这个名词感到陌生。社区用户分享的 SD LoRA 模型能够修改 SD 的画风,使之画出动漫、水墨或像素等风格的图片。但实际上,LoRA 不仅仅能改变 SD 的画风,还有其他的妙用。在这篇文章中,我们会先简单学习 LoRA 的原理,再认识科研中 LoRA 的三种常见应用:1) 还原单幅图像;2)风格调整;3)训练目标调整,最后阅读两个基于 Diffusers 的 SD LoRA 代码实现示例。

LoRA 的原理

在认识 LoRA 之前,我们先来回顾一下迁移学习的有关概念。迁移学习指在一次新的训练中,复用之前已经训练过的模型的知识。如果你自己动手训练过深度学习模型,那你应该不经意间地使用到了迁移学习:比如你一个模型训练了 500 步,测试后发现效果不太理想,于是重新读取该模型的参数,又继续训练了 100 步。之前那个被训练过的模型叫做预训练模型(pre-trained model),继续训练预训练模型的过程叫做微调(fine-tune)。

知道了微调的概念,我们就能来认识 LoRA 了。LoRA 的全称是 Low-Rank Adaptation (低秩适配),它是一种 Parameter-Efficient Fine-Tuning (参数高效微调,PEFT) 方法,即在微调时只训练原模型中的部分参数,以加速微调的过程。相比其他的 PEFT 方法,LoRA 之所以能脱颖而出,是因为它有几个明显的优点:

  • 从性能上来看,使用 LoRA 时,只需要存储少量被微调过的参数,而不需要把整个新模型都保存下来。同时,LoRA 的新参数可以和原模型的参数合并到一起,不会增加模型的运算时间。
  • 从功能上来看,LoRA 维护了模型在微调中的「变化量」。通过用一个介于 0~1 之间的混合比例乘变化量,我们可以控制模型的修改程度。此外,基于同一个原模型独立训练的多个 LoRA 可以同时使用。

这些优点在 SD LoRA 中的体现为:

  • SD LoRA 模型一般都很小,一般只有几十 MB。
  • SD LoRA 模型的参数可以合并到 SD 基础模型里,得到一个新的 SD 模型。
  • 可以用一个 0~1 之间的比例来控制 SD LoRA 新画风的程度。
  • 可以把不同画风的 SD LoRA 模型以不同比例混合。

为什么 LoRA 能有这些优点呢?LoRA 名字中的 「低秩」又是什么意思呢?让我们从 LoRA 的优点入手,逐步揭示它原理。

上文提到过,LoRA 之所以那么灵活,是因为它维护了模型在微调过程中的变化量。那么,假设我们正在修改模型中的一个参数 $W \in \mathbb{R}^{d \times d}$,我们就应该维护它的变化量 $\Delta W \in \mathbb{R}^{d \times d}$,训练时的参数用 $W + \Delta W$ 表示。这样,想要在推理时控制模型的修改程度,只要添加一个 $\alpha \in [0, 1]$,令使用的参数为 $W + \alpha \Delta W$即可。

可是,这样做我们还是要记录一个和原参数矩阵一样大的参数矩阵 $\Delta W$,这就算不上是参数高效微调了。为此,LoRA 的作者提出假设:模型参数在微调时的变化量中蕴含的信息没有那么多。为了用更少的信息来表示参数的变化量$\Delta W$,我们可以把$\Delta W$拆解成两个低秩矩阵的乘积:

其中,$A \in \mathbb{R}^{r \times d}$, $B \in \mathbb{R}^{d \times r}$,$d$ 是一个比 $r$ 小得多的数。这样,通过用两个参数量少得多的矩阵 $A, B$ 来维护变化量,我们不仅提高了微调的效率,还保持了使用变化量来描述微调过程的灵活性。这就是 LoRA 的全部原理,它十分简单,用 $\Delta W = BA$ 这一行公式足以表示。

了解了 LoRA 的原理,我们再回头看前文提及的 LoRA 的四项优点。LoRA 模型由许多参数量较少的矩阵 $A, B$ 来表示,它可以被单独存储,且占用空间不大。由于 $\Delta W = BA$ 维护的其实是参数的变化量,我们既可以把它与预训练模型的参数加起来得到一个新模型以提高推理速度,也可以在线地用一个混合比例来灵活地组合新旧模型。LoRA 的最后一个优点是各个基于同一个原模型独立训练出来的 LoRA 模型可以混合使用。LoRA 甚至可以作用于被其他方式修改过的原模型,比如 SD LoRA 支持带 ControlNet 的 SD。这一点其实来自于社区用户的实践。一个可能的解释是,LoRA 用低秩矩阵来表示变化量,这种低秩的变化量恰好与其他方法的变化量「错开」,使得 LoRA 能向着一个不干扰其他方法的方向修改模型。

我们最后来学习一下 LoRA 的实现细节。LoRA 有两个超参数,除了上文中提到的$r$,还有一个叫$\alpha$的参数。LoRA 的作者在实现 LoRA 模块时,给修改量乘了一个 $\frac{\alpha}{r}$ 的系数,即对于输入$x$,带了 LoRA 模块后的输出为 $Wx + \frac{\alpha}{r}BAx$。作者解释说,调这个参数几乎等于调学习率,一开始令$\alpha=r$即可。在我们要反复调超参数$r$时,只要保持$\alpha$不变,就不用改其他超参数了(因为不加$\alpha$的话,改了$r$后,学习率等参数也得做相应调整以维持同样的训练条件)。当然,实际运用中,LoRA 的超参数很好调。一般令$r=4, 8, 16$即可。由于我们不怎么会变$r$,总是令$\alpha=r$就够了。

为了使用 LoRA,除了确定超参数外,我们还得指定需要被微调的参数矩阵。在 SD 中使用 LoRA 时,大家一般会对 SD 的 U-Net 的所有多头注意力模块的所有参数矩阵做微调。即对于多头注意力模块的四个矩阵 $W_Q, W_K, W_V, W_{out}$ 进行微调。

LoRA 在 SD 中的三种运用

LoRA 在 SD 的科研中有着广泛的应用。按照使用 LoRA 的动机,我们可以把 LoRA 的应用分成:1) 还原单幅图像;2)风格调整;3)训练目标调整。通过学习这些应用,我们能更好地理解 LoRA 的本质。

还原单幅图像

SD 只是一个生成任意图片的模型。为了用 SD 来编辑一张给定的图片,我们一般要让 SD 先学会生成一张一模一样的图片,再在此基础上做修改。可是,由于训练集和输入图片的差异,SD 或许不能生成完全一样的图片。解决这个问题的思路很简单粗暴:我们只用这一张图片来微调 SD,让 SD 在这张图片上过拟合。这样,SD 的输出就会和这张图片非常相似了。

较早介绍这种提高输入图片保真度方法的工作是 Imagic,只不过它采取的是完全微调策略。后续的 DragDiffusion 也用了相同的方法,并使用 LoRA 来代替完全微调。近期的 DiffMorpher 为了实现两幅图像间的插值,不仅对两幅图像单独训练了 LoRA,还通过两个 LoRA 间的插值来平滑图像插值的过程。

风格调整

LoRA 在 SD 社区中最受欢迎的应用就是风格调整了。我们希望 SD 只生成某一画风,或者某一人物的图片。为此,我们只需要在一个符合我们要求的训练集上直接训练 SD LoRA 即可。

由于这种调整 SD 风格的方法非常直接,没有特别介绍这种方法的论文。稍微值得一提的是基于 SD 的视频模型 AnimateDiff,它用 LoRA 来控制输出视频的视角变换,而不是控制画风。

由于 SD 风格化 LoRA 已经被广泛使用,能否兼容 SD 风格化 LoRA 决定了一个工作是否易于在社区中传播。

训练目标调整

最后一个应用就有一点返璞归真了。LoRA 最初的应用就是把一个预训练模型适配到另一任务上。比如 GPT 一开始在大量语料中训练,随后在问答任务上微调。对于 SD 来说,我们也可以修改 U-Net 的训练目标,以提升 SD 的能力。

有不少相关工作用 LoRA 来改进 SD。比如 Smooth Diffusion 通过在训练目标中添加一个约束项并进行 LoRA 微调来使得 SD 的隐空间更加平滑。近期比较火的高速图像生成方法 LCM-LoRA 也是把原本作用于 SD 全参数上的一个模型蒸馏过程用 LoRA 来实现。

SD LoRA 应用总结

尽管上述三种 SD LoRA 应用的设计出发点不同,它们本质上还是在利用微调这一迁移学习技术来调整模型的数据分布或者训练目标。LoRA 只是众多高效微调方法中的一种,只要是微调能实现的功能,LoRA 基本都能实现,只不过 LoRA 更轻便而已。如果你想微调 SD 又担心计算资源不够,那么用 LoRA 准没错。反过来说,你想用 LoRA 在 SD 上设计出一个新应用,就要去思考微调 SD 能够做到哪些事。

Diffusers SD LoRA 代码实战

看完了原理,我们来尝试用 Diffusers 自己训一训 LoRA。我们会先学习 Diffusers 训练 LoRA 的脚本,再学习两个简单的 LoRA 示例: SD 图像插值与 SD 图像风格迁移。

项目网址:https://github.com/SingleZombie/DiffusersExample/tree/main/LoRA

Diffusers 脚本

我们将参考 Diffusers 中的 SD LoRA 文档 https://huggingface.co/docs/diffusers/training/lora ,使用官方脚本 examples/text_to_image/train_text_to_image_lora.py 训练 LoRA。为了使用这个脚本,建议直接克隆官方仓库,并安装根目录和 text_to_image 目录下的依赖文件。本文使用的 Diffusers 版本是 0.26.0,过旧的 Diffusers 的代码可能和本文展示的有所出入。目前,官方文档也描述的是旧版的代码。

1
2
3
4
5
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
cd examples/text_to_image
pip install -r requirements.txt

这份代码使用 accelerate 库管理 PyTorch 的训练。对同一份代码,只需要修改 accelerate 的配置,就能实现单卡训练或者多卡训练。默认情况下,用 accelerate launch 命令运行 Python 脚本会使用所有显卡。如果你需要修改训练配置,请参考相关文档使用 accelerate config 命令配置环境。

做好准备后,我们来开始阅读 examples/text_to_image/train_text_to_image_lora.py 的代码。这份代码写得十分易懂,复杂的地方都有注释。我们跳过命令行参数部分,直接从 main 函数开始读。

一开始,函数会配置 accelerate 库及日志记录器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()

随后的代码决定是否手动设置随机种子。保持默认即可。

1
2
3
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)

接着,函数会创建输出文件夹。如果我们想把模型推送到在线仓库上,函数还会创建一个仓库。我们的项目不必上传,忽略所有 args.push_to_hub 即可。另外,if accelerator.is_main_process: 表示多卡训练时只有主进程会执行这段代码块。

1
2
3
4
5
6
7
8
9
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

准备完辅助工具后,函数正式开始着手训练。训练前,函数会先实例化好一切处理类,包括用于维护扩散模型中间变量的 DDPMScheduler,负责编码输入文本的 CLIPTokenizer, CLIPTextModel,压缩图像的VAE AutoencoderKL,预测噪声的 U-Net UNet2DConditionModel。参数 args.pretrained_model_name_or_path 是 Diffusers 在线仓库的地址(如runwayml/stable-diffusion-v1-5),或者本地的 Diffusers 模型文件夹。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

函数还会设置各个带参数模型是否需要计算梯度。由于我们待会要优化的是新加入的 LoRA 模型,所有预训练模型都不需要计算梯度。另外,函数还会根据 accelerate 配置自动设置这些模型的精度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# Freeze the unet parameters before adding adapters
for param in unet.parameters():
param.requires_grad_(False)

# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

把预训练模型都调好了后,函数会配置 LoRA 模块并将其加入 U-Net 模型中。最近,Diffusers 更新了添加 LoRA 的方式。Diffusers 用 Attention 处理器来描述 Attention 的计算。为了把 LoRA 加入到 Attention 模块中,早期的 Diffusers 直接在 Attention 处理器里加入可训练参数。现在,为了和其他 Hugging Face 库统一,Diffusers 使用 PEFT 库来管理 LoRA。我们不需要关注 LoRA 的实现细节,只需要写一个 LoraConfig 就行了。

PEFT 中的 LoRA 文档参见 https://huggingface.co/docs/peft/conceptual_guides/lora

LoraConfig 中有四个主要参数: r, lora_alpha, init_lora_weights, target_modulesr, lora_alpha 的意义我们已经在前文中见过了,前者决定了 LoRA 矩阵的大小,后者决定了训练速度。默认配置下,它们都等于同一个值 args.rankinit_lora_weights 表示如何初始化训练参数,gaussian是论文中使用的方法。target_modules 表示 Attention 模块的哪些层需要添加 LoRA。按照通常的做法,会给所有层,即三个输入变换矩阵 to_k, to_q, to_v 和一个输出变换矩阵 to_out.0 加 LoRA。

创建了配置后,用 unet.add_adapter(unet_lora_config) 就可以创建 LoRA 模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
for param in unet.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

更新完了 U-Net 的结构,函数会尝试启用 xformers 来提升 Attention 的效率。PyTorch 在 2.0 版本也加入了类似的 Attention 优化技术。如果你的显卡性能有限,且 PyTorch 版本小于 2.0,可以考虑使用 xformers

1
2
3
4
5
6
7
8
9
10
11
12
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers

xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
...
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

做完了 U-Net 的处理后,函数会过滤出要优化的模型参数,这些参数稍后会传递给优化器。过滤的原则很简单,如果参数要求梯度,就是待优化参数。

1
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())

之后是优化器的配置。函数先是配置了一些细枝末节的训练选项,一般可以忽略。

1
2
3
4
5
6
7
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True

然后是优化器的选择。我们可以忽略其他逻辑,直接用 AdamW

1
2
3
4
5
6
7
8
9
10
11
12
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"..."
)

optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW

选择了优化器类,就可以实例化优化器了。优化器的第一个参数是之前准备好的待优化 LoRA 参数,其他参数是 Adam 优化器本身的参数。

1
2
3
4
5
6
7
optimizer = optimizer_cls(
lora_layers,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

准备了优化器,之后需要准备训练集。这个脚本用 Hugging Face 的 datasets 库来管理数据集。我们既可以读取在线数据集,也可以读取本地的图片文件夹数据集。在本文的示例项目中,我们将使用图片文件夹数据集。稍后我们再详细学习这样的数据集文件夹该怎么构建。相关的文档可以参考 https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

训练 SD 时,每一个数据样本需要包含两项信息:图像数据与对应的文本描述。在数据集 dataset 中,每个数据样本包含了多项属性。下面的代码用于从这些属性中取出图像与文本描述。默认情况下,第一个属性会被当做图像数据,第二个属性会被当做文本。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)

准备好了数据集,接下来要定义数据预处理流程以创建 DataLoader。函数先定义了一个把文本标签预处理成 token ID 的 token 化函数。我们不需要修改它。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids

接着,函数定义了图像数据的预处理流程。该流程是用 torchvision 中的 transforms 实现的。如代码所示,处理流程中包括了 resize 至指定分辨率 args.resolution、将图像长宽均裁剪至指定分辨率、随机翻转、转换至 tensor 和归一化。

经过这一套预处理后,所有图像的长宽都会被设置为 args.resolution 。统一图像的尺寸,主要的目的是对齐数据,以使多个数据样本能拼接成一个 batch。注意,数据预处理流程中包括了随机裁剪。如果数据集里的多数图片都长宽不一致,模型会倾向于生成被裁剪过的图片。为了解决这一问题,要么自己手动预处理图片,使训练图片都是分辨率至少为 args.resolution 的正方形图片,要么令 batch size 为 1 并取消掉随机裁剪。

1
2
3
4
5
6
7
8
9
10
11
12
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.Resize(
args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(
args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

定义了预处理流程后,函数对所有数据进行预处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [
train_transforms(image) for image in images]
examples["input_ids"] = tokenize_captions(examples)
return examples

with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(
seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)

之后函数用预处理过的数据集创建 DataLoader。这里要注意的参数是 batch size args.train_batch_size 和读取数据的进程数 args.dataloader_num_workers 。这两个参数的用法和一般的 PyTorch 项目一样。args.train_batch_size 决定了训练速度,一般设置到不爆显存的最大值。如果要读取的数据过多,导致数据读取成为了模型训练的速度瓶颈,则应该提高 args.dataloader_num_workers

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"]
for example in examples])
pixel_values = pixel_values.to(
memory_format=torch.contiguous_format).float()
input_ids = torch.stack([example["input_ids"] for example in examples])
return {"pixel_values": pixel_values, "input_ids": input_ids}

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)

如果想用更大的 batch size,显存又不够,则可以使用梯度累计技术。使用这项技术时,训练梯度不会每步优化,而是累计了若干步后再优化。args.gradient_accumulation_steps 表示要累计几步再优化模型。实际的 batch size 等于输入 batch size 乘 GPU 数乘梯度累计步数。下面的代码维护了训练步数有关的信息,并创建了学习率调度器。我们按照默认设置使用一个常量学习率即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)

# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(
args.max_train_steps / num_update_steps_per_epoch)

在准备工作的最后,函数会用 accelerate 库记录配置信息。

1
2
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))

终于,要开始训练了。训练开始前,函数会准备全局变量并记录日志。

1
2
3
4
5
6
7
8
# Train!
total_batch_size = args.train_batch_size * \
accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
...
global_step = 0
first_epoch = 0

此时,如果设置了 args.resume_from_checkpoint,则函数会读取之前训练过的权重。一般继续训练时可以把该参数设为 latest,程序会自动找最新的权重。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = ...
else:
# Get the most recent checkpoint
path = ...

if path is None:
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0

随后,函数根据总步数和已经训练过的步数设置迭代器,正式进入训练循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):

训练的过程基本和 LDM 论文中展示的一致。一开始,要取出图像batch["pixel_values"] 并用 VAE 把它压缩进隐空间。

1
2
3
4
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(
dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor

再随机生成一个噪声。该噪声会套入扩散模型前向过程的公式,和输入图像一起得到 t 时刻的带噪图像。

1
2
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

下一步,这里插入了一个提升扩散模型训练质量的小技巧,用上它后输出图像的颜色分布会更合理。原理见注释中的链接。args.noise_offset 默认为 0。如果要启用这个特性,一般令 args.noise_offset = 0.1

1
2
3
4
5
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)

然后是时间戳的随机生成。

1
2
3
4
5
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

时间戳和前面随机生成的噪声一起经 DDPM 的前向过程得到带噪图片 noisy_latents

1
2
3
4
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(
latents, noise, timesteps)

再把文本 batch["input_ids"] 编码,为之后的 U-Net 前向传播做准备。

1
2
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

在 U-Net 推理开始前,函数这里做了一个关于 U-Net 输出类型的判断。一般 U-Net 都是输出预测的噪声 epsilon,可以忽略这段代码。当 U-Net 是想预测噪声时,要拟合的目标是之前随机生成的噪声 noise

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
# set prediction_type of scheduler if defined
noise_scheduler.register_to_config(
prediction_type=args.prediction_type)

if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(
latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}")

之后把带噪图像、时间戳、文本编码输入进 U-Net,U-Net 输出预测的噪声。

1
2
3
# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps,
encoder_hidden_states).sample

有了预测值,下一步是算 loss。这里又可以选择是否使用一种加速训练的技术。如果使用,则 args.snr_gamma 推荐设置为 5.0。原 DDPM 的做法是直接算预测噪声和真实噪声的均方误差。

1
2
3
4
5
6
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(),
target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
...

训练迭代的最后,要用 accelerate 库来完成梯度计算和反向传播。在更新梯度前,可以通过设置 args.max_grad_norm 来裁剪梯度,以防梯度过大。args.max_grad_norm 默认为 1.0。代码中的 if accelerator.sync_gradients: 可以保证所有 GPU 都同步了梯度再执行后续代码。

1
2
3
4
5
6
7
8
9
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers
accelerator.clip_grad_norm_(
params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

一步训练结束后,更新和步数相关的变量。

1
2
3
4
5
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0

脚本默认每 args.checkpointing_steps 步保存一次中间结果。当需要保存时,函数会清理多余的 checkpoint,再把模型状态和 LoRA 模型分别保存下来。accelerator.save_state(save_path) 负责把模型及优化器等训练用到的所有状态存下来,后面的 StableDiffusionPipeline.save_lora_weights 负责存储 LoRA 模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = ...

if len(checkpoints) >= args.checkpoints_total_limit:
# remove ckpt
...

save_path = os.path.join(
args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)

unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)

StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)

logger.info(f"Saved state to {save_path}")

训练循环的最后,函数会更新进度条上的信息,并根据当前的训练步数决定是否停止训练。

1
2
3
4
5
6
logs = {"step_loss": loss.detach().item(
), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

训完每一个 epoch 后,函数会进行验证。默认的验证方法是新建一个图像生成 pipeline,生成一些图片并保存。如果有其他验证方法,如计算某一指标,可以自行编写这部分的代码。

1
2
3
4
5
6
7
8
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = DiffusionPipeline.from_pretrained(...)
...

所有训练结束后,函数会再存一次最终的 LoRA 模型权重。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)

unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet))
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)

if args.push_to_hub:
...

函数还会再测试一次模型。具体方法和之前的验证是一样的。

1
2
3
4
# Final inference
# Load previous pipeline
if args.validation_prompt is not None:
...

运行完了这里,函数也就结束了。

1
accelerator.end_training()

为了方便使用,我把这个脚本改写了一下:删除了部分不常用的功能,并且配置参数能通过配置文件而不是命令行参数传入。新的脚本为项目根目录下的 train_lora.py,示例配置文件在 cfg 目录下。

cfg 中的某个配置文件为例,我们来回顾一下训练脚本主要用到的参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
{
"log_dir": "log",
"output_dir": "ckpt",
"data_dir": "dataset/mountain",
"ckpt_name": "mountain",
"gradient_accumulation_steps": 1,
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
"rank": 8,
"enable_xformers_memory_efficient_attention": true,
"learning_rate": 1e-4,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"adam_weight_decay": 1e-2,
"adam_epsilon": 1e-08,
"resolution": 512,
"n_epochs": 200,
"checkpointing_steps": 500,
"train_batch_size": 1,
"dataloader_num_workers": 1,
"lr_scheduler_name": "constant",
"resume_from_checkpoint": false,
"noise_offset": 0.1,
"max_grad_norm": 1.0
}

需要关注的参数:output_dir 为输出 checkpoint 的文件夹,ckpt_name 为输出 checkpoint 的文件名。data_dir 是训练数据集所在文件夹。pretrained_model_name_or_path 为 SD 模型文件夹。rank 是决定 LoRA 大小的参数。learning_rate 是学习率。adam 打头的是 AdamW 优化器的参数。resolution 是训练图片的统一分辨率。n_epochs 是训练的轮数。checkpointing_steps 指每过多久存一次 checkpoint。train_batch_size 是 batch size。gradient_accumulation_steps 是梯度累计步数。

要修改这个配置文件,要先把文件夹的路径改对,填上训练时的分辨率,再通过 gradient_accumulation_stepstrain_batch_size 决定 batch size,接着填 n_epochs (一般训 10~20 轮就会过拟合)。最后就可以一边改 LoRA 的主要超参数 rank 一边反复训练了。

SD 图像插值

在这个示例中,我们来实现 DiffMorpher 工作的一小部分,完成一个简单的图像插值工具。在此过程中,我们将学会怎么在单张图片上训练 SD LoRA,以验证我们的训练环境。

这个工具的原理很简单:我们对两张图片分别训练一个 LoRA。之后,为了获取两张图片的插值,我们可以对两张图片 DDIM Inversion 的初始隐变量及两个 LoRA 分别插值,用插值过的隐变量在插值过的 SD LoRA 上生成图片就能得到插值图片。

该示例的所有数据和代码都已经在项目文件夹中给出。首先,我们看一下该怎么在单张图片上训 LoRA。训练之前,我们要准备一个数据集文件夹。数据集文件夹及包含所有图片及一个描述文件 metadata.jsonl。比如单图片的数据集文件夹的结构应如下所示:

1
2
3
├── mountain
│ ├── metadata.jsonl
│ └── mountain.jpg

metadata.jsonl 元数据文件的每一行都是一个 json 结构,包含该图片的路径及文本描述。单图片的元数据文件如下:

1
{"file_name": "mountain.jpg", "text": "mountain"}

如果是多图片,就应该是:

1
2
3
{"file_name": "mountain.jpg", "text": "mountain"}
{"file_name": "mountain_up.jpg", "text": "mountain"}
...

我们可以运行项目目录下的数据集测试文件 test_dataset.py 来看看 datasets 库的数据集对象包含哪些信息。

1
2
3
4
5
6
7
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="dataset/mountain")
print(dataset)
print(dataset["train"].column_names)
print(dataset["train"]['image'])
print(dataset["train"]['text'])

其输出大致为:

1
2
3
4
5
6
7
8
9
10
Generating train split: 1 examples [00:00, 66.12 examples/s]
DatasetDict({
train: Dataset({
features: ['image', 'text'],
num_rows: 1
})
})
['image', 'text']
[<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F0400246670>]
['mountain']

这说明数据集对象实际上是一个词典。默认情况下,数据集放在词典的 train 键下。数据集的 column_names 属性可以返回每项数据有哪些属性。在我们的数据集里,数据的 image 是图像数据,text 是文本标签。训练脚本默认情况下会把每项数据的第一项属性作为图像,第二项属性作为文本标签。我们的这个数据集定义与训练脚本相符。

认识了数据集,我们可以来训练模型了。用下面的两行命令就可以分别在两张图片上训练 LoRA。

1
2
python train_lora.py cfg/mountain.json
python train_lora.py cfg/mountain_up.json

如果要用所有显卡训练,则应该用 accelerate。当然,对于这个简单的单图片训练,不需要用那么多显卡。

1
2
accelerate launch train_lora.py cfg/mountain.json
accelerate launch train_lora.py cfg/mountain_up.json

这两个 LoRA 模型的配置文件我们已经在前文见过了。相比普通的风格化 LoRA,这两个 LoRA 的训练轮数非常多,有 200 轮。设置较大的训练轮数能保证模型在单张图片上过拟合。

训练结束后,项目的 ckpt 文件夹下会多出两个 LoRA 权重文件: mountain.safetensor, mountain_up.safetensor。我们可以用它们来做图像插值了。

图像插值的脚本为 morph.py,它的主要内容为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from inversion_pipeline import InversionPipeline

lora_path = 'ckpt/mountain.safetensor'
lora_path2 = 'ckpt/mountain_up.safetensor'
sd_path = 'runwayml/stable-diffusion-v1-5'


pipeline: InversionPipeline = InversionPipeline.from_pretrained(
sd_path).to("cuda")
pipeline.load_lora_weights(lora_path, adapter_name='a')
pipeline.load_lora_weights(lora_path2, adapter_name='b')

img1_path = 'dataset/mountain/mountain.jpg'
img2_path = 'dataset/mountain_up/mountain_up.jpg'
prompt = 'mountain'
latent1 = pipeline.inverse(img1_path, prompt, 50, guidance_scale=1)
latent2 = pipeline.inverse(img2_path, prompt, 50, guidance_scale=1)
n_frames = 10
images = []
for i in range(n_frames + 1):
alpha = i / n_frames
pipeline.set_adapters(["a", "b"], adapter_weights=[1 - alpha, alpha])
latent = slerp(latent1, latent2, alpha)
output = pipeline(prompt=prompt, latents=latent,
guidance_scale=1.0).images[0]
images.append(output)

对于每一个 Diffusers 的 Pipeline 类实例,都可以用 pipeline.load_lora_weights 来读取 LoRA 权重。如果我们在同一个模型上使用了多个 LoRA,为了区分它们,我们要加上 adapter_name 参数为每个 LoRA 命名。稍后我们会用到这些名称。

1
2
pipeline.load_lora_weights(lora_path, adapter_name='a')
pipeline.load_lora_weights(lora_path2, adapter_name='b')

读好了文件,使用已经写好的 DDIM Inversion 方法来得到两张图片的初始隐变量。

1
2
3
4
5
img1_path = 'dataset/mountain/mountain.jpg'
img2_path = 'dataset/mountain_up/mountain_up.jpg'
prompt = 'mountain'
latent1 = pipeline.inverse(img1_path, prompt, 50, guidance_scale=1)
latent2 = pipeline.inverse(img2_path, prompt, 50, guidance_scale=1)

最后开始生成不同插值比例的图片。根据混合比例 alpha,我们可以用 pipeline.set_adapters(["a", "b"], adapter_weights=[1 - alpha, alpha]) 来融合 LoRA 模型的比例。随后,我们再根据 alpha 对隐变量插值。用插值隐变量在插值 SD LoRA 上生成图片即可得到最终的插值图片。

1
2
3
4
5
6
7
8
9
n_frames = 10
images = []
for i in range(n_frames + 1):
alpha = i / n_frames
pipeline.set_adapters(["a", "b"], adapter_weights=[1 - alpha, alpha])
latent = slerp(latent1, latent2, alpha)
output = pipeline(prompt=prompt, latents=latent,
guidance_scale=1.0).images[0]
images.append(output)

下面两段动图中,左图和右图分别是无 LoRA 和有 LoRA 的插值结果。可见,通过 LoRA 权重上的插值,图像插值的过度会更加自然。

图片风格迁移

接下来,我们来实现最流行的 LoRA 应用——风格化 LoRA。当然,训练一个每张随机输出图片都质量很高的模型是很困难的。我们退而求其次,来实现一个能对输入图片做风格迁移的 LoRA 模型。

训练风格化 LoRA 对技术要求不高,其主要难点其实是在数据收集上。大家可以根据自己的需求,准备自己的数据集。我在本文中会分享我的实验结果。我希望把《弹丸论破》的画风——一种颜色渐变较多的动漫画风——应用到一张普通动漫画风的图片上。

由于我的目标是拟合画风而不是某一种特定的物体,我直接选取了 50 张左右的游戏 CG 构成训练数据集,且没有对图片做任何处理。训风格化 LoRA 时,文本标签几乎没用,我把所有数据的文本都设置成了游戏名 danganronpa

1
2
3
{"file_name": "1.png", "text": "danganronpa"}
...
{"file_name": "59.png", "text": "danganronpa"}

我的配置文件依然和前文的相同,LoRA rank 设置为 8。我一共训了 100 轮,但发现训练后期模型的过拟合很严重,其实令 n_epochs 为 10 到 20 就能有不错的结果。50 张图片训 10 轮最多几十分钟就训完。

由于训练图片的内容不够多样,且图片预处理时加入了随机裁剪,我的 LoRA 模型随机生成的图片质量较低。于是我决定在图像风格迁移任务上测试该模型。具体来说,我使用了 ControlNet Canny 加上图生图 (SDEdit)技术。相关的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel
from PIL import Image
import cv2
import numpy as np

lora_path = '...'
sd_path = 'runwayml/stable-diffusion-v1-5'
controlnet_canny_path = 'lllyasviel/sd-controlnet-canny'

prompt = '1 man, look at right, side face, Ace Attorney, Phoenix Wright, best quality, danganronpa'
neg_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, {multiple people}'
img_path = '...'
init_image = Image.open(img_path).convert("RGB")
init_image = init_image.resize((768, 512))
np_image = np.array(init_image)

# get canny image
np_image = cv2.Canny(np_image, 100, 200)
np_image = np_image[:, :, None]
np_image = np.concatenate([np_image, np_image, np_image], axis=2)
canny_image = Image.fromarray(np_image)
canny_image.save('tmp_edge.png')

controlnet = ControlNetModel.from_pretrained(controlnet_canny_path)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
sd_path, controlnet=controlnet
)
pipe.load_lora_weights(lora_path)

output = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
strength=0.5,
guidance_scale=7.5,
controlnet_conditioning_scale=0.5,
num_inference_steps=50,
image=init_image,
cross_attention_kwargs={"scale": 1.0},
control_image=canny_image,
).images[0]
output.save("tmp.png")

StableDiffusionControlNetImg2ImgPipeline 是 Diffusers 中 ControlNet 加图生图的 Pipeline。使用它生成图片的重要参数有:

  • strength:0~1 之间重绘比例。越低越接近输入图片。
  • controlnet_conditioning_scale: 0~1 之间的 ControlNet 约束比例。越高越贴近约束。
  • cross_attention_kwargs={"scale": scale}:此处的 scale 是 0~1 之间的 LoRA 混合比例。越高越贴近 LoRA 模型的输出。

这里贴一下输入图片和两张编辑后的图片。



可以看出,输出图片中人物的画风确实得到了修改,颜色渐变更加丰富。我在几乎没有调试 LoRA 参数的情况下得到了这样的结果,可见虽然训练一个高质量的随机生成新画风的 LoRA 难度较高,但只是做风格迁移还是比较容易的。

尽管实验的经历不多,我还是基本上了解了 SD LoRA 风格化的能力边界。LoRA 风格化的本质还是修改输出图片的分布,数据集的质量基本上决定了生成的质量,其他参数的影响不会很大(包括训练图片的文本标签)。数据集最好手动裁剪至 512x512。如果想要生成丰富的风格化内容而不是只生成人物,就要丰富训练数据,减少人物数据的占比。训练时,最容易碰到的机器学习上的问题是过拟合问题。解决此问题的最简单的方式是早停,即不用最终的训练结果而用中间某一步的结果。如果你想实现改变输出数据分布以外的功能,比如精确生成某类物体、向模型中加入一些改变画风的关键词,那你应该使用更加先进的技术,而不仅仅是用最基本的 LoRA 微调。

总结

LoRA 是当今深度学习领域中常见的技术。对于 SD,LoRA 则是能够编辑单幅图片、调整整体画风,或者是通过修改训练目标来实现更强大的功能。LoRA 的原理非常简单,它其实就是用两个参数量较少的矩阵来描述一个大参数矩阵在微调中的变化量。Diffusers 库提供了非常便利的 SD LoRA 训练脚本。相信读完了本文后,我们能知道如何用 Diffusers 训练 LoRA,修改训练中的主要参数,并在简单的单图片 LoRA 编辑任务上验证训练的正确性。利用这些知识,我们也能把 LoRA 拓展到风格化生成及其他应用上。

本文的项目网址:https://github.com/SingleZombie/DiffusersExample/tree/main/LoRA

看完了Stable Diffusion的论文,在最后这篇文章里,我们来学习Stable Diffusion的代码实现。具体来说,我们会学习Stable Diffusion官方仓库及Diffusers开源库中有关采样算法和U-Net的代码,而不会学习有关训练、VAE、text encoder (CLIP) 的代码。如今大多数工作都只会用到预训练的Stable Diffusion,只学采样算法和U-Net代码就能理解大多数工作了。

建议读者在阅读本文之前了解DDPM、ResNet、U-Net、Transformer。

本文用到的Stable Diffusion版本是v1.5。Diffusers版本是0.25.0。为了提升可读性,本文对源代码做了一定的精简,部分不会运行到的分支会被略过。

算法梳理

在正式读代码之前,我们先用伪代码梳理一下Stable Diffusion的采样过程,并回顾一下U-Net架构的组成。实现Stable Diffusion的代码库有很多,各个库之间的API差异很大。但是,它们实际上都是在描述同一个算法,同一个模型。如果我们理解了算法和模型本身,就可以在学习时主动去找一个算法对应哪一段代码,而不是被动地去理解每一行代码在干什么。

LDM 采样算法

让我们从最早的DDPM开始,一步一步还原Latent Diffusion Model (LDM)的采样算法。DDPM的采样算法如下所示:

1
2
3
4
5
6
7
8
9
10
def ddpm_sample(image_shape):
ddpm_scheduler = DDPMScheduler()
unet = UNet()
xt = randn(image_shape)
T = 1000
for t in T ... 1:
eps = unet(xt, t)
std = ddpm_scheduler.get_std(t)
xt = ddpm_scheduler.get_xt_prev(xt, t, eps, std)
return xt

在DDPM的实现中,一般会有一个类专门维护扩散模型的$\alpha, \beta$等变量。我们这里把这个类称为DDPMScheduler。此外,DDPM会用到一个U-Net神经网络unet,用于计算去噪过程中图像应该去除的噪声eps。准备好这两个变量后,就可以用randn()从标准正态分布中采样一个纯噪声图像xt。它会被逐渐去噪,最终变成一幅图片。去噪过程中,时刻t会从总时刻T遍历至1(总时刻T一般取1000)。在每一轮去噪步骤中,U-Net会根据这一时刻的图像xt和当前时间戳t估计出此刻应去除的噪声eps,根据xteps就能知道下一步图像的均值。除了均值,我们还要获取下一步图像的方差,这一般可以从DDPM调度类中直接获取。有了下一步图像的均值和方差,我们根据DDPM的公式,就能采样出下一步的图像。反复执行去噪循环,xt会从纯噪声图像变成一幅有意义的图像。

DDIM对DDPM的采样过程做了两点改进:1) 去噪的有效步数可以少于T步,由另一个变量ddim_steps决定;2) 采样的方差大小可以由eta决定。因此,改进后的DDIM算法可以写成这样:

1
2
3
4
5
6
7
8
9
10
11
def ddim_sample(image_shape, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
unet = UNet()
xt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
for t in timesteps:
eps = unet(xt, t)
std = ddim_scheduler.get_std(t, eta)
xt = ddim_scheduler.get_xt_prev(xt, t, eps, std)
return xt

其中,ddim_steps是去噪循环的执行次数。根据ddim_steps,DDIM调度器可以生成所有被使用到的t。比如对于T=1000, ddim_steps=20,被使用到的就只有[1000, 950, 900, ..., 50]这20个时间戳,其他时间戳就可以跳过不算了。eta会被用来计算方差,一般这个值都会设成0

DDIM是早期的加速扩散模型采样的算法。如今有许多比DDIM更好的采样方法,但它们多数都保留了stepseta这两个参数。因此,在使用所有采样方法时,我们可以不用关心实现细节,只关注多出来的这两个参数。

在DDIM的基础上,LDM从生成像素空间上的图像变为生成隐空间上的图像。隐空间图像需要再做一次解码才能变回真实图像。从代码上来看,使用LDM后,只需要多准备一个VAE,并对最后的隐空间图像zt解码。

1
2
3
4
5
6
7
8
9
10
11
12
13
def ldm_ddim_sample(image_shape, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
for t in timesteps:
eps = unet(zt, t)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt

而想用LDM实现文生图,则需要给一个额外的文本输入text。文本编码器会把文本编码成张量c,输入进unet。其他地方的实现都和之前的LDM一样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

text_encoder = CLIP()
c = text_encoder.encode(text)

for t = timesteps:
eps = unet(zt, t, c)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt

最后这个能实现文生图的LDM就是我们熟悉的Stable Diffusion。Stable Diffusion的采样算法看上去比较复杂,但如果能够从DDPM开始把各个功能都拆开来看,理解起来就不是那么困难了。

U-Net 结构组成

Stable Diffusion代码实现中的另一个重点是去噪网络U-Net的实现。仿照上一节的学习方法,我们来逐步学习Stable Diffusion中的U-Net是怎么从最经典的纯卷积U-Net逐渐发展而来的。

最早的U-Net的结构如下图所示:

可以看出,U-Net的结构有以下特点:

  • 整体上看,U-Net由若干个大层组成。特征在每一大层会被下采样成尺寸更小的特征,再被上采样回原尺寸的特征。整个网络构成一个U形结构。
  • 下采样后,特征的通道数会变多。一般情况下,每次下采样后图像尺寸减半,通道数翻倍。上采样过程则反之。
  • 为了防止信息在下采样的过程中丢失,U-Net每一大层在下采样前的输出会作为额外输入拼接到每一大层上采样前的输入上。这种数据连接方式类似于ResNet中的「短路连接」。

DDPM则使用了一种改进版的U-Net。改进主要有两点:

  • 原来的卷积层被替换成了ResNet中的残差卷积模块。每一大层有若干个这样的子模块。对于较深的大层,残差卷积模块后面还会接一个自注意力模块。
  • 原来模型每一大层只有一个短路连接。现在每个大层下采样部分的每个子模块的输出都会额外输入到其对称的上采样部分的子模块上。直观上来看,就是短路连接更多了一点,输入信息更不容易在下采样过程中丢失。

最后,LDM提出了一种给U-Net添加额外约束信息的方法:把U-Net中的自注意力模块换成交叉注意力模块。具体来说,DDPM的U-Net的自注意力模块被换成了标准的Transformer模块。约束信息$C$可以作为Cross Attention的K, V输入进模块中。

Stable Diffusion的U-Net还在结构上有少许修改,该U-Net的每一大层都有Transformer块,而不是只有较深的大层有。

至此,我们已经学完了Stable Diffusion的采样原理和U-Net结构。接下来我们来看一看它们在不同框架下的代码实现。

Stable Diffusion 官方 GitHub 仓库

安装

克隆仓库后,照着官方Markdown文档安装即可。

1
git clone git@github.com:CompVis/stable-diffusion.git

先用下面的命令创建conda环境,此后ldm环境就是运行Stable Diffusiion的conda环境。

1
2
conda env create -f environment.yaml
conda activate ldm

之后去网上下一个Stable Diffusion的模型文件。比较常见一个版本是v1.5,该模型在Hugging Face上:https://huggingface.co/runwayml/stable-diffusion-v1-5 (推荐下载v1-5-pruned.ckpt)。下载完毕后,把模型软链接到指定位置。

1
2
mkdir -p models/ldm/stable-diffusion-v1/
ln -s <path/to/model.ckpt> models/ldm/stable-diffusion-v1/model.ckpt

准备完毕后,只要输入下面的命令,就可以生成实现文生图了。

1
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" 

在默认的参数下,“一幅骑着马的飞行员的照片”的绘制结果会被保存在outputs/txt2img-samples中。你也可以通过--outdir <dir>参数来指定输出到的文件夹。我得到的一些绘制结果为:

如果你在安装时碰到了错误,可以在搜索引擎上或者GitHub的issue里搜索,一般都能搜到其他人遇到的相同错误。

主函数

接下来,我们来探究一下scripts/txt2img.py的执行过程。为了方便阅读,我们可以简化代码中的命令行处理,得到下面这份精简代码。(你可以把这份代码复制到仓库根目录下的一个新Python脚本里并直接运行。别忘了修改代码中的模型路径)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from pytorch_lightning import seed_everything
from torch import autocast
from torchvision.utils import make_grid

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler


def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)

model.cuda()
model.eval()
return model


def main():
seed = 42
config = 'configs/stable-diffusion/v1-inference.yaml'
ckpt = 'ckpt/v1-5-pruned.ckpt'
outdir = 'tmp'
n_samples = batch_size = 3
n_rows = batch_size
n_iter = 2
prompt = 'a photograph of an astronaut riding a horse'
data = [batch_size * [prompt]]
scale = 7.5
C = 4
f = 8
H = W = 512
ddim_steps = 50
ddim_eta = 0.0

seed_everything(seed)

config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt)

device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)

os.makedirs(outdir, exist_ok=True)
outpath = outdir

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
grid_count = len(os.listdir(outpath)) - 1

start_code = None
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(
batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

all_samples.append(x_samples_ddim)
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")


if __name__ == "__main__":
main()

抛开前面一大堆初始化操作,代码的核心部分只有下面几行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(
batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

x_samples_ddim = model.decode_first_stage(samples_ddim)

我们来逐行分析一下这段代码。一开始的几行是执行Classifier-Free Guidance (CFG)。uc表示的是CFG中的无约束下的约束张量。scale表示的是执行CFG的程度,scale不等于1.0即表示启用CFG。model.get_learned_conditioning表示用CLIP把文本编码成张量。对于文本约束的模型,无约束其实就是输入文本为空字符串("")。因此,在代码中,若启用了CFG,则会用CLIP编码空字符串,编码结果为uc

如果你没学过CFG,也不用担心。你可以暂时不要去理解上面这段话。等读完了后文中有关CFG的代码后,你差不多就能理解CFG的用法了。

1
2
3
4
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(
batch_size * [""])

之后的几行是在把用户输入的文本编码成张量。同样,model.get_learned_conditioning表示用CLIP把输入文本编码成张量c

1
2
3
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)

接着是用扩散模型的采样器生成图片。在这份代码中,sampler是DDIM采样器,sampler.sample函数直接完成了图像生成。

1
2
3
4
5
6
7
8
9
10
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

最后,LDM生成的隐空间图片被VAE解码成真实图片。函数model.decode_first_stage负责图片解码。x_samples_ddim在后续的代码中会被后处理成正确格式的RGB图片,并输出至文件里。

1
x_samples_ddim = model.decode_first_stage(samples_ddim)

Stable Diffusion 官方实现的主函数主要就做了这些事情。这份实现还是有一些凌乱的。采样算法的一部分内容被扔到了主函数里,另一部分放到了DDIM采样器里。在阅读官方实现的源码时,既要去读主函数里的内容,也要去读采样器里的内容。

接下来,我们来看一看DDIM采样器的部分代码,学完采样算法的剩余部分的实现。

DDIM 采样器

回头看主函数的前半部分,DDIM采样器是在下面的代码里导入的:

1
from ldm.models.diffusion.ddim import DDIMSampler

跳转到ldm/models/diffusion/ddim.py文件,我们可以找到DDIMSampler类的实现。

先看一下这个类的构造函数。构造函数主要是把U-Net model给存了下来。后文中的self.model都指的是U-Net。

1
2
3
4
5
6
7
8
9
10
11
12
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule

# in main

config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt)
model = model.to(device)
sampler = DDIMSampler(model)

再沿着类的self.sample方法,看一下DDIM采样的实现代码。以下是self.sample方法的主要内容。这个方法其实就执行了一个self.make_schedule,之后把所有参数原封不动地传到了self.ddim_sampling里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
...
):
if conditioning is not None:
...

self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')

samples, intermediates = self.ddim_sampling(...)

self.make_schedule用于预处理扩散模型的中间计算参数。它的大部分实现细节可以略过。DDIM用到的有效时间戳列表就是在这个函数里设置的,该列表通过make_ddim_timesteps获取,并保存在self.ddim_timesteps中。此外,由ddim_eta决定的扩散模型的方差也是在这个方法里设置的。大致扫完这个方法后,我们可以直接跳到self.ddim_sampling的代码。

1
2
3
4
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
...

穿越重重的嵌套,我们总算能看到DDIM采样的实现方法self.ddim_sampling了。它的主要内容如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@torch.no_grad()
def ddim_sampling(self, ...):
device = self.model.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
timesteps = self.ddim_timesteps
intermediates = ...
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]

iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)

outs = self.p_sample_ddim(img, cond, ts, ...)
img, pred_x0 = outs

return img, intermediates

这段代码和我们之前自己写的伪代码非常相似。一开始,方法获取了在make_schedule里初始化的DDIM有效时间戳列表self.ddim_timesteps,并预处理成一个iterator。该迭代器用于控制DDIM去噪循环。每一轮循环会根据当前时刻的图像img和时间戳ts计算下一步的图像img。具体来说,代码每次用当前的时间戳step创建一个内容全部为step,形状为(b,)的张量ts。该张量会和当前的隐空间图像img,约束信息张量cond一起传给执行一轮DDIM去噪的p_sample_ddim方法。p_sample_ddim方法会返回下一步的图像img。最后,经过多次去噪后,ddim_sampling方法将去噪后的隐空间图像img返回。

p_sample_ddim里的p_sample看上去似乎意义不明,实际上这个叫法来自于DDPM论文。在DDPM论文中,扩散模型的前向过程用字母$q$表示,反向过程用字母$p$表示。因此,反向过程的一轮去噪在代码里被叫做p_sample

最后来看一下p_sample_ddim这个方法,它的主体部分如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@torch.no_grad()
def p_sample_ddim(self, x, c, t, ...):
b, *_, device = *x.shape, x.device

if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)


# Prepare variables
...

# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0

方法的内容大致可以拆成三段:首先,方法调用U-Net self.model,使用CFG来计算除这一轮该去掉的噪声e_t。然后,方法预处理出DDIM的中间变量。最后,方法根据DDIM的公式,计算出这一轮去噪后的图片x_prev。我们着重看第一部分的代码。

不启用CFG时,方法直接通过self.model.apply_model(x, t, c)调用U-Net,算出这一轮的噪声e_t。而想启用CFG,需要输入空字符串的约束张量unconditional_conditioning,且CFG的强度unconditional_guidance_scale不为1。CFG的执行过程是:对U-Net输入不同的约束c,先用空字符串约束得到一个预测噪声e_t_uncond,再用输入的文本约束得到一个预测噪声e_t。之后令e_t = et_uncond + scale * (e_t - e_t_uncond)scale大于1,即表明我们希望预测噪声更加靠近有输入文本的那一个。直观上来看,scale越大,最后生成的图片越符合输入文本,越偏离空文本。下面这段代码正是实现了上述这段逻辑,只不过代码使用了一些数据拼接技巧,让空字符串约束下和输入文本约束下的结果在一次U-Net推理中获得。

1
2
3
4
5
6
7
8
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

p_sample_ddim 方法的后续代码都是在实现下面这个DDIM采样公式。代码工工整整地计算了公式中的predicted_x0, dir_xt, noise,非常易懂,没有需要特别注意的地方。

我们已经看完了p_sample_ddim的代码。该方法可以实现一步去噪操作。多次调用该方法去噪后,我们就能得到生成的隐空间图片。该图片会被返回到main函数里,被VAE的解码器解码成普通图片。至此,我们就学完了Stable Diffusion官方仓库的采样代码。

对照下面这份我们之前写的伪代码,我们再来梳理一下Stable Diffusion官方仓库的代码逻辑。官方仓库的采样代码一部分在main函数里,另一部分在ldm/models/diffusion/ddim.py里。main函数主要完成了编码约束文字、解码隐空间图像这两件事。剩下的DDIM采样以及各种Diffusion图像编辑功能都是在ldm/models/diffusion/ddim.py文件中实现的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0)
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
eta = input()
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

text_encoder = CLIP()
c = text_encoder.encode(text)

for t = timesteps:
eps = unet(zt, t, c)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt

在学习代码时,要着重学习DDIM采样器部分的代码。大部分基于Diffusion的图像编辑技术都是在DDIM采样的中间步骤中做文章,只要学懂了DDIM采样的代码,学相关图像编辑技术就会非常轻松。除此之外,和LDM相关的文字约束编码、隐空间图像编码解码的接口函数也需要熟悉,不少技术会调用到这几项功能。

还有一些Diffusion相关工作会涉及U-Net的修改。接下来,我们就来看Stable Diffusion官方仓库中U-Net的实现。

U-Net

我们来回头看一下main函数和DDIM采样中U-Net的调用逻辑。和U-Net有关的代码如下所示。LDM模型类 model在主函数中通过load_model_from_config从配置文件里创建,随后成为了sampler的成员变量。在DDIM去噪循环中,LDM模型里的U-Net会在self.model.apply_model方法里被调用。

1
2
3
4
5
6
7
8
# main.py
config = 'configs/stable-diffusion/v1-inference.yaml'
config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt)
sampler = DDIMSampler(model)

# ldm/models/diffusion/ddim.py
e_t = self.model.apply_model(x, t, c)

为了知道U-Net是在哪个类里定义的,我们需要打开配置文件 configs/stable-diffusion/v1-inference.yaml。该配置文件有这样一段话:

1
2
3
4
5
6
model:
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
conditioning_key: crossattn
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel

根据这段话,我们知道LDM类定义在ldm/models/diffusion/ddpm.pyLatentDiffusion里,U-Net类定义在ldm/modules/diffusionmodules/openaimodel.pyUNetModel里。一个LDM类有一个U-Net类的实例。我们先简单看一看LatentDiffusion类的实现。

ldm/models/diffusion/ddpm.py原本来自DDPM论文的官方仓库,内含DDPM类的实现。DDPM类维护了扩散模型公式里的一些变量,同时维护了U-Net类的实例。LDM的作者基于之前DDPM的代码进行开发,定义了一个继承自DDPMLatentDiffusion类。除了DDPM本身的功能外,LatentDiffusion还维护了VAE(self.first_stage_model),CLIP(self.cond_stage_model)。也就是说,LatentDiffusion主要维护了扩散模型中间变量、U-Net、VAE、CLIP这四类信息。这样,所有带参数的模型都在LatentDiffusion里,我们可以从一个checkpoint文件中读取所有的模型的参数。相关代码定义代码如下:

把所有模型定义在一起有好处也有坏处。好处在于,用户想使用Stable Diffusion时,只需要下载一个checkpoint文件就行了。坏处在于,哪怕用户只改了某个子模型(如U-Net),为了保存整个模型,他还是得把其他子模型一起存下来。这其中存在着信息冗余,十分不灵活。Diffusers框架没有把模型全存在一个文件里,而是放到了一个文件夹里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
...):
self.model = DiffusionWrapper(unet_config, conditioning_key)


class LatentDiffusion(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
...):

self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)

我们主要关注LatentDiffusion类的apply_model方法,它用于调用U-Net self.modelapply_model看上去有很长,但略过了我们用不到的一些代码后,整个方法其实非常短。一开始,方法对输入的约束信息编码cond做了一个前处理,判断约束是哪种类型。如论文里所描述的,LDM支持两种约束:将约束与输入拼接、将约束注入到交叉注意力层中。方法会根据self.model.conditioning_keyconcat还是crossattn,使用不同的约束方式。Stable Diffusion使用的是后者,即self.model.conditioning_key == crossattn。做完前处理后,方法执行了x_recon = self.model(x_noisy, t, **cond)。接下来的处理交给U-Net self.model来完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}

x_recon = self.model(x_noisy, t, **cond)

if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon

现在,我们跳转到ldm/modules/diffusionmodules/openaimodel.pyUNetModel类里。UNetModel只定义了神经网络层的运算,没有多余的功能。我们只需要看它的__init__方法和forward方法。我们先来看较为简短的forward方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)

forward方法的输入是x, timesteps, context,分别表示当前去噪时刻的图片、当前时间戳、文本约束编码。根据这些输入,forward会输出当前时刻应去除的噪声eps。一开始,方法会先对timesteps使用Transformer论文中介绍的位置编码timestep_embedding,得到时间戳的编码t_embt_emb再经过几个线性层,得到最终的时间戳编码emb。而context已经是CLIP处理过的编码,它不需要做额外的预处理。时间戳编码emb和文本约束编码context随后会注入到U-Net的所有中间模块中。

1
2
3
4
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

经过预处理后,方法开始处理U-Net的计算。中间结果h会经过U-Net的下采样模块input_blocks,每一个子模块的临时输出都会被保存进一个栈hs里。

1
2
3
4
 h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)

接着,h会经过U-Net的中间模块。

1
h = self.middle_block(h, emb, context)

随后,h开始经过U-Net的上采样模块output_blocks。此时每一个编码器子模块的临时输出会从栈hs里弹出,作为对应解码器子模块的额外输入。额外输入hs.pop()会与中间结果h拼接到一起输入进子模块里。

1
2
3
4
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)

最后,h会被输出层转换成一个通道数正确的eps张量。

1
return self.out(h)

这段代码的数据连接图如下所示:

在阅读__init__前,我们先看一下待会会用到的另一个模块类TimestepEmbedSequential的定义。在PyTorch中,一系列输入和输出都只有一个变量的模块在串行连接时,可以用串行模块类nn.Sequential来把多个模块合并简化成一个模块。而在扩散模型中,多数模块的输入是x, t, c三个变量,输出是一个变量。为了也能用类似的串行模块类把扩散模型的模块合并在一起,代码中包含了一个TimestepEmbedSequential类。它的行为类似于nn.Sequential,只不过它支持x, t, c的输入。forward中用到的多数模块都是通过TimestepEmbedSequential创建的。

1
2
3
4
5
6
7
8
9
10
11
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):

def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x

看完了数据的计算过程,我们回头来看各个子模块在__init__方法中是怎么被详细定义的。__init__的主要内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class UNetModel(nn.Module):
def __init__(self, ...):

self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)

self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)

for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(...)]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...))

self.input_blocks.append(TimestepEmbedSequential(*layers))
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(...)
if resblock_updown
else Downsample(...)
)
)

self.middle_block = TimestepEmbedSequential(
ResBlock(...),
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...),
ResBlock(...),
)

self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(...)
]
ch = model_channels * mult
if ds in attention_resolutions:
layers.append(
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(...)
if resblock_updown
else Upsample(...)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)

__init__方法的代码很长。在阅读这样的代码时,我们不需要每一行都去细读,只需要理解代码能拆成几块,每一块在做什么即可。__init__方法其实就是定义了forward中用到的5个模块,我们一个一个看过去即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class UNetModel(nn.Module):
def __init__(self, ...):

self.time_embed = ...

self.input_blocks = nn.ModuleList(...)
for level, mult in enumerate(channel_mult):
...

self.middle_block = ...

self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
...
self.out = ...

先来看time_embed。回忆一下,在forward里,输入的整数时间戳会被正弦编码timestep_embedding(即Transformer中的位置编码)编码成一个张量。之后,时间戳编码处理模块time_embed用于进一步提取时间戳编码的特征。从下面的代码中可知,它本质上就是一个由两个普通线性层构成的模块。

1
2
3
4
5
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)

再来看U-Net最后面的输出模块out。输出模块的结构也很简单,它主要包含了一个卷积层,用于把中间变量的通道数从dims变成model_channels

1
2
3
4
5
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)

接下来,我们把目光聚焦在U-Net的三个核心模块上:input_blocks, middle_block, output_blocks。这三个模块的组成都很类似,都用到了残差块ResBlock和注意力块。稍有不同的是,input_blocks的每一大层后面都有一个下采样模块,output_blocks的每一大层后面都有一个上采样模块。上下采样模块的结构都很常规,与经典的U-Net无异。我们把学习的重点放在残差块和注意力块上。我们先看这两个模块的内部实现细节,再来看它们是怎么拼接起来的。

Stable Diffusion的U-Net中的ResBlock和原DDPM的U-Net的ResBlock功能完全一样,都是在普通残差块的基础上,支持时间戳编码的额外输入。具体来说,普通的残差块是由两个卷积模块和一条短路连接构成的,即y = x + conv(conv(x))。如果经过两个卷积块后数据的通道数发生了变化,则要在短路连接上加一个转换通道数的卷积,即y = conv(x) + conv(conv(x))

在这种普通残差块的基础上,扩散模型中的残差块还支持时间戳编码t的输入。为了把t和输入x的信息融合在一起,t会和经过第一个卷积后的中间结果conv(x)加在一起。可是,t的通道数和conv(x)的通道数很可能会不一样。通道数不一样的数据是不能直接加起来的。为此,每一个残差块中都有一个用于转换t通道数的线性层。这样,tconv(x)就能相加了。整个模块的计算可以表示成y=conv(x) + conv(conv(x) + linear(t))。残差块的示意图和源代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class ResBlock(TimestepBlock):
def __init__(self, ...):
super().__init__()
...

self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)

self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)

if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

def forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h

代码中的in_layers是第一个卷积模块,out_layers是第二个卷积模块。skip_connection是用于调整短路连接通道数的模块。若输入输出的通道数相同,则该模块是一个恒等函数,不对数据做任何修改。emb_layers是调整时间戳编码通道数的线性层模块。这些模块的定义都在ResBlock__init__里。它们的结构都很常规,没有值得注意的地方。我们可以着重阅读模型的forward方法。

如前文所述,在forward中,输入x会先经过第一个卷积模块in_layers,再与经过了emb_layers调整的时间戳编码emb相加后,输入进第二个卷积模块out_layers。最后,做完计算的数据会和经过了短路连接的原输入skip_connection(x)加在一起,作为整个残差块的输出。

1
2
3
4
5
6
7
8
def forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h

这里有一点实现细节需要注意。时间戳编码emb_out的形状是[n, c]。为了把它和形状为[n, c, h, w]的图片加在一起,需要把它的形状变成[n, c, 1, 1]后再相加(形状为[n, c, 1, 1]的数据在与形状为[n, c, h, w]的数据做加法时形状会被自动广播成[n, c, h, w])。在PyTorch中,x=x[..., None]可以在一个数据最后加一个长度为1的维度。比如对于形状为[n, c]tt[..., None]的形状就会是[n, c, 1]

残差块的内容到此结束。我们接着来看注意力模块。在看模块的具体实现之前,我们先看一下源代码中有哪几种注意力模块。在U-Net的代码中,注意力模型是用以下代码创建的:

1
2
3
4
if ds in attention_resolutions:
layers.append(
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...)
)

第一行if ds in attention_resolutions:用于控制在U-Net的哪几个大层。Stable Diffusion每一大层都用了注意力模块,可以忽略这一行。随后,代码根据是否设置use_spatial_transformer来创建AttentionBlock或是SpatialTransformerAttentionBlock是DDPM中采样的普通自注意力模块,而SpatialTransformer是LDM中提出的支持额外约束的标准Transfomer块。Stable Diffusion使用的是SpatialTransformer。我们就来看一看这个模块的实现细节。

如前所述,SpatialTransformer使用的是标准的Transformer块,它和Transformer中的Transformer块完全一致。输入x先经过一个自注意力层,再过一个交叉注意力层。在此期间,约束编码c会作为交叉注意力层的K, V输入进模块。最后,数据经过一个全连接层。每一层的输入都会和输出做一个残差连接。

当然,标准Transformer是针对一维序列数据的。要把Transformer用到图像上,则需要把图像的宽高拼接到同一维,即对张量做形状变换n c h w -> n c (h * w)。做完这个变换后,就可以把数据直接输入进Transformer模块了。
这些图像数据与序列数据的适配都是在SpatialTransformer类里完成的。SpatialTransformer类并没有直接实现Transformer块的细节,仅仅是U-Net和Transformer块之间的一个过渡。Transformer块的实现在它的一个子模块里。我们来看它的实现代码。

SpatialTransformer有两个卷积层proj_in, proj_out,负责图像通道数与Transformer模块通道数之间的转换。SpatialTransformertransformer_blocks才是真正的Transformer模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class SpatialTransformer(nn.Module):

def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)

self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)

self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
)

self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))

forward中,图像数据在进出Transformer模块前后都会做形状和通道数上的适配。运算结束后,结果和输入之间还会做一个残差连接。context就是约束信息编码,它会接入到交叉注意力层上。

1
2
3
4
5
6
7
8
9
10
11
12

def forward(self, x, context=None):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in

每一个Transformer模块的结构完全符合上文的示意图。如果你之前学过Transformer,那这些代码你会十分熟悉。我们快速把这部分代码浏览一遍。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint

def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x

自注意力层和交叉注意力层都是用CrossAttention类实现的。该模块与Transformer论文中的多头注意力机制完全相同。当forward的参数context=None时,模块其实只是一个提取特征的自注意力模块;而当context为约束文本的编码时,模块就是一个根据文本约束进行运算的交叉注意力模块。该模块用不到mask,相关的代码可以忽略。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)

def forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

if exists(mask):
...

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

Transformer块的内容到此结束。看完了SpatialTransformerResBlock,我们可以回头去看模块之间是怎么拼接的了。先来看U-Net的中间块。它其实就是一个ResBlock接一个SpatialTransformer再接一个ResBlock

1
2
3
4
5
self.middle_block = TimestepEmbedSequential(
ResBlock(...),
SpatialTransformer(...),
ResBlock(...),
)

下采样块input_blocks和上采样块output_blocks的结构几乎一模一样,区别只在于每一大层最后是做下采样还是上采样。这里我们以下采样块为例来学习一下这两个块的结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)

for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(...)]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...))

self.input_blocks.append(TimestepEmbedSequential(*layers))
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(...)
if resblock_updown
else Downsample(...)
)
)

上采样块一开始是一个调整输入图片通道数的卷积层,它的作用和self.out输出层一样。

1
2
3
4
5
6
7
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)

之后正式进行上采样块的构造。此处代码有两层循环,外层循环表示正在构造哪一个大层,内层循环表示正在构造该大层的哪一组模块。也就是说,共有len(channel_mult)个大层,每一大层都有num_res_blocks组相同的模块。在Stable Diffusion中,channel_mult=[1, 2, 4, 4], num_res_blocks=2

1
2
3
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
...

每一组模块由一个ResBlock和一个SpatialTransformer构成。

1
2
3
4
5
6
7
8
9
10
11
layers = [
ResBlock(...)
]
ch = mult * model_channels
if ds in attention_resolutions:
...
layers.append(
SpatialTransformer(...)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
...

构造完每一组模块后,若现在还没到最后一个大层,则添加一个下采样模块。Stable Diffusion有4个大层,只有运行到前3个大层时才会添加下采样模块。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
...
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(...)
if resblock_updown
else Downsample(...)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2

至此,我们已经学完了Stable Diffusion的U-Net的主要实现代码。让我们来总结一下。U-Net是一种先对数据做下采样,再做上采样的网络结构。为了防止信息丢失,下采样模块和对应的上采样模块之间有残差连接。下采样块、中间块、上采样块都包含了ResBlockSpatialTransformer两种模块。ResBlock是图像网络中常使用的残差块,而SpatialTransformer是能够融合图像全局信息并融合不同模态信息的Transformer块。Stable Diffusion的U-Net的输入除了有图像外,还有时间戳t和约束编码ct会先过几个嵌入层和线性层,再输入进每一个ResBlock中。c会直接输入到所有Transformer块的交叉注意力块中。

Diffusers

Diffusers是由Hugging Face维护的一套Diffusion框架。这个库的代码被封装进了一个Python模块里,我们可以在安装了Diffusers的Python环境中用import diffusers随时调用该库。相比之下,Diffusers的代码架构更加清楚,且各类Stable Diffusion的新技术都会及时集成进Diffusers库中。

由于我们已经在上文中学过了Stable Diffusion官方源码,在学习Diffusers代码时,我们只会大致过一过每一段代码是在做什么,而不会赘述Stable Diffusion的原理。

安装

安装该库时,不需要克隆仓库,只需要直接用pip即可。

1
pip install --upgrade diffusers[torch]

之后,随便在某个地方创建一个Python脚本文件,输入官方的示例项目代码。

1
2
3
4
5
6
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline.to("cuda")
pipeline("An image of a squirrel in Picasso style").images[0].save('output.jpg')

运行代码后,”一幅毕加索风格的松鼠图片”的绘制结果会保存在output.jpg中。我得到的结果如下:

在Diffusers中,from_pretrained函数可以直接从Hugging Face的模型仓库中下载预训练模型。比如,示例代码中from_pretrained("runwayml/stable-diffusion-v1-5", ...)指的就是从模型仓库https://huggingface.co/runwayml/stable-diffusion-v1-5中获取模型。

如果在当前网络下无法从命令行中访问Hugging Face,可以先想办法在网页上访问上面的模型仓库,手动下载v1-5-pruned.ckpt。之后,克隆Diffusers的GitHub仓库,再用Diffusers的工具把Stable Diffusion模型文件转换成Diffusers支持的模型格式。

1
2
3
git clone git@github.com:huggingface/diffusers.git
cd diffusers
python scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path <src> --dump_path <dst>

比如,假设你的模型文件存在ckpt/v1-5-pruned.ckpt,你想把输出的Diffusers的模型文件存在ckpt/sd15,则应该输入:

1
python scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path ckpt/v1-5-pruned.ckpt --dump_path ckpt/sd15 

之后修改示例脚本中的路径,就可以成功运行了。
1
2
3
4
5
6
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("ckpt/sd15", torch_dtype=torch.float16)
pipeline.to("cuda")
pipeline("An image of a squirrel in Picasso style").images[0].save('output.jpg')

对于其他的原版SD checkpoint(比如在civitai上下载的),也可以用同样的方式把它们转换成Diffusers兼容的版本。

采样

Diffusers使用Pipeline来管理一类图像生成算法。和图像生成相关的模块(如U-Net,DDIM采样器)都是Pipeline的成员变量。打开Diffusers版Stable Diffusion模型的配置文件model_index.json(在 https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json 网页上直接访问或者在本地的模型文件夹中找到),我们能看到该模型使用的Pipeline:

1
2
3
4
{
"_class_name": "StableDiffusionPipeline",
...
}

diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py中,我们能找到StableDiffusionPipeline类的定义。所有Pipeline类的代码都非常长,一般我们可以忽略其他部分,只看运行方法__call__里的内容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
...
):

# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks

# 1. Check inputs. Raise error if not correct
self.check_inputs(...)

# 2. Define call parameters
batch_size = ...

device = self._execution_device

# 3. Encode input prompt


prompt_embeds, negative_prompt_embeds = self.encode_prompt(...)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(...)

# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
...

# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
...
)[0]

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]


# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()


if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None

...

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

虽然这段代码很长,但代码中的关键内容和我们在本文开头写的伪代码完全一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0)
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
eta = input()
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

text_encoder = CLIP()
c = text_encoder.encode(text)

for t = timesteps:
eps = unet(zt, t, c)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt

我们可以对照着上面的伪代码来阅读这个方法。经过Diffusers框架本身的一些前处理后,方法先获取了约束文本的编码。

1
2
3
# 3. Encode input prompt
# c = text_encoder.encode(text)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(...)

方法再从采样器里获取了要用到的时间戳,并随机生成了一个初始噪声。

1
2
3
4
5
6
7
8
9
10
11
12
13
# Preprocess
...

# 4. Prepare timesteps
# timesteps = ddim_scheduler.get_timesteps(T, ddim_steps)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

# 5. Prepare latent variables
# zt = randn(image_shape)
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
...
)

做完准备后,方法进入去噪循环。循环一开始是用U-Net算出当前应去除的噪声noise_pred。由于加入了CFG,U-Net计算的前后有一些对数据形状处理的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# eps = unet(zt, t, c)

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
...
)[0]

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

有了应去除的噪声,方法会调用扩散模型采样器对当前的噪声图片进行更新。Diffusers把采样的逻辑全部封装进了采样器的step方法里。对于包括DDIM在内的所有采样器,都可以调用这个通用的接口,完成一步采样。eta等采样器参数会通过**extra_step_kwargs传入采样器的step方法里。

1
2
3
4
5
# std = ddim_scheduler.get_std(t, eta)
# zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

经过若干次循环后,我们得到了隐空间下的生成图片。我们还需要调用VAE把隐空间图片解码成普通图片。代码中的self.vae.decode(latents / self.vae.config.scaling_factor, ...)用于解码图片。

1
2
3
4
5
6
7
8
9
10
11
12
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None

...

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

就这样,我们很快就看完了Diffusers的采样代码。相比之下,Diffusers的封装确实更合理,主要的图像生成逻辑都写在Pipeline类的__call__里,剩余逻辑都封装在VAE、U-Net、采样器等各自的类里。

U-Net

接下来我们来看Diffusers中的U-Net实现。还是打开模型配置文件model_index.json,我们可以找到U-Net的类名。

1
2
3
4
5
6
7
8
{
...
"unet": [
"diffusers",
"UNet2DConditionModel"
],
...
}

diffusers/models/unet_2d_condition.py文件中,我们可以找到类UNet2DConditionModel。由于Diffusers集成了非常多新特性,整个文件就像一锅大杂烩一样,掺杂着各种功能的实现代码。不过,这份U-Net的实现还是基于原版Stable Diffusion的U-Net进行开发的,原版代码的每一部分都能在这份代码里找到对应。在阅读代码时,我们可以跳过无关的功能,只看我们在Stable Diffusion官方仓库中见过的部分。

先看初始化函数的主要内容。初始化函数依然主要包括time_proj, time_embedding, down_blocks, mid_block, up_blocks, conv_in, conv_out这几个模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def __init__(...):
...
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
...
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
...
self.time_embedding = TimestepEmbedding(...)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
for i, down_block_type in enumerate(down_block_types):
...
down_block = get_down_block(...)

if mid_block_type == ...
self.mid_block = ...

for i, up_block_type in enumerate(up_block_types):
up_block = get_up_block(...)

self.conv_out = nn.Conv2d(...)

其中,较为重要的down_blocks, mid_block, up_blocks都是根据模块类名称来创建的。我们可以在Diffusers的Stable Diffusion模型文件夹的U-Net的配置文件unet/config.json中找到对应的模块类名称。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{
...
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"mid_block_type": "UNetMidBlock2DCrossAttn",
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
],
...
}

diffusers/models/unet_2d_blocks.py中,我们可以找到这几个模块类的定义。和原版代码一样,这几个模块的核心组件都是残差块和Transformer块。在Diffusers中,残差块叫做ResnetBlock2D,Transformer块叫做Transformer2DModel。这几个类的执行逻辑和原版仓库的也几乎一样。比如CrossAttnDownBlock2D的定义如下:

1
2
3
4
5
6
class CrossAttnDownBlock2D(nn.Module):
def __init__(...):
for i in range(num_layers):
resnets.append(ResnetBlock2D(...))
if not dual_cross_attention:
attentions.append(Transformer2DModel(...))

接着我们来看U-Net的forward方法。忽略掉其他功能的实现,该方法的主要内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
...):

# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0

# 1. time
timesteps = timestep
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb, timestep_cond)

# 2. pre-process
sample = self.conv_in(sample)

# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
...)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
...)

# 5. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
...)

# 6. post-process
sample = self.conv_out(sample)

return UNet2DConditionOutput(sample=sample)


该方法和原版仓库的实现差不多,唯一要注意的是栈相关的实现。在方法的下采样计算中,每个downsample_block会返回多个残差输出的元组res_samples,该元组会拼接到栈down_block_res_samples的栈顶。在上采样计算中,代码会根据当前的模块个数,从栈顶一次取出len(upsample_block.resnets)个残差输出。

1
2
3
4
5
6
7
8
9
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(...)
down_block_res_samples += res_samples

for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(...)

现在,我们已经看完了Diffusers中U-Net的主要内容。可以看出,Diffusers的U-Net包含了很多功能,一般情况下是难以自己更改这些代码的。有没有什么办法能方便地修改U-Net的实现呢?由于很多工作都需要修改U-Net的Attention,Diffusers给U-Net添加了几个方法,用于精确地修改每一个Attention模块的实现。我们来学习一个修改Attention模块的示例。

U-Net类的attn_processors属性会返回一个词典,它的key是每个Attention运算类所在位置,比如down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor,它的value是每个Attention运算类的实例。默认情况下,每个Attention运算类都是AttnProcessor,它的实现在diffusers/models/attention_processor.py文件中。

为了修改Attention运算的实现,我们需要构建一个格式一样的词典attn_processor_dict,再调用unet.set_attn_processor(attn_processor_dict),取代原来的attn_processors。假如我们自己实现了另一个Attention运算类MyAttnProcessor,我们可以编写下面的代码来修改Attention的实现:

1
2
3
4
5
6
7
8
9

attn_processor_dict = {}
for k in unet.attn_processors.keys():
if we_want_to_modify(k):
attn_processor_dict[k] = MyAttnProcessor()
else:
attn_processor_dict[k] = AttnProcessor()

unet.set_attn_processor(attn_processor_dict)

MyAttnProcessor的唯一要求是,它需要实现一个__call__方法,且方法参数与AttnProcessor的一致。除此之外,我们可以自由地实现Attention处理的细节。一般来说,我们可以先把原来AttnProcessor的实现代码复制过去,再对某些细节做修改。

总结

在这篇文章中,我们学习了Stable Diffusion的原版实现和Diffusers实现的主要内容:采样算法和U-Net。具体来说,在原版仓库中,采样的实现一部分在主函数中,一部分在DDIM采样器类中。U-Net由一个简明的PyTorch模块类实现,其中比较重要的子模块是残差块和Transformer块。相比之下,Diffusers实现的封装更好,功能更多。Diffusers用一个Pipeline类来维护采样过程。Diffusers的U-Net实现与原版完全相同,且支持更复杂的功能。此外,Diffusers还给U-Net提供了精确修改Attention计算的接口。

不管是哪个Stable Diffusion的框架,都会提供一些相同的原子操作。各种基于Stable Diffusion的应用都应该基于这些原子操作开发,而无需修改这些操作的细节。在学习时,我们应该注意这些操作在不同的框架下的写法是怎么样的。常用的原子操作包括:

  • VAE的解码和编码
  • 文本编码器(CLIP)的编码
  • 用U-Net预测当前图像应去除的噪声
  • 用采样器计算下一去噪迭代的图像

在原版仓库中,相关的实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
# VAE的解码和编码
model.decode_first_stage(...)
model.encode_first_stage(...)

# 文本编码器(CLIP)的编码
model.get_learned_conditioning(...)

# 用U-Net预测当前图像应去除的噪声
model.apply_model(...)

# 用采样器计算下一去噪迭代的图像
p_sample_ddim(...)

在Diffusers中,相关的实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
# VAE的解码和编码
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
latents = self.vae.encode(image).latent_dist.sample(generator) * self.vae.config.scaling_factor

# 文本编码器(CLIP)的编码
self.encode_prompt(...)

# 用U-Net预测当前图像应去除的噪声
self.unet(..., return_dict=False)[0]

# 用采样器计算下一去噪迭代的图像
self.scheduler.step(..., return_dict=False)[0]

如今zero-shot(无需训练)的Stable Diffusion编辑技术一般只会修改采样算法和Attention计算,需训练的编辑技术有时会在U-Net里加几个模块。只要我们熟悉了普通的Stable Diffusion是怎么样生成图像的,知道原来U-Net的结构是怎么样的,我们在阅读新论文的源码时就可以把这份代码与原来的代码进行对比,只看那些有修改的部分。相信读完了本文后,我们不仅加深了对Stable Diffusion本身的理解,以后学习各种新出的Stable Diffusion编辑技术时也会更加轻松。

在上一篇文章中,我们梳理了基于自编码器(AE)的图像生成模型的发展脉络,并引出了Stable Diffusion的核心思想。简单来说,Stable Diffusion是一个两阶段的图像生成模型,它先用一个AE压缩图像,再在压缩图像所在的隐空间上用DDPM生成图像。在这篇文章中,我们来精读Stable Diffusion的论文:High-Resolution Image Synthesis with Latent Diffusion Models

注意:如果你从未学习过扩散模型,Stable Diffusion并不是你应该的读的第一篇论文。请参照我[上篇文章]的早期工作总结,至少在学会了DDPM后再来学习Stable Diffusion。

摘要与引言

论文摘要的大意如下:扩散模型的生成效果很好,但是,在像素空间上训练和推理扩散模型的计算开销都很大。为了在不降低质量与易用性的前提下用较少的计算资源训练扩散模型,我们在一个预训练过的自编码器的隐空间上使用扩散模型。相较以往的工作,在这种表示下训练扩散模型首次在减少计算复杂度和维持图像细节间达到几近最优的平衡点,极大地提升了视觉保真度。通过向模型架构中引入交叉注意力层,我们把扩散模型变成了强大而灵活的带约束图像生成器,它支持常见的约束,如文字、边界框,且能够以纯卷积方式实现高分辨率的图像合成。我们的隐扩散模型(latent diffusion model, LDM) 在使用比像素扩散模型少得多的计算资源的前提下,在各项图像合成任务上取得最优成果或顶尖成果。

整理一下。论文提出了一种叫LDM的图像生成模型。论文想解决的问题是减少像素空间扩散模型的运算开销。为此,LDM借助了VQVAE「先压缩、再生成」的想法,把扩散模型用在AE的隐空间上,在几乎不降低生成质量的前提下减少了计算量。另外,LDM还支持带约束图像合成及纯卷积图像超分辨率。

在上一篇回顾LDM早期工作的文章中,我们已经理解了LDM想解决的问题及解决问题的思路。因此,在读完摘要后,我们接下来读文章时只需要关注LDM的两个创新点:

  1. LDM的AE是怎么设计以达到压缩比例与质量的平衡的。
  2. LDM怎么实现带约束的图像合成。

引言基本是摘要的扩写。首先,引言大致介绍了图像合成任务的背景,提及了扩散模型近期的突出表现。随后,引言介绍了本文想解决的主要问题:扩散模型的训练和推理太耗时了,需要在不降低效果的前提下减少扩散模型的运算量。最后,引言揭示了本工作的解决方法:使用类似VQGAN的两阶段图像生成方法。

引言的前两部分没有什么关键信息,而最后一部分介绍了本工作改进扩散模型的动机,值得一读。如下图所示,DDPM的论文展示了从不同去噪时刻的同一个噪声图像开始的不同生成结果,比如$\mathbf{x}_{750}$指从时刻$t=750$的去噪图像开始,多次以不同随机数执行DDPM的反向过程,生成的多幅图像。LDM作者认为,DDPM的这一实验表明,扩散模型的图像生成分两个阶段:先是对语义进行压缩,再是对图像的感知细节压缩。正因此,随机对早期的噪声图像去噪,生成图像的内容会更多样;而随机对后期的噪声图像去噪,生成图像只是在细节上有所不同。LDM的作者认为,扩散模型的大量计算都浪费在了生成整幅图像的细节上,不如只让扩散模型描述比较关键的语义压缩部分,而让自编码器(AE)负责感知细节压缩部分。

引言在结尾总结了本工作的贡献:

  1. 相比之前按序列处理图像的纯Transformer的方法,扩散模型能更好地处理二维数据。因此,LDM生成隐空间图像时不需要那么重的压缩比例(比如DIV2K数据集上,LDM只需要将图像下采样4倍,而之前的纯Transformer方法要下采样8倍或16倍),图像在压缩时能有更高的保真度,整套方法能更高效地生成高分辨率图像。
  2. 在大幅降低计算开销的前提下在多项图像生成任务上取得了顶尖成果。
  3. 相比于之前同时训练图像压缩模型和图像生成模型的方法,该方法分步训练两个模型,训练起来更加简单。
  4. 对于有着稠密约束的任务(如超分辨率、补全、语义生成),该方法的模型能换成一个纯卷积版本的,且能生成边长为1024的图像。
  5. 该工作设计了一种通用的约束机制,该机制基于交叉注意力,支持多模态训练。作者训练了多种带约束的模型。
  6. 作者把工作开源了,并提供了预训练模型。

我们来整理一下这些贡献。读论文时,可以忽略第6条。第2条是成果,与方法设计无关。第1、3条主要描述了提出两阶段图像生成建模方法的贡献。第4条是把方法拓展到稠密约束任务的贡献。第5条是提出了新约束机制的贡献。所以,在学习论文的方法时,我们还是主要关注摘要里就提过的那两个创新点。在读完引言后,我们可以把阅读目标再细化一下:

  1. LDM的AE是怎么设计以达到压缩比例与质量的平衡的。与纯基于Transformer的VQGAN相比,它有什么不同。
  2. LDM怎么用交叉注意力机制实现带约束的图像生成。

相关工作

作者主要从两个角度回顾了早期工作:不同架构的图像生成模型与两阶段的图像合成方法。其回顾逻辑与本系列的第一篇文章类似,在此就不过多介绍了。除了介绍早期工作外,作者重申了引言中的对比结果,强调了LDM相对于扩散模型的创新和相对于两阶段图像生成模型的创新。

方法

在方法章节中,作者先是大致介绍了使用LDM这种两阶段图像生成架构的优点,再分三部分详细介绍了论文的实现细节:图像压缩AE的实现、LDM的实现、约束的实现。开头的介绍和AE的实现相对比较重要,我们放在一起详细阅读;相对于DDPM,LDM几乎没有做任何修改,只是把要拟合的图片从真实图片换成了压缩图片,这一部分我们会快速浏览一遍;而添加约束的方法有所创新,我们会详细阅读一遍。

AE与两阶段图像生成模型

我们来先读3.1节,看一看AE的具体实现方法,再回头读第3节开头介绍的两阶段图像生成模型的优点。

LDM配套的图像压缩模型(论文中称之为”感知压缩模型”)和VQGAN几乎完全一样。该压缩模型的原型是一个AE。普通的AE会用原图像和重建图像的重建误差(L1误差或者L2误差)来训练。在普通的AE的基础上,该压缩模型参考了GAN的误差设置方法,使用感知误差代替重建误差,并添加了基于patch的对抗误差。

但该图像压缩模型的输出与VQGAN有所不同。我们先回忆一下VQGAN的原理。VQGAN的输出会接到Transformer里,Transformer的输入必须是离散的。因此,VQGAN必须要额外完成两件事:1)让连续输出变成离散输出;2)用正则化方法防止过拟合。为此,VQGAN使用了VQVAE里的向量离散化操作,该操作能同时完成这两件事。

而LDM的压缩模型的输出会接入一个扩散模型里,扩散模型的输入是连续的。因此,LDM的压缩模型只需要额外完成使用正则化方法这一件事。该压缩模型不必像VQGAN一样非得用向量离散化来完成正则化。如我们在第一篇文章中讨论的,作者在LDM的压缩模型中使用了两种正则化方法:VQ正则化与KL正则化。前者来自于VQVAE,后者来自于VAE。

该压缩模型相较VQGAN有一项明显的优势。VQGAN的Transformer只能按一维序列来处理图像(通过把二维图像reshape成一维),且只能处理较小的压缩图像($16\times16$)。而本身用于二维图像生成的LDM能更好地利用二维信息,因此可以处理更大的压缩图像($64\times 64$)。这样,LDM的压缩模型的压缩程度不必那么重,其保真度会比VQGAN高。

看完了3.1节,我们来回头看第3节开头介绍了LDM的三项优点:1)通过规避在高维图像空间上训练扩散模型,作者开发出了一个因在低维空间上采样而计算效率大幅提升的扩散模型;2)作者发掘了扩散模型中来自U-Net架构的归纳偏置(inductive bias),使得它们能高效地处理有空间结构的数据(比如二维图像),避免像之前基于Transformer的方法一样使用激进、有损质量的压缩比例;3)本工作的压缩模型是通用的,它的隐空间能用来训练多种图像生成模型。第一个优点是相对于DDPM。第二个是优点是相对于使用Transformer的VQGAN,我们在上一段已经分析过了。第三个优点是相对于之前那些换一个任务就需要换一个压缩模型的两阶段图像生成模型。

归纳偏置可以简单理解为某个学习算法对一类数据的优势。比如CNN结构适合处理图像数据。

隐扩散模型(LDM)

在DDPM中,一个参数为$\theta$的神经网络$\epsilon_\theta$会根据当前时刻$t$的带噪图片$x_t$预测本时刻的噪声$\epsilon_\theta(x_t, t)$。网络的学习目标是让预测的噪声和真实的噪声$\epsilon$一致。

LDM的原理和DDPM完全一样,只不过训练图片从像素空间上的真实图片$x_0$变成了隐空间上的压缩图片$z_0$,每一轮的带噪图片由$x_t$变成了隐空间上的带噪图片$z_t$。在训练时,相比DDPM,只需要多对$x_0$用一次编码器变成$z_0$即可。

如果你在理解这部分内容时有疑问,请去阅读DDPM的相关文章。LDM的具体结构我们会在第三篇代码阅读文章中讨论。

约束机制

让模型支持带约束图像生成,其实就是想办法把额外的约束信息输入进扩散模型中。显然,最简单的添加约束的方法就是把额外的信息和扩散模型原本的输入$z_t$拼接起来。如果约束是一个值,就把相同的值拼接到$z_t$的每一个像素上;如果约束本身具有空间结构(如语音分割图片),就可以把约束重采样至和$z_t$一样的分辨率,再逐像素拼接。除了直接的拼接外,作者在LDM中还使用了另一种融合约束信息的方法。

DDPM中含有自注意力层。自注意力操作其实基于注意力操作$Attention(Q, K, V)$,它可以解释成一个数据库中存储了许多数据$V$,数据的索引(键)是$K$,现在要用查询$Q$查询数据库里的数据并返回查询结果。注意力操作有几种用法,第一种用法是交叉注意力$CrossAttn(A, B)=Attention(W_Q \cdot A, W_K \cdot B, W_V \cdot B)$,可以理解成数据$A$, $B$做了一次信息融合;第二种用法是自注意力$SelfAttn(A)=Attention(W_Q \cdot A, W_K \cdot A, W_V \cdot A)$,可以理解成数据$A$自己做了一次特征提取。

既然交叉注意力操作可以融合两类信息,何不把DDPM的自注意力层换成交叉注意力层,把$K$, $V$换成来自约束的信息,以实现带约束图像生成呢?如下图所示,通过把用编码器$\tau_\theta$编码过的约束信息输入进扩散模型交叉注意力层的$K$, $V$,LDM实现了带约束图像生成。这里的实现细节我们会在第三篇代码阅读文章中讨论。

根据论文中实验的设计,对于作用于全局的约束,如文本描述,使用交叉注意力较好;对于有空间信息的约束,如语义分割图片,则用拼接的方式较好。

实验

在这一章里,作者按照介绍方法的顺序,依次探究了图像压缩模型、无约束图像生成、带约束图像合成的实验结果。我们主要关心前两部分的实验结果。

感知压缩程度的折衷

论文首先讨论了图像压缩模型在不同的下采样比例$f$下的实验结果,其中$f\in\{1, 2, 4, 8, 16, 32\}$。这些实验分两部分,第一部分是训练速度上的实验,第二部分是采样速度与效果上的实验。

在ImageNet上以不同下采样比例$f$训练一定步数后LDM的采样指标对比结果如下图所示。其中,FID指标越低越好,Inception Score越高越好。结果显示,无论下采样比例$f$是过大还是过小都会降低训练速度。作者分析了$f$较小或较大时训练速度慢的原因:$f$过小时扩散模型把过多的精力放在了本应由压缩模型负责的感知压缩上;$f$过大时图像信息在压缩中损失过多。LDM-$\{4\text{-}16\}$的表现相对好一些。

在实验的第二部分中,作者比较了不同采样比例$f$的LDM在CelebA-HQ(下图左侧)和ImageNet(下图右侧)上的采样速度和采样效果。下图中,横坐标为吞吐量,越靠右表示采样速度越快。同一个模型的不同实验结果表示使用不同DDIM采样步数时的实验结果,每一条线上的结果从右到左分别是DDIM采样步数取$\{10, 20, 50, 100, 200\}$的采样结果(DDIM步数越少,采样速度越快,生成图片质量越低)。对于CelebA-HQ上的实验,若采样步数较多,则还是LDM-$\{4, 8\}$效果较好,只有在采样步数较少时压缩比更高的LDM才有优势。而对于ImageNet上的实验,$f$太小或太大的结果都很差,整体上还是LDM-$\{4, 8\}$的结果较好。

综上,根据实验,作者认为$f$取适中的$4$或$8$比较妥当。下采样比例$f=8$也正是Stable Diffusion采用的配置。

图像生成效果

在这一节中,作者在几个常见的数据集上对比了LDM与其他模型的无约束图像生成效果。作者主要比较了两类指标:表示采样质量的FID和表示数据分布覆盖率的精确率及召回率(Precision-and-Recall)。

在介绍具体结果之前,先对这个不太常见的精确率及召回率指标做一个解释。精确率及召回率常用于分类等有确定答案的任务中,分别表示所有被分类为正的样本中有多少是分对了的、所有真值为正的样本中有多少是被成功分类成正的。而无约束图像生成中的精确率及召回率的解释可以参加论文Improved Precision and Recall Metric for Assessing
Generative Models
。如下图所示,设真实分布为蓝色,生成模型的分布为红色,则红色样本落在蓝色分布的比例为精确率,蓝色样本落在红色分布的比例为召回率。简单来说,精确率能描述采样质量,召回率能描述生成分布与真实分布的覆盖情况。

接下来,我们回头来看论文展示的无约束图像生成对比结果,如下图所示。整体上看,LDM的表现还不错。虽然在FID指标上无法超过GAN或其他扩散模型,但是在精确率和召回率上还是颇具优势。唯一没有被LDM战胜的是LSUN-Bedrooms上的ADM模型,但作者提到,相比ADM,LDM只用了一半的参数,且只需四分之一的训练资源。

带约束图像合成

这一节里,作者展示了LDM的文生图能力。论文中的LDM用了一个从头训练的基于Transformer的文本编码器,与后续使用CLIP的Stable Diffusion差别较大。这一部分的结果没那么重要,大致看一看就好。

本文的文生图模型是一个在LAION-400M数据集上训练的KL约束LDM。它的文本编码器是一个Transformer,编码后的特征会以交叉注意力的形式传入LDM。采样时,LDM使用了Classifier-Free Guidance。

Classifier-Free Guidance可以让输出图片更符合文本约束。这是一种适用于所有扩散模型的采样策略,并非要和LDM绑定,感兴趣可以去阅读相关论文。

LDM与其他模型的文生图效果对比如下图所示。虽然这个版本的LDM并没有显著优于其他模型,但它的参数量是最少的。

LDM在类别约束的图像合成上表现也很不错,超越了当时的其他模型。其结果在此略过。

剩余的带约束图像合成任务都可以看成是图像转图像任务,比如图像超分辨率是低质量图像到高质量图像的转换、语义生成是把语义分割图像转换成一幅合成图像。要添加这些约束,只需要把这些任务的输入图片和LDM原本的输入$z_t$拼接起来即可。比如对于图像超分辨率,可以把输入图片直接与隐空间图片$z_t$拼接,解码后图片会被自然上采样$f$倍;对于语义生成,可以把下采样$f$倍的语义分割图与$z_t$拼接。论文用这些任务上的实验证明了LDM的泛用性。由于这部分实验与LDM的主要知识无关,具体实验结果就不在此详细介绍了。

总结

论文末尾探讨了LDM的两大不足。首先,尽管LDM的计算需求比其他像素空间上的扩散模型要少得多,但受制于扩散模型本身的串行采样,它的采样速度还是比GAN慢上许多。其次,LDM使用了一个自编码器来压缩图像,重建图像带来的精度损失会成为某些需要精准像素值的任务的性能瓶颈。

论文最后再次总结了此方法的贡献。LDM的主要贡献其实只有两点:在不损失效果的情况下用两阶段的图像生成方法大幅提升了训练和采样效率、借助交叉注意力实现了各任务通用的约束机制。这两个贡献总结得非常精准。之后的Stable Diffusion之所以大受欢迎,第一就是因为它采样所需的计算资源不多,大众能使用消费级显卡完成图像生成,第二就是因为它强大的文字转图片生成效果。

我们再从知识学习的角度总结一下LDM。LDM的核心知识是DDPM和VQGAN。如果你能看懂之前这两篇论文,那你一下子就能明白LDM是的核心思想是什么,看论文时只需要精读交叉注意力约束机制那一段即可,其他实验内容在现在看来已经价值不大了。由于近两年有大量基于Stable Diffusion开发的工作,相比论文,阅读源代码的重要性会大很多。我们会在下一篇文章里详细学习Stable Diffusion的官方源码和最常用的Stable Diffusion第三方实现——Diffusers框架。

在2022年的这波AI绘画浪潮中,Stable Diffusion无疑是最受欢迎的图像生成模型。究其原因,第一,Stable Diffusion通过压缩图像尺寸显著提升了扩散模型的运行效率,使得每个用户能在自己的商业级显卡上运行模型;第二,有许多基于Stable Diffusion的应用,比如Stable Diffusion自带的文生图、图像补全,以及ControlNet、LoRA、DreamBooth等插件式应用;第三,得益于前两点,Stable Diffusion已经形成了一个庞大的用户社群,大家互相分享模型,交流心得。

不仅是大众,Stable Diffusion也吸引了大量科研人员,很多本来研究GAN的人纷纷转来研究扩散模型。然而,许多人在学习Stable Diffusion时却犯了难:又是公式扎堆的扩散模型,又是VAE,又是U-Net,这该怎么学起呀?

其实,一上来就读Stable Diffusion是很难读懂的。而如果你把之前的一些更基础的文章读懂,再回头来读Stable Diffusion,就会畅行无阻了。在这篇及之后的几篇文章中,我将从科研的角度对Stable Diffusion做一个全面的解读。在第一篇文章中,我将面向完全没接触过图像生成的读者,从头介绍Stable Diffusion是怎样从早期工作中一步一步诞生的;在第二篇文章中,我将详细解读Stable Diffusion的论文;在最后的第三篇文章中,我将带领大家阅读Stable Diffusion的官方源码,以及一些流行的开源库的Stable Diffusion实现。后续我还会写其他和Stable Diffusion相关的文章,比如ControlNet的介绍。

从自编码器谈起

包括Stable Diffusion在内,很多图像生成模型都可以看成是一种非常简单的模型——自编码器——的改进版。要谈Stable Diffusion是怎么逐渐诞生的,其实就是在谈自编码器是一步一步进化的。我们的学习就从自编码器开始。

尽管PNG、JPG等图像压缩方法已经非常成熟,但我们会想,会不会还有更好的图像压缩算法呢?图像压缩,其实就是找两个映射,一个把图片编码成压缩数据,另一个把压缩数据解码回图片。我们知道,神经网络理论上可以拟合任何映射。那我们干脆用两个神经网络来拟合两种映射,以实现一个图像压缩算法。负责编码的神经网络叫编码器(Encoder),负责解码的神经网络叫做解码器(Decoder)

光定义了神经网络还不够,我们还需要给两个神经网络设置一个学习目标。在运行过程中,神经网络应该满足一个显然的约束:编码再解码后的重建图像应该和原图像尽可能一致,即二者的均方误差应该尽可能小。这样,我们只需要随便找一张图片,通过编码器和解码器得到重建图像,就能训练神经网络了。我们不需要给图片打上标签,整个训练过程是自监督的。所以我们说,整套模型是一个自编码器(Autoencoder,AE)

图像压缩模型AE为什么会和图像生成扯上关系呢?你可以试着把AE的输入图像和编码器遮住,只看解码部分。把一个压缩数据解码成图像,换个角度看,不就是在根据某一数据生成图像嘛。

很可惜,AE并不是一个合格的图像生成模型。我们常说的图像生成,具体是指让程序生成各种各样的图片。为了让程序生成不同的图片,我们一般是让程序根据随机数(或是随机向量)来生成图片。而普通的AE会有过拟合现象,这导致AE的解码器只认得训练集里的图片经编码器解码出来的压缩数据,而不认得随机生成的压缩数据,进而也无法达到图像生成的要求。

所谓过拟合,就是指模型只能处理训练数据,而不能推广到一般的数据上。举一个极端的例子,如下图所示,编码器和解码器直接记忆了整个数据集,把所有图片压缩成了一个数字。也就是模型把编码器当成一个图片到数字的词典,把解码器当成一个数字到图片的词典。这样,不管数据集有多大,所有图片都可以被压缩成一个数字。这样的AE确实压缩能力很强,但它完全没用,因为它过拟合了,处理不了训练集以外的数据。

过拟合现象在普通版AE中是不可避免的。为了利用AE的解码器来生成图片,许多工作都在试图克服AE的过拟合现象。AE的改进思路很多,在这篇文章中,我们仅把AE的改进路线粗略地分成两种:解决过拟合问题以直接用AE做图像生成、用AE压缩图像间接实现图像生成。

第一条路线:VAE 和 DDPM

在第一条改进路线中,许多后续工作都试图用更高级的数学模型来解决AE的过拟合问题。变分自编码器(Variational Autoencoder, VAE) 就是其中的代表。

VAE对AE做了若干改动。第一,VAE让编码器的输出不再是一个确定的数据,而是一个正态分布中的一个随机数据。更具体一点,训练时,编码器会同时输出一个均值和方差。随后,模型会从这个均值和方差表达的正态分布里随机采样一个数据,作为解码器的输入。直观上看,这一改动就是在AE的基础上,让编码器多输出了一个方差,使得原AE编码器的输出发生了一点随机扰动。

这一改动可以缓解过拟合现象。这是为什么呢?我们可以这样想:原来的AE之所以会过拟合,是因为它强行记住了训练集里每一个数据的编码输出。现在,我们在VAE里让编码器不再输出一个固定值,而是随机输出一个在均值附近的值。这样的话,VAE就不能死记硬背了,必须要找出数据中的规律。

VAE的第二项改动是多添加一个学习目标,让编码器的输出和标准正态分布尽可能相似。前面我们谈过,图像生成模型一般会根据一个随机向量来生成图像。最常用的产生随机向量的方法是去标准正态分布里采样。也就是说,在用VAE生成图像时,我们会抛掉编码器,用下图所示的流程来生成图像。如果我们不约束编码器的输出分布,不让它输出一个和标准正态分布很相近的分布的话,解码器就不能很好地根据来自标准正态分布的随机向量生成图像了。

综上,VAE对AE做了两项改进:使编码器输出一个正态分布,且该分布要尽可能和标准正态分布相似。训练时,模型从编码器输出的分布里随机采样一个数据作为解码器的输入;图像采样(图像生成)时,模型从标准正态分布里随机采样一个数据作为解码器的输入。VAE的误差函数由两部分组成:原图像和重建图像的重建误差、编码器输出和标准正态分布之间的误差。VAE要最小化重建误差,最大化编码器输出与标准正态分布的相似度。

分布与分布之间的误差可以用一个叫KL散度的指标表示。所以,在上面那个误差函数公式中,负的相似度应该被替换成KL散度。VAE的这两项改动本质上都是在解决AE的过拟合问题,所以,VAE的改动可以被看成一种正则化方法。我们可以把VAE的正则化方法简称为KL正则化

在机器学习中,正则化方法就是「降低模型过拟合的方法」的简称。

VAE确实能减轻AE的过拟合。然而,由于VAE只是让重建图像和原图像的均方误差(重建误差)尽可能小,而没有对重建图像的质量施加更多的约束,VAE的重建结果和图像生成结果都非常模糊。以下是VAE在CelebA数据集上图像生成结果。

在众多对VAE的改进方法中,一个叫做去噪扩散概率模型(Denoising Diffusion Probabilistic Model, DDPM) 的图像生成模型脱颖而出。DDPM正是当今扩散模型的开山鼻祖。我们来看一下DDPM是怎样基于VAE对图像生成建模的。

VAE之所以效果不好,很可能是因为它的约束太少了。VAE的编码和解码都是用神经网络表示的。神经网络是一个黑盒,我们不好对神经网络的中间步骤施加约束,只好在编码器的输出(某个正态分布)和解码器的输出(重建图像)上施加约束。能不能让VAE的编码和解码过程更可控一点呢?

DDPM的设计灵感来自热力学:一个分布可以通过一系列简单的变化(如添加高斯噪声)逐渐变成另一个分布。恰好,VAE的编码器不正是想让来自训练集的图像(训练集分布)变成标准正态分布吗?既然如此,就不要用一个可学习的神经网络来表示VAE的编码器了,干脆用一些预定义好的加噪声操作来表示编码过程。可以从数学上证明,经过了多次加噪声操作后,最后的图像分布会是一个标准正态分布。

既然编码是加噪声,那解码时就应该去掉噪声。DDPM的解码器也不再是一个不可解释的神经网络,而是一个能预测若干个去噪结果的神经网络。

相比只有两个约束条件的VAE,DDPM的约束条件就多得多了。在DDPM中,第t个去噪操作应该尽可能抵消掉第t个加噪操作。

让我们来更具体地认识一下DDPM的学习目标。所谓添加噪声,就是在一个均值约等于当前图像的正态分布上采样。比如要对图像$\mathbf{x}$添加噪声,我们可以在$\mathcal{N}(0.9\mathbf{x},\mathbf{I})$这个分布里采样一张新图像。新的图像每个像素的均值是原来的0.9倍左右,且新图像会出现很多噪声。我们设$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$为第$t$步加噪声的正态分布。经过一些数学推导,我们可以求出这一步操作的逆操作$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$,这个加噪声逆操作也是一个正态分布。既然如此,我们可以设第$t$步去噪声也为一个正态分布$p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$,让第$t$步去噪声和第$t$步加噪声的逆操作尽可能相似。

总结一下,DDPM对VAE做了如下改动:

  1. 编码器是一系列不可学习(固定)的加噪声操作
  2. 解码器是一系列可学习的去噪声操作
  3. 图像尺寸自始至终不变

相比于VAE,DDPM的编码过程和解码过程的定义更加明确,可以施加的约束更多。因此,如下图所示,它的生成效果会比VAE好很多。同时,DDPM和VAE类似,它在编码时会从分布里采样,而不是只输出一个固定值,不会出现AE的过拟合问题。

DDPM的图像生成结果

DDPM的生成效果确实很好。但是,由于DDPM始终会对同一个尺寸的数据进行操作,图像的尺寸极大地影响了DDPM的运行速度,用DDPM生成高分辨率图像需要耗费大量计算资源。因此,想要用DDPM生成高质量图像,还得经过另一条路线。

第二条路线:VQVAE

在AE的第二条改进路线中,一些工作干脆放弃使用AE做图像生成,转而利用AE的图像压缩能力,把图像生成拆成两步来做:先用AE的编码器把图像压缩成更小的图像,再用另一个图像生成模型生成小图像,并用AE的解码器把小图像重建回真实图像。

为什么会有这么奇怪的图像生成方法呢?这得从另一类图像生成模型讲起。在机器翻译模型Transformer横空出世后的一段时间里,有很多工作都想把Transformer用在图像生成上。但是,原本用来生成文本的Transformer无法直接应用在图像上。在自然语言处理(NLP)中,一个句子可以用若干个单词表示。而每个单词又是用一个整数表示。所以,Transformer生成句子时,实际上是在生成若干个离散的整数,也就是生成一个离散向量。而在图像生成模型中,每个像素的颜色值是一个连续的浮点数。想把Transformer直接用在图像生成上,就得想办法把图像用离散向量表示。我们知道,AE可以把图像编码成一个连续向量。能不能做一些修改,让AE把图像编码成一个离散向量呢?

Vector Quantised-Variational AutoEncoder (VQVAE) 就是一个能把图像编码成离散向量的AE(虽然作者在取名时用了VAE)。我们来简单看一下VQVAE是怎样把图像编码成离散向量的。

假设我们有了一个能编码出离散向量的AE。

由于神经网络不能很好地处理离散数据,我们要引入NLP里的通常做法,加一个把离散向量映射成连续向量的嵌入层。

现在我们再回头讨论怎么让编码器输出一个离散向量。我们可以让AE的编码器保持不变,还是输出一个连续向量,再通过一个「向量离散化」操作,把连续向量变成离散向量。这个操作会把编码器的输出对齐到嵌入层的向量上,其原理类似于把0.99和1.01离散化成1,只不过它是对向量整体考虑,而不是对每一个数单独考虑。向量离散化操作的具体原理我们不在此处细究。

忽略掉实现细节,我们可以认为VQVAE能够把图像压缩成离散向量。更准确地说,VQVAE能把图像等比例压缩成离散的「小图像」。压缩成二维图像而不是一维向量,能够保留原图像的一些空间特性,为之后第二步图像生成铺路。

整理一下,VQVAE是一个能把图像压缩成离散小图像的AE。为了用VQVAE生成图像,需要执行一个两阶段的图像生成流程:

  • 训练时,先训练一个图像压缩模型(VQVAE),再训练一个生成压缩图像的模型(比如Transformer)
  • 生成时,先用第二个模型生成出一个压缩图像,再用第一个模型的解码器把压缩图像复原成真实图像

之所以要执行两阶段的图像生成流程,而不是只用第二个模型生成大图像,有两个原因。第一个原因是前面提到的,Transformer等生成模型只支持生成离散图像,需要用另一个模型把连续的颜色值变成离散值以兼容这些模型。第二个原因是为了减少模型的运算量。以Transformer为例,Transformer的运算次数大致与像素数的平方成正比,拿Transformer生成高分辨率图像的运算开销是不可接受的。而如果用一个AE把图像压缩一下的话,用Transformer就可行了。

VQVAE给后续工作带来了三条启发:第一,可以用AE把图像压缩成离散向量;第二,如果一个图像生成模型生成高分辨率的图像的计算代价太高,可以先用AE把图像压缩,再生成压缩图像。这两条启发对应上一段提到的使用VQVAE的两条动机。

而第三条启发就比较有意思了。在讨论VQVAE的过程中,我们完全没有考虑过拟合的事。这是因为经过了向量离散化操作后,解码器的输入已经不再是编码器的输出,而是嵌入层里的向量了。这种做法杜绝了AE的死记硬背,缓解了过拟合现象。这样,我们可以换一个角度看待VQVAE:编码器还是AE的编码器,编码器的输出是连续向量,后续的向量离散化操作和嵌入层全部都是解码器的一部分。从这个角度看,VQVAE其实提出了一个由向量离散化和嵌入层组成的正则化模块。这个模块和VAE的KL散度约束一样,都解决了AE的过拟合问题。我们把VQVAE的正则化方法叫做VQ正则化

VQVAE论文提出的图像生成方法效果一般。和普通的AE一样,VQVAE在训练时只用了重建误差来约束图像质量,重建图像的细节依然很模糊。且VQVAE配套的第二阶段图像生成模型不是较为强力的Transformer,而是一个基于CNN的图像生成模型。

后续的VQGAN论文对VQVAE进行了改进。对于一阶段的图像压缩模型,VQGAN在VQVAE的基础上引入了生成对抗网络(GAN)中一些监督误差,提高了图像压缩模型的重建质量;对于两阶段的图像生成模型,该方法使用了Transformer。凭借这些改动,VQGAN方法能够生成高质量的高清图片。并且,通过把额外的约束条件(如语义分割图像、文字)输入进Transformer,VQGAN方法能够实现带约束的图像生成。以下是VQGAN方法根据语义分割图像生成的高清图片。

图像生成模型可以是无约束或带约束的。无约束图像生成模型只需要输入一个随机向量,训练数据不需要任何标注,可以进行无监督训练。带约束图像生成模型会在无约束图像生成模型的基础上多加一些输入,并给每个训练图像打上描述约束的标签,执行监督训练。比如要训练文生图模型,就要给每个训练图片带上文字描述。

路线的交汇点——Stable Diffusion

看完上面这两条AE的改进路线,相信你已经能够猜出Stable Diffusion的核心思想了。让我们看看Stable Diffusion是怎么从这两条路径中汲取灵感的。

在发布了VQGAN后,德国的CompVis实验室开始探索起VQGAN的改进方法。VQGAN能把图像边长压缩16倍,而VQGAN配套的Transformer只能一次生成$16 \times 16$的图片。也就是说,整套方法一次只能生成$256 \times 256$的图片。为了生成分辨率更高的图片,VQGAN方法需要借助滑动窗口。能不能让模型一次性生成分辨率更高的图片呢?制约VQGAN方法生成分辨率的主要因素是Transformer。如果能把Transformer换成一个效率更高,能生成更高分辨率的图像的模型,不就能生成比$256\times256$更大的图片了吗?CompVis实验室开始把目光着眼于DDPM上。

于是,在发布VQGAN的一年后,CompVis实验室又发布了名为High-Resolution Image Synthesis with Latent Diffusion Models的论文,提出了一种叫做隐扩散模型(latent diffusion model, LDM) 的图像生成模型。通过与AI公司Stability AI合作,借助他们庞大的算力资源训练LDM,CompVis实验室发布了商业名为Stable Diffusion的开源文生图AI绘画模型。

LDM其实就是在VQGAN方法的基础上,把图像生成模型从Transformer换成了DDPM。或者从另一个角度说,为了让DDPM生成高分辨率图像,LDM利用了VQVAE的第二条启发:先用AE把图像压缩,再用DDPM生成压缩图像。LDM的AE一般是把图像边长压缩8倍,DDPM生成$64 \times 64$的压缩图像,整套LDM能生成$512 \times 512$的图像。

和Transformer不同,DDPM处理的图像是用连续向量表示的。因此,在LDM中使用VQGAN做图像压缩时,不一定需要向量离散化操作,只需要在AE的基础上加一点轻微的正则化就行。作者在实现LDM时讨论了两类正则化,一类是VAE的KL正则化,一类是VQ正则化(对应VQVAE的第三条启发),两种正则化都能取得不错的效果。

LDM依然可以实现带约束的图像生成。用DDPM替换掉Transformer后,额外的约束会输入进DDPM中。作者在论文中讨论了几种把约束输入进DDPM的方式。

在搞懂了早期工作后,理解Stable Diffusion的核心思想就是这么简单。让我们把Stable Diffusion的发展过程及主要结构总结一下。Stable Diffusion由两类AE的变种发展而来,一类是有强大生成能力却需要耗费大量运算资源的DDPM,一类是能够以较高保真度压缩图像的VQVAE。Stable Diffusion是一个两阶段的图像生成模型,它先用一个使用KL正则化或VQ正则化的VQGAN来实现图像压缩,再用DDPM生成压缩图像。可以把额外的约束(如文字)输入进DDPM以实现带约束图像生成。

相关论文

本文仅仅对Stable Diffusion的早期工作做了一个简单的梳理。要把Stable Diffusion吃透,还需要多读一些早期论文。我来把早期论文按重要性分个类。

图像生成必读文章

Neural Discrete Representation Learning (VQVAE): https://arxiv.org/abs/1711.00937

Taming Transformers for High-Resolution Image Synthesis (VQGAN): https://arxiv.org/abs/2012.09841

Denoising Diffusion Probabilistic Models (DDPM): https://arxiv.org/abs/2006.11239

图像生成选读文章

Auto-Encoding Variational Bayes (VAE): https://arxiv.org/abs/1312.6114 提出VAE的文章。数学公式较多,只需要了解VAE的大致结构就好,不需要详细阅读论文。

Pixel Recurrent Neural Networks (PixelCNN): https://arxiv.org/abs/1601.06759 提出了一种拟合离散分布的图像生成模型,自回归图像生成模型的代表。这是VQVAE使用的第二阶段图像生成模型。有兴趣可以了解一下。

Deep Unsupervised Learning using Nonequilibrium Thermodynamics: https://arxiv.org/abs/1503.03585 DDPM的前作,首个提出扩散模型思想的文章。其核心原理和DDPM几乎完全一致,但是模型结构和优化目标不够先进,生成效果没有改进后的DDPM好。数学公式较多,不必细读,可以在学习DDPM时对比着阅读。

Denoising Diffusion Implicit Models (DDIM): https://arxiv.org/abs/2010.02502 一种加速DDPM采样的方法,广泛运用在包含Stable Diffusion在内的扩散模型中。推荐阅读。

Classifier-Free Diffusion Guidance: https://arxiv.org/abs/2207.12598 一种让扩散模型的输出更加贴近约束的方法,广泛运用在包含Stable Diffusion在内的扩散模型中,用于生成更符合文字描述的图片。推荐阅读。

Generative Adversarial Networks (GAN): https://arxiv.org/abs/1406.2661 以及 A Style-Based Generator Architecture for Generative Adversarial Networks (StyleGAN): https://arxiv.org/abs/1812.04948 可以了解一下GAN是怎么确保图像生成质量的,并认识CelebAHQ和FFHQ这两个常用的人脸数据集。

其他必读文章

Deep Residual Learning for Image Recognition (ResNet): https://arxiv.org/abs/1512.03385 深度学习的经典文章。其中提出的残差连接被用到了DDPM中。

Attention Is All You Need (Transformer): https://arxiv.org/abs/1706.03762 深度学习的经典文章。其中提出的自注意力模块被用到了DDPM中。

其他选读文章

Learning Transferable Visual Models From Natural Language Supervision (CLIP): https://arxiv.org/abs/2103.00020 提出了对齐文本和图像的方法。绝大多数文生图模型的核心。

U-Net: Convolutional Networks for Biomedical Image Segmentation (U-Net): https://arxiv.org/abs/1505.04597 一种被广泛运用的神经网络架构。DDPM的神经网络的主架构。U-Net的结构很简单,可以不用去读论文,直接看代码。

我的解读文章

我对这上面的很多论文都做过解读。如果你在阅读论文的时候碰到了困难,欢迎阅读我的解读。

轻松理解 VQ-VAE:首个提出 codebook 机制的生成模型

VQGAN 论文与源码解读:前Diffusion时代的高清图像生成模型

扩散模型(Diffusion Model)详解:直观理解、数学原理、PyTorch 实现

抛开数学,轻松学懂 VAE(附 PyTorch 实现)

冷门的自回归生成模型 ~ 详解 PixelCNN 大家族

DDIM 简明讲解与 PyTorch 实现:加速扩散模型采样的通用方法

用18支画笔作画的AI ~ StyleGAN特点浅析

ResNet 论文概览与精读

Attention Is All You Need (Transformer) 论文精读

相比于多数图像生成模型,去噪扩散概率模型(Denoising Diffusion Probabilistic Model, DDPM)的采样速度非常慢。这是因为DDPM在采样时通常要做1000次去噪操作。但如果你玩过基于扩散模型的图像生成应用的话,你会发现,大多数应用只需要20次去噪即可生成图像。这是为什么呢?原来,这些应用都使用了一种更快速的采样方法——去噪扩散隐式模型(Denoising Diffusion Implicit Model, DDIM)。

基于DDPM,DDIM论文主要提出了两项改进。第一,对于一个已经训练好的DDPM,只需要对采样公式做简单的修改,模型就能在去噪时「跳步骤」,在一步去噪迭代中直接预测若干次去噪后的结果。比如说,假设模型从时刻$T=100$开始去噪,新的模型可以在每步去噪迭代中预测10次去噪操作后的结果,也就是逐步预测时刻$t=90, 80, …, 0$的结果。这样,DDPM的采样速度就被加速了10倍。第二,DDIM论文推广了DDPM的数学模型,从更高的视角定义了DDPM的前向过程(加噪过程)和反向过程(去噪过程)。在这个新数学模型下,我们可以自定义模型的噪声强度,让同一个训练好的DDPM有不同的采样效果。

在这篇文章中,我将言简意赅地介绍DDIM的建模方法,并给出我的DDIM PyTorch实现与实验结果。本文不会深究DDIM的数学推导,对这部分感兴趣的读者可以去阅读我在文末给出的参考资料。

回顾 DDPM

DDIM是建立在DDPM之上的一篇工作。在正式认识DDIM之前,我们先回顾一下DDPM中的一些关键内容,再从中引出DDIM的改进思想。

DDPM是一个特殊的VAE。它的编码器是$T$步固定的加噪操作,解码器是$T$步可学习的去噪操作。模型的学习目标是让每一步去噪操作尽可能抵消掉对应的加噪操作。

DDPM的加噪和去噪操作其实都是在某个正态分布中采样。因此,我们可以用概率$q, p$分别表示加噪和去噪的分布。比如 $q(\mathbf{x}_t|\mathbf{x}_{t-1})$ 就是由第 $t-1$ 时刻的图像到第 $t$ 时刻的图像的加噪声分布, $p(\mathbf{x}_{t-1}|\mathbf{x}_{t})$ 就是由第 $t$ 时刻的图像到第 $t-1$ 时刻的图像的去噪声分布。这样,我们可以说网络的学习目标是让 $p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$ 尽可能与 $q(\mathbf{x}_t | \mathbf{x}_{t-1})$ 和互逆。

但是,「互逆」并不是一个严格的数学表述。更具体地,我们应该让分布$p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$和分布$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$尽可能相似。$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$和$p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$的关系就和VAE中原图像与重建图像的关系一样。

$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$是不好求得的,但在给定了输入数据$\mathbf{x}_{0}$时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0})$是可以用贝叶斯公式求出来的:

我们不必关心具体的求解方法,只需要知道从等式右边的三项$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)、q(\mathbf{x}_{t-1} | \mathbf{x}_0)、q(\mathbf{x}_{t} | \mathbf{x}_0)$可以推导出等式左边的那一项。在DDPM中,$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$是一个定义好的式子,且$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}) = q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$。根据$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$,可以推出$q(\mathbf{x}_{t} | \mathbf{x}_0)$。知道了$q(\mathbf{x}_{t} | \mathbf{x}_0)$,$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$也就知道了(把公式里的$t$换成$t-1$就行了)。这样,在DDPM中,等式右边的式子全部已知,等式左边的$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0})$可以直接求出来。

上述推理过程可以简单地表示为:知道$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$,就知道了神经网络的学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$。这几个公式在DDPM中的具体形式如下:

其中,只有参数$\beta_t$是可调的。$\bar{\alpha}_t$是根据$\beta_t$算出的变量,其计算方法为:$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$。

由于学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$里只有一个未知变量$\mathbf{x}_0$,DDPM把学习目标简化成了只让神经网络根据$\mathbf{x}_{t}$拟合公式里的$\mathbf{x}_{0}$(更具体一点,是拟合从$\mathbf{x}_{0}$到$\mathbf{x}_{t}$的噪声)。也就是说,在训练时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$的公式不会被用到,只有$\mathbf{x}_{t}$和$\mathbf{x}_{0}$两个量之间的公式$q(\mathbf{x}_{t} | \mathbf{x}_0)$会被用到。只有在采样时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$的公式才会被用到。训练目标的推理过程可以总结为:

理解「DDPM的训练目标只有$\mathbf{x}_{0}$」对于理解DDIM非常关键。如果你在回顾DDPM时出现了问题,请再次阅读DDPM的相关介绍文章。

加速 DDPM

我们再次审视一下DDPM的推理过程:首先有$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}) = q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$。根据$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$,可以推出$q(\mathbf{x}_{t} | \mathbf{x}_0)$。知道$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$,由贝叶斯公式,就知道了学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$。

根据这一推理过程,DDIM论文的作者想到,假如我们把贝叶斯公式中的$t$替换成$t_2$, $t-1$替换成$t_1$,其中$t_2$是比$t_1$大的任意某一时刻,那么我们不就可以从$t_2$到$t_1$跳步骤去噪了吗?比如令$t_2 = t_1 + 10$,我们就可以求出去除10次噪声的公式,去噪的过程就快了10倍。

修改之后,$q(\mathbf{x}_{t_1} | \mathbf{x}_0)$和$q(\mathbf{x}_{t_2} | \mathbf{x}_0)$依然很好求,只要把$t_1$, $t_2$代入普通的$q(\mathbf{x}_{t} | \mathbf{x}_0)$公式里就行。

但是,$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$怎么求呢?原来的$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t};\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$来自于DDPM的定义,我们能直接把公式拿来用。能不能把$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的公式稍微修改一下,让它兼容$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$呢?

修改$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的思路如下:假如我们能把公式中的$\beta_t$换成一个由$t$和$t-1$决定的变量,我们就能把$t$换成$t_2$,$t-1$换成$t_1$,也就得到了$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$。

那怎么修改$\beta_t$的形式呢?很简单。我们知道$\beta_t$决定了$\bar{\alpha}_t$:$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$。那么我们用$\bar{\alpha}_t$除以$\bar{\alpha}_{t-1}$,不就得到了$1-\beta_t$了吗?也就是说:

我们把这个用$\bar{\alpha}_t$和$\bar{\alpha}_{t-1}$表示的$\beta_t$套入$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的公式里,再把$t$换成$t_2$,$t-1$换成$t_1$,就得到了$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$。有了这一项,贝叶斯公式等式右边那三项我们就全部已知,可以求出$q(\mathbf{x}_{t_1} | \mathbf{x}_{t_2}, \mathbf{x}_0)$,也就是可以一次性得到多个时刻后的去噪结果。

在这个过程中,我们只是把DDPM公式里的$\bar{\alpha}_t$换成$\bar{\alpha}_{t2}$,$\bar{\alpha}_{t-1}$换成$\bar{\alpha}_{t1}$,公式推导过程完全不变。网络的训练目标$\mathbf{x}_{0}$也没有发生改变,只是采样时的公式需要修改。这意味着我们可以先照着原DDPM的方法训练,再用这种更快速的方式采样。

我们之前只讨论了$t_1$到$t_2$为固定值的情况。实际上,我们不一定要间隔固定的时刻去噪一次,完全可以用原时刻序列的任意一个子序列来去噪。比如去噪100次的DDPM的去噪时刻序列为[99, 98, ..., 0],我们可以随便取一个长度为10的子序列:[99, 98, 77, 66, 55, 44, 33, 22, 1, 0],按这些时刻来去噪也能让采样速度加速10倍。但实践中没人会这样做,一般都是等间距地取时刻。

这样看来,在采样时,只有部分时刻才会被用到。那我们能不能顺着这个思路,干脆训练一个有效时刻更短(总时刻$T$不变)的DDPM,以加速训练呢?又或者保持有效训练时刻数不变,增大总时刻$T$呢?DDIM论文的作者提出了这些想法,认为这可以作为后续工作的研究方向。

从 DDPM 到 DDIM

除了加速DDPM外,DDIM论文还提出了一种更普遍的DDPM。在这种新的数学模型下,我们可以任意调节采样时的方差大小。让我们来看一下这个数学模型的推导过程。

DDPM的学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$由$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$决定。具体来说,在求解正态分布$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$时,我们会将它的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$设为未知量,并将条件$q(\mathbf{x}_{t} | \mathbf{x}_0)$、$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$、$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$代入,求解出确定的$\tilde{\mu}_t$和$\tilde{\beta}_t$。

在上文我们分析过,DDPM训练时只需要拟合$\mathbf{x}_0$,只需要用到$\mathbf{x}_0$和$\mathbf{x}_t$的关系$q(\mathbf{x}_{t} | \mathbf{x}_0)$。在不修改训练过程的前提下,我们能不能把限制$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$去掉(即$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$可以是任意一个正态分布,而不是我们提前定义好的一个正态分布),得到一个更普遍的DDPM呢?

这当然是可以的。根据基础的解方程知识,我们知道,去掉一个方程后,会多出一个自由变量。取消了$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的限制后,均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$就不能同时确定下来了。我们可以令方差$\tilde{\beta}_t$为自由变量,并让$\tilde{\mu}_t$用含$\tilde{\beta}_t$的式子表示出来。这样,我们就得到了一个方差可变的更一般的DDPM。

让我们来看一下这个新模型的具体公式。原来的DDPM的加噪声逆操作的分布为:

新的分布公式为:

新公式是旧公式的一个推广版本。如果我们把DDPM的方差$(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$代入新公式里的$\tilde{\beta}_t$,就能把新公式还原成DDPM的公式。和DDPM的公式一样,我们也可以把$\mathbf{x}_{0}$拆成$\mathbf{x}_{t}$和噪声$\epsilon$表示的式子。

现在采样时方差可以随意取了,我们来讨论一种特殊的方差取值——$\tilde{\beta}_t=0$。也就是说,扩散模型的反向过程变成了一个没有噪声的确定性过程。给定随机噪声$\mathbf{x}_{T}$,我们只能得到唯一的采样结果$\mathbf{x}_{0}$。这种结果确定的概率模型被称为隐式概率模型(implicit probabilistic model)。所以,论文作者把方差为0的这种扩散模型称为DDIM(Denoising Diffusion Implicit Model)。

为了方便地选取方差值,作者将方差改写为

其中,$\eta\in[0, 1]$。通过选择不同的$\eta$,我们实际上是在DDPM和DDIM之间插值。$\eta$控制了插值的比例。$\eta=0$,模型是DDIM;$\eta=1$,模型是DDPM。

除此之外,DDPM论文曾在采样时使用了另一种方差取值:$\tilde{\beta}_t=\beta_t$,即去噪方差等于加噪方差。实验显示这个方差的采样结果还不错。我们可以把这个取值也用到DDIM论文提出的方法里,只不过这个方差值不能直接套进上面的公式。在代码实现部分我会介绍该怎么在DDIM方法中使用这个方差。

注意,在这一节的推导过程中,我们依然没有修改DDPM的训练目标。我们可以把这种的新的采样方法用在预训练的DDPM上。当然,我们可以在使用新的采样方法的同时也使用上一节的加速采样方法。

实验

到这里为止,我们已经学完了DDIM论文的两大内容:加速采样、更换采样方差。加速采样的意义很好理解,它能大幅减少采样时间。可更换采样方差有什么意义呢?我们看完论文中的实验结果就知道了。

论文展示了新采样方法在不同方差、不同采样步数下的FID指标(越小越好)。其中,$\hat{\sigma}$表示使用DDPM中的$\tilde{\beta}_t=\beta_t$方差取值。实验结果非常有趣。在使用采样加速(步数比总时刻1000要小)时,$\eta=0$的DDIM的表现最好,而$\hat{\sigma}$的情况则非常差。而当$\eta$增大,模型越来越靠近DDPM时,用$\hat{\sigma}$的结果会越来越好。而在DDPM中,用$\hat{\sigma}$的结果是最好的。

从这个实验结果中,我们可以得到一条很简单的实践指南:如果使用了采样加速,一定要用效果最好的DDIM;而使用原DDPM的话,可以维持原论文提出的$\tilde{\beta}_t=\beta_t$方差取值。

总结

DDIM论文提出了DDPM的两个拓展方向:加速采样、变更采样方差。通过同时使用这两个方法,我们能够在不重新训练DDPM、尽可能不降低生成质量的前提下,让扩散模型的采样速度大幅提升(一般可以快50倍)。让我们再从头理一理提出DDIM方法的思考过程。

为了能直接使用预训练的DDPM,我们希望在改进DDPM时不更改DDPM的训练过程。而经过简化后,DDPM的训练目标只有拟合$\mathbf{x}_{0}$,训练时只会用到前向过程公式$q(\mathbf{x}_{t} | \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_t}\mathbf{x}_{0}, (1-\bar{\alpha}_t)\mathbf{I})$。所以,我们的改进应该建立在公式$q(\mathbf{x}_{t} | \mathbf{x}_0)$完全不变的前提下。

通过对DDPM反向过程公式的简单修改,也就是把$t$改成$t_2$,$t-1$改成$t_1$,我们可以把去噪一步的公式改成去噪多步的公式,以大幅加速DDPM。可是,这样改完之后,采样的质量会有明显的下降。

我们可以猜测,减少了采样迭代次数后,采样质量之所以下降,是因为每次估计的去噪均值更加不准确。而每次去噪迭代中的噪声(由方差项决定的那一项)放大了均值的不准确性。我们能不能干脆让去噪时的方差为0呢?为了让去噪时的方差可以自由变动,我们可以去掉DDPM的约束条件。由于贝叶斯公式里的$q(\mathbf{x}_{t} | \mathbf{x}_0)$不能修改,我们只能去掉$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的限制。去掉限制后,方差就成了自由变量。我们让去噪方差为0,让采样过程没有噪声。这样,就得到了本文提出的DDIM模型。实验证明,在采样迭代次数减少后,使用DDIM的生成结果是最优的。

在本文中,我较为严格地区分了DDPM和DDIM的叫法:DDPM指DDPM论文中提出的有1000个扩散时刻的模型,它的采样方差只有两种取值($\tilde{\beta}_t=(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$, $\tilde{\beta}_t=\beta_t$)。DDIM指DDIM论文中提出的$\eta=0$的推广版DDPM模型。DDPM和DDIM都可以使用采样加速。但是,从习惯上我们会把没有优化加速的DDPM称为”DDPM”,把$\eta$可以任取,采样迭代次数可以任取的采样方法统称为”DDIM”。一些开源库中会有叫DDIMSampler的类,调节$\eta$的参数大概会命名为eta,调节迭代次数的参数大概会命名为ddim_num_steps。一般我们令eta=0ddim_num_steps=20即可。

DDIM的代码实现没有太多的学习价值,只要在DDPM代码的基础上把新数学公式翻译成代码即可。其中唯一值得注意的就是如何在DDIM中使用DDPM的方差$\tilde{\beta}_t=\beta_t$。对此感兴趣的话可以阅读我接下来的代码实现介绍。

在这篇解读中,我略过了DDIM论文中的大部分数学推导细节。对DDIM数学模型的推导过程感兴趣的话,可以阅读我在参考文献中推荐的文章,或者看一看原论文。

DDIM PyTorch 实现

在这个项目中,我们将对一个在CelebAHQ上预训练的DDPM执行DDIM采样,尝试复现论文中的那个FID表格,以观察不同etaddim_steps对于采样结果的影响。

代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/ddim

DDPM 基础项目

DDIM只是DDPM的一种采样改进策略。为了复现DDIM的结果,我们需要一个DDPM基础项目。由于DDPM并不是本文的重点,在这一小节里我将简要介绍我的DDPM实现代码的框架。

我们的实验需要使用CelebAHQ数据集,请在 https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256 下载该数据集并解压到项目的data/celebA/celeba_hq_256目录下。另外,我在Hugging Face上分享了一个在64x64 CelebAHQ上训练的DDPM模型:https://huggingface.co/SingleZombie/dldemos/tree/main/ckpt/ddim ,请把它放到项目的dldemos/ddim目录下。

先运行dldemos/ddim/dataset.py下载MNIST,再直接运行dldemos/ddim/main.py,代码会自动完成MNIST上的训练,并执行步数1000的两种采样和步数20的三种采样,同时将结果保存在目录work_dirs中。以下是我得到的MNIST DDPM采样结果(存储在work_dirs/diffusion_ddpm_sigma_hat.jpg中)。

为了查看64x64 CelebAHQ上的采样结果,可以在dldemos/ddim/main.py的main函数里把config_id改成2,再注释掉训练函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 0 for MNIST. See configs.py
config_id = 2
cfg = configs[config_id]
n_steps = 1000
device = 'cuda'
model_path = cfg['model_path']
img_shape = cfg['img_shape']
to_bgr = False if cfg['dataset_type'] == 'MNIST' else True

net = UNet(n_steps, img_shape, cfg['channels'], cfg['pe_dim'],
cfg.get('with_attn', False), cfg.get('norm_type', 'ln'))
ddpm = DDPM(device, n_steps)

# train(ddpm,
# net,
# cfg['dataset_type'],
# resolution=(img_shape[1], img_shape[2]),
# batch_size=cfg['batch_size'],
# n_epochs=cfg['n_epochs'],
# device=device,
# ckpt_path=model_path)

以下是我得到的CelebAHQ DDPM采样结果(存储在work_dirs/diffusion_ddpm_sigma_hat.jpg中)。

项目目录下的configs.py存储了训练配置,dataset.py定义了DataLoadernetwork.py定义了U-Net的结构,ddpm.pyddim.py分别定义了普通的DDPM前向过程和采样以及DDIM采样,dist_train.py提供了并行训练脚本,dist_sample.py提供了并行采样脚本,main.py提供了单卡运行的所有任务脚本。

在这个项目中,我们的主要的目标是基于其他文件,编写ddim.py。我们先来看一下原来的DDPM类是怎么实现的,再仿照它的接口写一个DDIM类。

实现 DDIM 采样

在我的设计中,DDPM类不是一个神经网络(torch.nn.Module),它仅仅维护了扩散模型的alpha等变量,并描述了前向过程和反向过程。

DDPM类中,我们可以在初始化函数里定义好要用到的self.betas, self.alphas, self.alpha_bars变量。如果在工程项目中,我们可以预定义好更多的常量以节约采样时间。但在学习时,我们可以少写一点代码,让项目更清晰一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class DDPM():

def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
alphas = 1 - betas
alpha_bars = torch.empty_like(alphas)
product = 1
for i, alpha in enumerate(alphas):
product *= alpha
alpha_bars[i] = product
self.betas = betas
self.n_steps = n_steps
self.alphas = alphas
self.alpha_bars = alpha_bars

前向过程就是把正态分布的公式$q(\mathbf{x}_{t} | \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_t}\mathbf{x}_{0}, (1-\bar{\alpha}_t)\mathbf{I})$翻译一下。

1
2
3
4
5
6
def sample_forward(self, x, t, eps=None):
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
if eps is None:
eps = torch.randn_like(x)
res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x
return res

在反向过程中,我们从self.n_steps1枚举时刻t(代码中时刻和数组下标有1的偏差),按照公式算出每一步的去噪均值和方差,执行去噪。算法流程如下:

参数simple_var=True表示令方差$\sigma_t^2=\beta_t$,而不是$(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def sample_backward(self, img_or_shape, net, device, simple_var=True):
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
net = net.to(device)
for t in tqdm(range(self.n_steps - 1, -1, -1), "DDPM sampling"):
x = self.sample_backward_step(x, t, net, simple_var)

return x

def sample_backward_step(self, x_t, t, net, simple_var=True):

n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

接下来,我们来实现DDIM类。DDIMDDPM的推广,我们可以直接用DDIM类继承DDPM类。它们共享初始化函数与前向过程函数。

1
2
3
4
5
6
7
8
class DDIM(DDPM):

def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
super().__init__(device, n_steps, min_beta, max_beta)

我们要修改的只有反向过程的实现函数。整个函数的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def sample_backward(self,
img_or_shape,
net,
device,
simple_var=True,
ddim_step=20,
eta=1):
if simple_var:
eta = 1
ts = torch.linspace(self.n_steps, 0,
(ddim_step + 1)).to(device).to(torch.long)
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
batch_size = x.shape[0]
net = net.to(device)
for i in tqdm(range(1, ddim_step + 1),
f'DDIM sampling with eta {eta} simple_var {simple_var}'):
cur_t = ts[i - 1] - 1
prev_t = ts[i] - 1

ab_cur = self.alpha_bars[cur_t]
ab_prev = self.alpha_bars[prev_t] if prev_t >= 0 else 1

t_tensor = torch.tensor([cur_t] * batch_size,
dtype=torch.long).to(device).unsqueeze(1)
eps = net(x, t_tensor)
var = eta * (1 - ab_prev) / (1 - ab_cur) * (1 - ab_cur / ab_prev)
noise = torch.randn_like(x)

first_term = (ab_prev / ab_cur)**0.5 * x
second_term = ((1 - ab_prev - var)**0.5 -
(ab_prev * (1 - ab_cur) / ab_cur)**0.5) * eps
if simple_var:
third_term = (1 - ab_cur / ab_prev)**0.5 * noise
else:
third_term = var**0.5 * noise
x = first_term + second_term + third_term

return x

我们来把整个函数过一遍。先看一下函数的参数。相比DDPM,DDIM的采样会多出两个参数:ddim_step, eta。如正文所述,ddim_step表示执行几轮去噪迭代,eta表示DDPM和DDIM的插值系数。

1
2
3
4
5
6
7
def sample_backward(self,
img_or_shape,
net,
device,
simple_var=True,
ddim_step=20,
eta=1):

在开始迭代前,要做一些预处理。根据论文的描述,如果使用了DDPM的那种简单方差,一定要令eta=1。所以,一开始我们根据simple_vareta做一个处理。之后,我们要准备好迭代时用到的时刻。整个迭代过程中,我们会用到从self.n_steps0等间距的ddim_step+1个时刻(self.n_steps是初始时刻,不在去噪迭代中)。比如总时刻self.n_steps=100ddim_step=10ts数组里的内容就是[100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0]

1
2
3
4
5
6
7
8
9
10
if simple_var:
eta = 1
ts = torch.linspace(self.n_steps, 0,
(ddim_step + 1)).to(device).to(torch.long)
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
batch_size = x.shape[0]
net = net.to(device)

做好预处理后,进入去噪循环。在for循环中,我们从1ddim_step遍历ts的下标,从时刻数组ts里取出较大的时刻cur_t(正文中的$t_2$)和较小的时刻prev_t(正文中的$t_1$)。由于self.alpha_bars存储的是t=1, t=2, ..., t=n_steps时的变量,时刻和数组下标之间有一个1的偏移,我们要把ts里的时刻减去1得到时刻在self.alpha_bars里的下标,再取出对应的变量ab_cur, ab_prev。注意,在当前时刻为0时,self.alpha_bars是没有定义的。但由于self.alpha_bars表示连乘,我们可以特别地令当前时刻为0(prev_t=-1)时的alpha_bar=1

1
2
3
4
5
6
7
for i in tqdm(range(1, ddim_step + 1),
f'DDIM sampling with eta {eta} simple_var {simple_var}'):
cur_t = ts[i - 1] - 1
prev_t = ts[i] - 1

ab_cur = self.alpha_bars[cur_t]
ab_prev = self.alpha_bars[prev_t] if prev_t >= 0 else 1

准备好时刻后,我们使用和DDPM一样的方法,用U-Net估计生成x_t时的噪声eps,并准备好DDPM采样算法里的噪声noise(公式里的$\mathbf{z}$)。
与DDPM不同,在计算方差var时(公式里的$\sigma_t^2$),我们要给方差乘一个权重eta

1
2
3
4
5
t_tensor = torch.tensor([cur_t] * batch_size,
dtype=torch.long).to(device).unsqueeze(1)
eps = net(x, t_tensor)
var = eta * (1 - ab_prev) / (1 - ab_cur) * (1 - ab_cur / ab_prev)
noise = torch.randn_like(x)

接下来,我们要把之前算好的所有变量用起来,套入DDIM的去噪均值计算公式中。

也就是(设$\sigma_t^2 = \tilde{\beta}_t$, $\mathbf{z}$为来自标准正态分布的噪声):

由于我们只有噪声$\epsilon$,要把$\mathbf{x}_0=(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon)/\sqrt{\bar{\alpha}_t}$代入,得到不含$\mathbf{x}_0$的公式:

我在代码里把公式的三项分别命名为first_term, second_term, third_term,以便查看。

特别地,当使用DDPM的$\hat{\sigma_t}$方差取值(令$\sigma_t^2=\beta_t=\hat{\sigma_t}^2$)时,不能把这个方差套入公式中,不然$\sqrt{1-{\bar{\alpha}}_{t}-\sigma_t^2}$的根号里的数会小于0。DDIM论文提出的做法是,只修改后面和噪声$\mathbf{z}$有关的方差项,前面这个根号里的方差项保持$\sigma_t^2=(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$ ($\eta=1$)的取值。

当然,上面这些公式全都是在描述$t$到$t-1$。当描述$t_2$到$t_1$时,只需要把$\beta_t$换成$1-\frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}$,再把所有$t$换成$t_2$,$t-1$换成$t_1$即可。

把上面的公式和处理逻辑翻译成代码,就是这样:

1
2
3
4
5
6
7
8
first_term = (ab_prev / ab_cur)**0.5 * x
second_term = ((1 - ab_prev - var)**0.5 -
(ab_prev * (1 - ab_cur) / ab_cur)**0.5) * eps
if simple_var:
third_term = (1 - ab_cur / ab_prev)**0.5 * noise
else:
third_term = var**0.5 * noise
x = first_term + second_term + third_term

这样,下一刻的x就算完了。反复执行循环即可得到最终的结果。

实验

写完了DDIM采样后,我们可以编写一个随机生成图片的函数。由于DDPMDDIM的接口非常相似,我们可以用同一套代码实现DDPM或DDIM的采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def sample_imgs(ddpm,
net,
output_path,
img_shape,
n_sample=64,
device='cuda',
simple_var=True,
to_bgr=False,
**kwargs):
if img_shape[1] >= 256:
max_batch_size = 16
elif img_shape[1] >= 128:
max_batch_size = 64
else:
max_batch_size = 256

net = net.to(device)
net = net.eval()

index = 0
with torch.no_grad():
while n_sample > 0:
if n_sample >= max_batch_size:
batch_size = max_batch_size
else:
batch_size = n_sample
n_sample -= batch_size
shape = (batch_size, *img_shape)
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var,
**kwargs).detach().cpu()
imgs = (imgs + 1) / 2 * 255
imgs = imgs.clamp(0, 255).to(torch.uint8)

img_list = einops.rearrange(imgs, 'n c h w -> n h w c').numpy()
output_dir = os.path.splitext(output_path)[0]
os.makedirs(output_dir, exist_ok=True)
for i, img in enumerate(img_list):
if to_bgr:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(f'{output_dir}/{i+index}.jpg', img)

# First iteration
if index == 0:
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(batch_size**0.5))
imgs = imgs.numpy()
if to_bgr:
imgs = cv2.cvtColor(imgs, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_path, imgs)

index += batch_size

为了生成大量图片以计算FID,在这个函数中我加入了很多和batch有关的处理。剔除这些处理代码以及图像存储后处理代码,和采样有关的核心代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def sample_imgs(ddpm,
net,
output_path,
img_shape,
n_sample=64,
device='cuda',
simple_var=True,
to_bgr=False,
**kwargs):

net = net.to(device)
net = net.eval()

with torch.no_grad():
shape = (n_sample, *img_shape)
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var,
**kwargs).detach().cpu()

如果是用DDPM采样,把参数表里的那些参数填完就行了;如果是DDIM采样,则需要在kwargs里指定ddim_stepeta

使用这个函数,我们可以进行不同ddim_step和不同eta下的64x64 CelebAHQ采样实验,以尝试复现DDIM论文的实验结果。

我们先准备好变量。

1
2
3
4
5
net = UNet(n_steps, img_shape, cfg['channels'], cfg['pe_dim'],
cfg.get('with_attn', False), cfg.get('norm_type', 'ln'))
ddpm = DDPM(device, n_steps)
ddim = DDIM(device, n_steps)
net.load_state_dict(torch.load(model_path))

第一组实验是总时刻保持1000,使用$\hat{\sigma}_t$(标准DDPM)和$\eta=0$(标准DDIM)的实验。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
sample_imgs(ddpm,
net,
'work_dirs/diffusion_ddpm_sigma_hat.jpg',
img_shape,
device=device,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddpm_eta_0.jpg',
img_shape,
device=device,
to_bgr=to_bgr,
ddim_step=1000,
simple_var=False,
eta=0)

把参数n_samples改成30000,就可以生成30000张图像,以和30000张图像的CelebAHQ之间算FID指标。由于总时刻1000的采样速度非常非常慢,建议使用dist_sample.py并行采样。

算FID指标时,可以使用torch fidelity库。使用pip即可安装此库。

1
pip install torch-fidelity

之后就可以使用命令fidelity来算指标了。假设我们把降采样过的CelebAHQ存储在data/celebA/celeba_hq_64,把我们生成的30000张图片存在work_dirs/diffusion_ddpm_sigma_hat,就可以用下面的命令算FID指标。

1
fidelity --gpu 0 --fid --input1 work_dirs/diffusion_ddpm_sigma_hat --input2 data/celebA/celeba_hq_64

整体来看,我的模型比论文差一点,总的FID会高一点。各个配置下的对比结果也稍有出入。在第一组实验中,使用$\hat{\sigma}_t$时,我的FID是13.68;使用$\eta=0$时,我的FID是13.09。而论文中用$\hat{\sigma}_t$时的FID比$\eta=0$时更低。

我们还可以做第二组实验,测试ddim_step=20(我设置的默认步数)时使用$\eta=0$, $\eta=1$, $\hat{\sigma}_t$的生成效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_sigma_hat.jpg',
img_shape,
device=device,
simple_var=True,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_eta_1.jpg',
img_shape,
device=device,
simple_var=False,
eta=1,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_eta_0.jpg',
img_shape,
device=device,
simple_var=False,
eta=0,
to_bgr=to_bgr)

我的FID结果是:

1
2
3
eta=0: 17.80
eta=1: 24.00
sigma hat: 213.16

这里得到的实验结果和论文一致。减少采样迭代次数后,生成质量略有降低。同采样步数下,eta=0最优。使用sigma hat的结果会有非常多的噪声,差得完全不能看。

综合上面两个实验来看,不管什么情况下,使用eta=0,得到的结果都不会太差。

从生成速度上来看,在64x64 CelebAHQ上生成256张图片,ddim_step=20时只要3秒不到,而ddim_step=1000时要200秒。基本上是步数减少到几分之一就提速几倍。可见,DDIM加速采样对于扩散模型来说是必要的。

参考文献及学习提示

如果对DDIM公式推导及其他数学知识感兴趣,欢迎阅读苏剑林的文章:
https://spaces.ac.cn/archives/9181。

DDIM的论文为Denoising diffusion implicit models(https://arxiv.org/abs/2010.02502)。

我在本文使用的公式符号都基于DDPM论文,与上面两篇文章使用的符号不一样。比如DDIM论文里的$\alpha$在本文中是用$\bar{\alpha}$表示。

DDIM论文在介绍新均值公式时很不友好地在3.1节直接不加解释地给出了公式的形式,并在附录B中以先给结论再证明这种和逻辑思维完全反过来的方法介绍了公式的由来。建议去阅读苏剑林的文章,看看是怎么按正常的思考方式正向推导出DDIM公式。

除了在3.1节直接甩给你一个公式外,DDIM论文后面的地方都很好读懂。DDIM后面还介绍了一些比较有趣的内容,比如4.3节介绍了扩散模型和常微分方程的关系,它可以帮助我们理解为什么DDPM会设置$T=1000$这么长的加噪步数。5.3节中作者介绍了如何用DDIM在两幅图像间插值。

要回顾DDPM的知识,欢迎阅读我之前的文章:DDPM详解。

在过去的大半年里,以Stable Diffusion为代表的AI绘画是世界上最为火热的AI方向之一。或许大家会有疑问,Stable Diffusion里的这个”Diffusion”是什么意思?其实,扩散模型(Diffusion Model)正是Stable Diffusion中负责生成图像的模型。想要理解Stable Diffusion的原理,就一定绕不过扩散模型的学习。

Stable Diffusion以「毕加索笔下的《最后的晚餐》」为题的绘画结果

在这篇文章里,我会由浅入深地对最基础的去噪扩散概率模型(Denoising Diffusion Probabilistic Models, DDPM)进行讲解。我会先介绍扩散模型生成图像的基本原理,再用简单的数学语言对扩散模型建模,最后给出扩散模型的一份PyTorch实现。本文不会堆砌过于复杂的数学公式,哪怕你没有相关的数学背景,也能够轻松理解扩散模型的原理。

扩散模型与图像生成

在认识扩散模型之前,我们先退一步,看看一般的神经网络模型是怎么生成图像的。显然,为了生成丰富的图像,一个图像生成程序要根据随机数来生成图像。通常,这种随机数是一个满足标准正态分布的随机向量。这样,每次要生成新图像时,只需要从标准正态分布里随机生成一个向量并输入给程序就行了。

而在AI绘画程序中,负责生成图像的是一个神经网络模型。神经网络需要从数据中学习。对于图像生成任务,神经网络的训练数据一般是一些同类型的图片。比如一个绘制人脸的神经网络会用人脸照片来训练。也就是说,神经网络会学习如何把一个向量映射成一张图片,并确保这个图片和训练集的图片是一类图片。

可是,相比其他AI任务,图像生成任务对神经网络来说更加困难一点——图像生成任务缺乏有效的指导。在其他AI任务中,训练集本身会给出一个「标准答案」,指导AI的输出向标准答案靠拢。比如对于图像分类任务,训练集会给出每一幅图像的类别;对于人脸验证任务,训练集会给出两张人脸照片是不是同一个人;对于目标检测任务,训练集会给出目标的具体位置。然而,图像生成任务是没有标准答案的。图像生成数据集里只有一些同类型图片,却没有指导AI如何画得更好的信息。

为了解决这一问题,人们专门设计了一些用于生成图像的神经网络架构。这些架构中比较出名的有生成对抗模型(GAN)和变分自编码器(VAE)。

GAN的想法是,既然不知道一幅图片好不好,就干脆再训练一个神经网络,用于辨别某图片是不是和训练集里的图片长得一样。生成图像的神经网络叫做生成器,鉴定图像的神经网络叫做判别器。两个网络互相对抗,共同进步。

VAE则使用了逆向思维:学习向量生成图像很困难,那就再同时学习怎么用图像生成向量。这样,把某图像变成向量,再用该向量生成图像,就应该得到一幅和原图像一模一样的图像。每一个向量的绘画结果有了一个标准答案,可以用一般的优化方法来指导网络的训练了。VAE中,把图像变成向量的网络叫做编码器,把向量转换回图像的网络叫做解码器。其中,解码器就是负责生成图像的模型。

一直以来,GAN的生成效果较好,但训练起来比VAE麻烦很多。有没有和GAN一样强大,训练起来又方便的生成网络架构呢?扩散模型正是满足这些要求的生成网络架构。

扩散模型是一种特殊的VAE,其灵感来自于热力学:一个分布可以通过不断地添加噪声变成另一个分布。放到图像生成任务里,就是来自训练集的图像可以通过不断添加噪声变成符合标准正态分布的图像。从这个角度出发,我们可以对VAE做以下修改:1)不再训练一个可学习的编码器,而是把编码过程固定成不断添加噪声的过程;2)不再把图像压缩成更短的向量,而是自始至终都对一个等大的图像做操作。解码器依然是一个可学习的神经网络,它的目的也同样是实现编码的逆操作。不过,既然现在编码过程变成了加噪,那么解码器就应该负责去噪。而对于神经网络来说,去噪任务学习起来会更加有效。因此,扩散模型既不会涉及GAN中复杂的对抗训练,又比VAE更强大一点。

具体来说,扩散模型由正向过程反向过程这两部分组成,对应VAE中的编码和解码。在正向过程中,输入$\mathbf{x}_0$会不断混入高斯噪声。经过$T$次加噪声操作后,图像$\mathbf{x}_T$会变成一幅符合标准正态分布的纯噪声图像。而在反向过程中,我们希望训练出一个神经网络,该网络能够学会$T$个去噪声操作,把$\mathbf{x}_T$还原回$\mathbf{x}_0$。网络的学习目标是让$T$个去噪声操作正好能抵消掉对应的加噪声操作。训练完毕后,只需要从标准正态分布里随机采样出一个噪声,再利用反向过程里的神经网络把该噪声恢复成一幅图像,就能够生成一幅图片了。

高斯噪声,就是一幅各处颜色值都满足高斯分布(正态分布)的噪声图像。

总结一下,图像生成网络会学习如何把一个向量映射成一幅图像。设计网络架构时,最重要的是设计学习目标,让网络生成的图像和给定数据集里的图像相似。VAE的做法是使用两个网络,一个学习把图像编码成向量,另一个学习把向量解码回图像,它们的目标是让复原图像和原图像尽可能相似。学习完毕后,解码器就是图像生成网络。扩散模型是一种更具体的VAE。它把编码过程固定为加噪声,并让解码器学习怎么样消除之前添加的每一步噪声。

扩散模型的具体算法

上一节中,我们只是大概了解扩散模型的整体思想。这一节,我们来引入一些数学表示,来看一看扩散模型的训练算法和采样算法具体是什么。为了便于理解,这一节会出现一些不是那么严谨的数学描述。更加详细的一些数学推导会放到下一节里介绍。

前向过程

在前向过程中,来自训练集的图像$\mathbf{x}_0$会被添加$T$次噪声,使得$x_T$为符合标准正态分布。准确来说,「加噪声」并不是给上一时刻的图像加上噪声值,而是从一个均值与上一时刻图像相关的正态分布里采样出一幅新图像。如下面的公式所示,$\mathbf{x}_{t - 1}$是上一时刻的图像,$\mathbf{x}_{t}$是这一时刻生成的图像,该图像是从一个均值与$\mathbf{x}_{t - 1}$有关的正态分布里采样出来的。

多数文章会说前向过程是一个马尔可夫过程。其实,马尔可夫过程的意思就是当前时刻的状态只由上一时刻的状态决定,而不由更早的状态决定。上面的公式表明,计算$\mathbf{x}_t$,只需要用到$\mathbf{x}_{t - 1}$,而不需要用到$\mathbf{x}_{t - 2}, \mathbf{x}_{t - 3}…$,这符合马尔可夫过程的定义。

绝大多数扩散模型会把这个正态分布设置成这个形式:

这个正态分布公式乍看起来很奇怪:$\sqrt{1 - \beta_t}$是哪里冒出来的?为什么会有这种奇怪的系数?别急,我们先来看另一个问题:假如给定$\mathbf{x}_{0}$,也就是从训练集里采样出一幅图片,该怎么计算任意一个时刻$t$的噪声图像$\mathbf{x}_{t}$呢?

我们不妨按照公式,从$\mathbf{x}_{t}$开始倒推。$\mathbf{x}_{t}$其实可以通过一个标准正态分布的样本$\epsilon_{t-1}$算出来:

再往前推几步:

由正态分布的性质可知,均值相同的正态分布「加」在一起后,方差也会加到一起。也就是$\mathcal{N}(0, \sigma_1^2 I)$与$\mathcal{N}(0, \sigma_2^2 I)$合起来会得到$\mathcal{N}(0, (\sigma_1^2+\sigma_2^2) I)$。根据这一性质,上面的公式可以化简为:

再往前推一步的话,结果是:

我们已经能够猜出规律来了,可以一直把公式推到$\mathbf{x}_{0}$。令$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$,则:

有了这个公式,我们就可以讨论加噪声公式为什么是$\mathbf{x}_t \sim \mathcal{N}(\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$了。这个公式里的$\beta_t$是一个小于1的常数。在DDPM论文中,$\beta_t$从$\beta_1=10^{-4}$到$\beta_T=0.02$线性增长。这样,$\beta_t$变大,$\alpha_t$也越小,$\bar{\alpha}_t$趋于0的速度越来越快。最后,$\bar{\alpha}_T$几乎为0,代入$\mathbf{x}_T = \sqrt{\bar{\alpha}_T}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_T}\epsilon$, $\mathbf{x}_T$就满足标准正态分布了,符合我们对扩散模型的要求。上述推断可以简单描述为:加噪声公式能够从慢到快地改变原图像,让图像最终均值为0,方差为$\mathbf{I}$。

大家不妨尝试一下,设加噪声公式中均值和方差前的系数分别为$a, b$,按照上述过程计算最终分布的方差。只有$a^2 + b^2 = 1$才能保证最后$\mathbf{x}_T$的方差系数为1。

反向过程

在正向过程中,我们人为设置了$T$步加噪声过程。而在反向过程中,我们希望能够倒过来取消每一步加噪声操作,让一幅纯噪声图像变回数据集里的图像。这样,利用这个去噪声过程,我们就可以把任意一个从标准正态分布里采样出来的噪声图像变成一幅和训练数据长得差不多的图像,从而起到图像生成的目的。

现在问题来了:去噪声操作的数学形式是怎么样的?怎么让神经网络来学习它呢?数学原理表明,当$\beta_t$足够小时,每一步加噪声的逆操作也满足正态分布。

其中,当前时刻加噪声逆操作的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$由当前的时刻$t$、当前的图像$\mathbf{x}_{t}$决定。因此,为了描述所有去噪声操作,神经网络应该输入$t$、$\mathbf{x}_{t}$,拟合当前的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$。

不要被上文的「去噪声」、「加噪声逆操作」绕晕了哦。由于加噪声是固定的,加噪声的逆操作也是固定的。理想情况下,我们希望去噪操作就等于加噪声逆操作。然而,加噪声的逆操作不太可能从理论上求得,我们只能用一个神经网络去拟合它。去噪声操作和加噪声逆操作的关系,就是神经网络的预测值和真值的关系。

现在问题来了:加噪声逆操作的均值和方差是什么?

直接计算所有数据的加噪声逆操作的分布是不太现实的。但是,如果给定了某个训练集输入$\mathbf{x}_0$,多了一个限定条件后,该分布是可以用贝叶斯公式计算的(其中$q$表示概率分布):

等式左边的$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1};\tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$表示加噪声操作的逆操作,它的均值和方差都是待求的。右边的$q(\mathbf{x}_{t} | \mathbf{x}_{t-1}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t};\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$是加噪声的分布。而由于$\mathbf{x}_0$已知,$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_0)$两项可以根据前面的公式$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t$得来:

这样,等式右边的式子全部已知。我们可以把公式套入,算出给定$\mathbf{x}_0$时的去噪声分布。经计算化简,分布的均值为:

其中,$\epsilon_t$是用公式算$\mathbf{x}_t$时从标准正态分布采样出的样本,它来自公式

分布的方差为:

注意,$\beta_t$是加噪声的方差,是一个常量。那么,加噪声逆操作的方差$\tilde{\beta}_t$也是一个常量,不与输入$\mathbf{x}_0$相关。这下就省事了,训练去噪网络时,神经网络只用拟合$T$个均值就行,不用再拟合方差了。

知道了均值和方差的真值,训练神经网络只差最后的问题了:该怎么设置训练的损失函数?加噪声逆操作和去噪声操作都是正态分布,网络的训练目标应该是让每对正态分布更加接近。那怎么用损失函数描述两个分布尽可能接近呢?最直观的想法,肯定是让两个正态分布的均值尽可能接近,方差尽可能接近。根据上文的分析,方差是常量,只用让均值尽可能接近就可以了。

那怎么用数学公式表达让均值更接近呢?再观察一下目标均值的公式:

神经网络拟合均值时,$\mathbf{x}_{t}$是已知的(别忘了,图像是一步一步倒着去噪的)。式子里唯一不确定的只有$\epsilon_t$。既然如此,神经网络干脆也别预测均值了,直接预测一个噪声$\epsilon_\theta(\mathbf{x}_{t}, t)$(其中$\theta$为可学习参数),让它和生成$\mathbf{x}_{t}$的噪声$\epsilon_t$的均方误差最小就行了。对于一轮训练,最终的误差函数可以写成

这样,我们就认识了反向过程的所有内容。总结一下,反向过程中,神经网络应该让$T$个去噪声操作拟合对应的$T$个加噪声逆操作。每步加噪声逆操作符合正态分布,且在给定某个输入时,该正态分布的均值和方差是可以用解析式表达出来的。因此,神经网络的学习目标就是让其输出的去噪声分布和理论计算的加噪声逆操作分布一致。经过数学计算上的一些化简,问题被转换成了拟合生成$\mathbf{x}_{t}$时用到的随机噪声$\epsilon_t$。

训练算法与采样算法

理解了前向过程和反向过程后,训练神经网络的算法和采样图片(生成图片)的算法就呼之欲出了。

以下是DDPM论文中的训练算法:

让我们来逐行理解一下这个算法。第二行是指从训练集里取一个数据$\mathbf{x}_{0}$。第三行是指随机从$1, …, T$里取一个时刻用来训练。我们虽然要求神经网络拟合$T$个正态分布,但实际训练时,不用一轮预测$T$个结果,只需要随机预测$T$个时刻中某一个时刻的结果就行。第四行指随机生成一个噪声$\epsilon$,该噪声是用于执行前向过程生成$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon$的。之后,我们把$\mathbf{x}_t$和$t$传给神经网络$\epsilon_\theta(\mathbf{x}_{t}, t)$,让神经网络预测随机噪声。训练的损失函数是预测噪声和实际噪声之间的均方误差,对此损失函数采用梯度下降即可优化网络。

DDPM并没有规定神经网络的结构。根据任务的难易程度,我们可以自己定义简单或复杂的网络结构。这里只需要把$\epsilon_\theta(\mathbf{x}_{t}, t)$当成一个普通的映射即可。

训练好了网络后,我们可以执行反向过程,对任意一幅噪声图像去噪,以实现图像生成。这个算法如下:

第一行的$\mathbf{x}_{T}$就是从标准正态分布里随机采样的输入噪声。要生成不同的图像,只需要更换这个噪声。后面的过程就是扩散模型的反向过程。令时刻从$T$到$1$,计算这一时刻去噪声操作的均值和方差,并采样出$\mathbf{x}_{t-1}$。均值是用之前提到的公式计算的:

而方差$\sigma_t^2$的公式有两种选择,两个公式都能产生差不多的结果。实验表明,当$\mathbf{x}_{0}$是特定的某个数据时,用上一节推导出来的方差最好。

而当$\mathbf{x}_{0} \sim \mathcal{N}(0, \mathbf{I})$时,只需要令方差和加噪声时的方差一样即可。

循环执行去噪声操作。最后生成的$\mathbf{x}_{0}$就是生成出来的图像。

特别地,最后一步去噪声是不用加方差项的。为什么呢,观察公式$\sigma_t^2=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t$。当$t=1$时,分子会出现$\bar{\alpha}_{t-1}=\bar{\alpha}_0$这一项。$\bar{\alpha}_t$是一个连乘,理论上$t$是从$1$开始的,在$t=0$时没有定义。但我们可以特别地令连乘的第0项$\bar{\alpha}_0=1$。这样,$t=1$时方差项的分子$1-\bar{\alpha}_{t-1}$为$0$,不用算这一项了。

当然,这一解释从数学上来说是不严谨的。据论文说,这部分的解释可以参见朗之万动力学。

数学推导的补充 (选读)

理解了训练算法和采样算法,我们就算是搞懂了扩散模型,可以去编写代码了。不过,上文的描述省略了一些数学推导的细节。如果对扩散模型更深的原理感兴趣,可以阅读一下本节。

加噪声逆操作均值和方差的推导

上一节,我们根据下面几个式子

一步就给出了$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$的均值和方差。

现在我们来看一下推导均值和方差的思路。

首先,把其他几个式子带入贝叶斯公式的等式右边。

由于多个正态分布的乘积还是一个正态分布,我们知道$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$也可以用一个正态分布公式$\mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$表达,它最后一定能写成这种形式:

问题就变成了怎么把开始那个很长的式子化简,算出$\tilde{\mu}_t$和$\tilde{\beta}_t$。

方差$\tilde{\beta}_t$可以从指数函数的系数得来,比较好求。系数为

所以,方差为:

接下来只要关注指数函数的指数部分。指数部分一定是一个关于的$\mathbf{x}_{t-1}$的二次函数,只要化简成$(\mathbf{x}_{t-1}-C)^2$的形式,再除以一下$-2$倍方差,就可以得到均值了。

指数部分为:

$\mathbf{x}_{t-1}$只在前两项里有。把和$\mathbf{x}_{t-1}$有关的项计算化简,可以计算出均值:

回想一下,在去噪声中,神经网络的输入是$\mathbf{x}_{t}$和$t$。也就是说,上式中$\mathbf{x}_{t}$已知,只有$\mathbf{x}_{0}$一个未知量。要算均值,还需要算出$\mathbf{x}_{0}$。$\mathbf{x}_{0}$和$\mathbf{x}_{t}$之间是有一定联系的。$\mathbf{x}_{t}$是$\mathbf{x}_{0}$在正向过程中第$t$步加噪声的结果。而根据正向过程的公式倒推:

把这个$\mathbf{x}_{0}$带入均值公式,均值最后会化简成我们熟悉的形式。

优化目标

上一节,我们只是简单地说神经网络的优化目标是让加噪声和去噪声的均值接近。而让均值接近,就是让生成$\mathbf{x}_t$的噪声$\epsilon_t$更接近。实际上,这个优化目标是经过简化得来的。扩散模型最早的优化目标是有一定的数学意义的。

扩散模型,全称为扩散概率模型(Diffusion Probabilistic Model)。最简单的一类扩散模型,是去噪扩散概率模型(Denoising Diffusion Probabilistic Model),也就是常说的DDPM。DDPM的框架主要是由两篇论文建立起来的。第一篇论文是首次提出扩散模型思想的Deep Unsupervised Learning using Nonequilibrium Thermodynamics。在此基础上,Denoising Diffusion Probabilistic Models对最早的扩散模型做出了一定的简化,让图像生成效果大幅提升,促成了扩散模型的广泛使用。我们上一节看到的公式,全部是简化后的结果。

扩散概率模型的名字之所以有「概率」二字,是因为这个模型是在描述一个系统的概率。准确来说,扩散模型是在描述经反向过程生成出某一项数据的概率。也就是说,扩散模型$p_{\theta}(\mathbf{x}_0)$是一个有着可训练参数$\theta$的模型,它描述了反向过程生成出数据$\mathbf{x}_0$的概率。$p_{\theta}(\mathbf{x}_0)$满足$p_{\theta}(\mathbf{x}_0)=\int p_{\theta}(\mathbf{x}_{0:T})d\mathbf{x}_{1:T}$,其中$p_{\theta}(\mathbf{x}_{0:T})$就是我们熟悉的反向过程,只不过它是以概率计算的形式表达:

我们上一节里见到的优化目标,是让去噪声操作$p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})$和加噪声操作的逆操作$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$尽可能相似。然而,这个描述并不确切。扩散模型原本的目标,是最大化$p_{\theta}(\mathbf{x}_0)$这个概率,其中$\mathbf{x}_0$是来自训练集的数据。换个角度说,给定一个训练集的数据$\mathbf{x}_0$,经过前向过程和反向过程,扩散模型要让复原出$\mathbf{x}_0$的概率尽可能大。这也是我们在本文开头认识VAE时见到的优化目标。

最大化$p_{\theta}(\mathbf{x}_0)$,一般会写成最小化其负对数值,即最小化$-log p_{\theta}(\mathbf{x}_0)$。使用和VAE类似的变分推理,可以把优化目标转换成优化一个叫做变分下界(variational lower bound, VLB)的量。它最终可以写成:

这里的$D_{KL}(P||Q)$表示分布P和Q之间的KL散度。KL散度是衡量两个分布相似度的指标。如果$P, Q$都是正态分布,则它们的KL散度可以由一个简单的公式给出。关于KL散度的知识可以参见我之前的文章:从零理解熵、交叉熵、KL散度。

其中,第一项$D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0) || p_\theta(\mathbf{x}_T))$和可学习参数$\theta$无关(因为可学习参数只描述了每一步去噪声操作,也就是只描述了$p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})$),可以不去管它。那么这个优化目标就由两部分组成:

  1. 最小化$D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))$表示的是最大化每一个去噪声操作和加噪声逆操作的相似度。
  2. 最小化$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$就是已知$\mathbf{x}_{1}$时,让最后复原原图$\mathbf{x}_{0}$概率更高。

我们分别看这两部分是怎么计算的。

对于第一部分,我们先回顾一下正态分布之间的KL散度公式。设一维正态分布$P, Q$的公式如下:

而对于$D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))$,根据前文的分析,我们知道,待求方差$\Sigma_\theta(\mathbf{x}_{t}, t)$可以直接由计算得到。

两个正态分布方差的比值是常量。所以,在计算KL散度时,不用管方差那一项了,只需要管均值那一项。

由根据之前的均值公式

这一部分的优化目标可以化简成

DDPM论文指出,如果把前面的系数全部丢掉的话,模型的效果更好。最终,我们就能得到一个非常简单的优化目标:

这就是我们上一节见到的优化目标。

当然,还没完,别忘了优化目标里还有$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$这一项。它的形式为:

只管后面有$\theta$的那一项(注意,$\alpha_1=\bar{\alpha}_1=1-\beta_1$):

这和那些KL散度项$t=1$时的形式相同,我们可以用相同的方式简化优化目标,只保留$|| \epsilon_1-\epsilon_\theta(\mathbf{x}_{1}, 1)||^2$。这样,损失函数的形式全都是$||\epsilon_t-\epsilon_{\theta}(\mathbf{x}_{t}, t)||^2$了。

DDPM论文里写$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$这一项可以直接满足简化后的公式$t=1$时的情况,而没有去掉系数的过程。我在网上没找到文章解释这一点,只好按自己的理解来推导这个误差项了。不论如何,推导的过程不是那么重要,重要的是最后的简化形式。

总结

图像生成任务就是把随机生成的向量(噪声)映射成和训练图像类似的图像。为此,扩散模型把这个过程看成是对纯噪声图像的去噪过程。通过学习把图像逐步变成纯噪声的逆操作,扩散模型可以把任何一个纯噪声图像变成有意义的图像,也就是完成图像生成。

对于不同程度的读者,应该对本文有不同的认识。

对于只想了解扩散模型大概原理的读者,只需要阅读第一节,并大概了解:

  • 图像生成任务的通常做法
  • 图像生成任务需要监督
  • VAE通过把图像编码再解码来训练一个解码器
  • 扩散模型是一类特殊的VAE,它的编码固定为加噪声,解码固定为去噪声

对于想认真学习扩散模型的读者,只需读懂第二节的主要内容:

  • 扩散模型的优化目标:让反向过程尽可能成为正向过程的逆操作
  • 正向过程的公式
  • 反向过程的做法(采样算法)
  • 加噪声逆操作的均值和方差在给定$\mathbf{x}_{0}$时可以求出来的,加噪声逆操作的均值就是去噪声的学习目标
  • 简化后的损失函数与训练算法

对有学有余力对数学感兴趣的读者,可以看一看第三节的内容:

  • 加噪声逆操作均值和方差的推导
  • 扩散模型最早的优化目标与DDPM论文是如何简化优化目标的

我个人认为,由于扩散模型的优化目标已经被大幅度简化,除非你的研究目标是改进扩散模型本身,否则没必要花过多的时间钻研数学原理。在学习时,建议快点看懂扩散模型的整体思想,搞懂最核心的训练算法和采样算法,跑通代码。之后就可以去看较新的论文了。

在附录中,我给出了一份DDPM的简单实现。欢迎大家参考,并自己动手复现一遍DDPM。

参考资料与学习建议

网上绝大多数的中英文教程都是照搬 https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ 这篇文章的。这篇文章像教科书一样严谨,适合有一定数学基础的人阅读,但不适合给初学者学习。建议在弄懂扩散模型的大概原理后再来阅读这篇文章补充细节。

多数介绍扩散模型的文章对没学过相关数学知识的人来说很不友好,我在阅读此类文章时碰到了大量的问题:为什么前向公式里有个$\sqrt{1-\beta}$?为什么突然冒出一个快速算$\mathbf{x}_{t}$的公式?为什么反向过程里来了个贝叶斯公式?优化目标是什么?$-log p_{\theta}(\mathbf{x}_0)$是什么?为什么优化目标里一大堆项,每一项的意义又是什么?为什么最后莫名其妙算一个$\epsilon$?为什么采样时$t=0$就不用加方差项了?好不容易,我才把这些问题慢慢搞懂,并在本文做出了解释。希望我的解答能够帮助到同样有这些困惑的读者。想逐步学习扩散模型,可以先看懂我这篇文章的大概讲解,再去其他文章里学懂一些细节。无论是教,还是学,最重要的都是搞懂整体思路,知道动机,最后再去强调细节。

再强烈推荐一位作者写的DDPM系列介绍:https://kexue.fm/archives/9119 。这位作者是全网为数不多的能令我敬佩的作者。早知道有这些文章,我也没必要自己写一遍了。

这里还有篇文章给出了扩散模型中数学公式的详细推导,并补充了变分推理的背景介绍,适合从头学起:https://arxiv.org/abs/2208.11970

想深入学习DDPM,可以看一看最重要的两篇论文:Deep Unsupervised Learning using Nonequilibrium ThermodynamicsDenoising Diffusion Probabilistic Models。当然,后者更重要一些,里面的一些实验结果仍有阅读价值。

我在代码复现时参考了这篇文章。相对于网上的其他开源DDPM实现,这份代码比较简短易懂,更适合学习。不过,这份代码有一点问题。它的神经网络不够强大,采样结果会有一点问题。

附录:代码复现

在这个项目中,我们要用PyTorch实现一个基于U-Net的DDPM,并在MNIST数据集(经典的手写数字数据集)上训练它。模型几分钟就能训练完,我们可以方便地做各种各样的实验。

后续讲解只会给出代码片段,完整的代码请参见 https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/ddpm 。git clone 仓库并安装后,可以直接运行目录里的main.py训练模型并采样。

获取数据集

PyTorch的torchvision提供了获取了MNIST的接口,我们只需要用下面的函数就可以生成MNIST的Dataset实例。参数中,root为数据集的下载路径,download为是否自动下载数据集。令download=True的话,第一次调用该函数时会自动下载数据集,而第二次之后就不用下载了,函数会读取存储在root里的数据。

1
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)

我们可以用下面的代码来下载MNIST并输出该数据集的一些信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torchvision
from torchvision.transforms import ToTensor
def download_dataset():
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp.jpg')
tensor = ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

if __name__ == '__main__':
download_dataset()

执行这段代码,输出大致为:

1
2
3
4
5
6
length of MNIST 60000
<PIL.Image.Image image mode=L size=28x28 at 0x7FB3F09CCE50>
9
torch.Size([1, 28, 28])
tensor(1.)
tensor(0.)

第一行输出表明,MNIST数据集里有60000张图片。而从第二行和第三行输出中,我们发现每一项数据由图片和标签组成,图片是大小为28x28的PIL格式的图片,标签表明该图片是哪个数字。我们可以用torchvision里的ToTensor()把PIL图片转成PyTorch张量,进一步查看图片的信息。最后三行输出表明,每一张图片都是单通道图片(灰度图),颜色值的取值范围是0~1。

我们可以查看一下每张图片的样子。如果你是在用带显示器的电脑,可以去掉img.show那一行的注释,直接查看图片;如果你是在用服务器,可以去img.save的路径里查看图片。该图片的应该长这个样子:

我们可以用下面的代码预处理数据并创建DataLoader。由于DDPM会把图像和正态分布关联起来,我们更希望图像颜色值的取值范围是[-1, 1]。为此,我们可以对图像做一个线性变换,减0.5再乘2。

1
2
3
4
5
6
def get_dataloader(batch_size: int):
transform = Compose([ToTensor(), Lambda(lambda x: (x - 0.5) * 2)])
dataset = torchvision.datasets.MNIST(root='data/mnist',
transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)

DDPM 类

在代码中,我们要实现一个DDPM类。它维护了扩散过程中的一些常量(比如$\alpha$),并且可以计算正向过程和反向过程的结果。

先来实现一下DDPM类的初始化函数。一开始,我们遵从论文的配置,用torch.linspace(min_beta, max_beta, n_steps)min_betamax_beta线性地生成n_steps个时刻的$\beta$。接着,我们根据公式$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$,计算每个时刻的alphaalpha_bar。注意,为了方便实现,我们让t的取值从0开始,要比论文里的$t$少1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

class DDPM():

# n_steps 就是论文里的 T
def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
alphas = 1 - betas
alpha_bars = torch.empty_like(alphas)
product = 1
for i, alpha in enumerate(alphas):
product *= alpha
alpha_bars[i] = product
self.betas = betas
self.n_steps = n_steps
self.alphas = alphas
self.alpha_bars = alpha_bars

部分实现会让 DDPM 继承torch.nn.Module,但我认为这样不好。DDPM本身不是一个神经网络,它只是描述了前向过程和后向过程的一些计算。只有涉及可学习参数的神经网络类才应该继承 torch.nn.Module

准备好了变量后,我们可以来实现DDPM类的其他方法。先实现正向过程方法,该方法会根据公式$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t$计算正向过程中的$\mathbf{x}_t$。

1
2
3
4
5
6
def sample_forward(self, x, t, eps=None):
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
if eps is None:
eps = torch.randn_like(x)
res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x
return res

这里要解释一些PyTorch编程上的细节。这份代码中,self.alpha_bars是一个一维Tensor。而在并行训练中,我们一般会令t为一个形状为(batch_size, )Tensor。PyTorch允许我们直接用self.alpha_bars[t]self.alpha_bars里取出batch_size个数,就像用一个普通的整型索引来从数组中取出一个数一样。有些实现会用torch.gatherself.alpha_bars里取数,其作用是一样的。

我们可以随机从训练集取图片做测试,看看它们在前向过程中是怎么逐步变成噪声的。

接下来实现反向过程。在反向过程中,DDPM会用神经网络预测每一轮去噪的均值,把$\mathbf{x}_t$复原回$\mathbf{x}_0$,以完成图像生成。反向过程即对应论文中的采样算法。

其实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def sample_backward(self, img_shape, net, device, simple_var=True):
x = torch.randn(img_shape).to(device)
net = net.to(device)
for t in range(self.n_steps - 1, -1, -1):
x = self.sample_backward_step(x, t, net, simple_var)
return x

def sample_backward_step(self, x_t, t, net, simple_var=True):
n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

其中,sample_backward是用来给外部调用的方法,而sample_backward_step是执行一步反向过程的方法。

sample_backward会随机生成纯噪声x(对应$\mathbf{x}_T$),再令tn_steps - 10,调用sample_backward_step

1
2
3
4
5
6
def sample_backward(self, img_shape, net, device, simple_var=True):
x = torch.randn(img_shape).to(device)
net = net.to(device)
for t in range(self.n_steps - 1, -1, -1):
x = self.sample_backward_step(x, t, net, simple_var)
return x

sample_backward_step中,我们先准备好这一步的神经网络输出eps。为此,我们要把整型的t转换成一个格式正确的Tensor。考虑到输入里可能有多个batch,我们先获取batch size n,再根据它来生成t_tensor

1
2
3
4
5
6
def sample_backward_step(self, x_t, t, net, simple_var=True):

n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

之后,我们来处理反向过程公式中的方差项。根据伪代码,我们仅在t非零的时候算方差项。方差项用到的方差有两种取值,效果差不多,我们用simple_var来控制选哪种取值方式。获取方差后,我们再随机采样一个噪声,根据公式,得到方差项。

1
2
3
4
5
6
7
8
9
10
if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

最后,我们把eps和方差项套入公式,得到这一步更新过后的图像x_t

1
2
3
4
5
6
mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

稍后完成了训练后,我们再来看反向过程的输出结果。

训练算法

接下来,我们先跳过神经网络的实现,直接完成论文里的训练算法。

再回顾一遍伪代码。首先,我们要随机选取训练图片$\mathbf{x}_{0}$,随机生成当前要训练的时刻$t$,以及随机生成一个生成$\mathbf{x}_{t}$的高斯噪声。之后,我们把$\mathbf{x}_{t}$和$t$输入进神经网络,尝试预测噪声。最后,我们以预测噪声和实际噪声的均方误差为损失函数做梯度下降。

为此,我们可以用下面的代码实现训练。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from dldemos.ddpm.dataset import get_dataloader, get_img_shape
from dldemos.ddpm.ddpm import DDPM
import cv2
import numpy as np
import einops

batch_size = 512
n_epochs = 100


def train(ddpm: DDPM, net, device, ckpt_path):
# n_steps 就是公式里的 T
# net 是某个继承自 torch.nn.Module 的神经网络
n_steps = ddpm.n_steps
dataloader = get_dataloader(batch_size)
net = net.to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)

for e in range(n_epochs):
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
eps = torch.randn_like(x).to(device)
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))
loss = loss_fn(eps_theta, eps)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net.state_dict(), ckpt_path)

代码的主要逻辑都在循环里。首先是完成训练数据$\mathbf{x}_{0}$、$t$、噪声的采样。采样$\mathbf{x}_{0}$的工作可以交给PyTorch的DataLoader完成,每轮遍历得到的x就是训练数据。$t$的采样可以用torch.randint函数随机从[0, n_steps - 1]取数。采样高斯噪声可以直接用torch.randn_like(x)生成一个和训练图片x形状一样的符合标准正态分布的图像。

1
2
3
4
5
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
eps = torch.randn_like(x).to(device)

之后计算$\mathbf{x}_{t}$并将其和$t$输入进神经网络net。计算$\mathbf{x}_{t}$的任务会由DDPM类的sample_forward方法完成,我们在上文已经实现了它。

1
2
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))

得到了预测的噪声eps_theta,我们调用PyTorch的API,算均方误差并调用优化器即可。

1
2
3
4
5
6
7
8
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)

...
loss = loss_fn(eps_theta, eps)
optimizer.zero_grad()
loss.backward()
optimizer.step()

去噪神经网络

在DDPM中,理论上我们可以用任意一种神经网络架构。但由于DDPM任务十分接近图像去噪任务,而U-Net又是去噪任务中最常见的网络架构,因此绝大多数DDPM都会使用基于U-Net的神经网络。

我一直想训练一个尽可能简单的模型。经过多次实验,我发现DDPM的神经网络很难训练。哪怕是对于比较简单的MNIST数据集,结构差一点的网络(比如纯ResNet)都不太行,只有带了残差块和时序编码的U-Net才能较好地完成去噪。注意力模块倒是可以不用加上。

由于神经网络结构并不是DDPM学习的重点,我这里就不对U-Net的写法做解说,而是直接贴上代码了。代码中大部分内容都和普通的U-Net无异。唯一要注意的地方就是时序编码。去噪网络的输入除了图像外,还有一个时间戳t。我们要考虑怎么把t的信息和输入图像信息融合起来。大部分人的做法是对t进行Transformer中的位置编码,把该编码加到图像的每一处上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import torch
import torch.nn as nn
import torch.nn.functional as F
from dldemos.ddpm.dataset import get_img_shape


class PositionalEncoding(nn.Module):

def __init__(self, max_seq_len: int, d_model: int):
super().__init__()

# Assume d_model is an even number for convenience
assert d_model % 2 == 0

pe = torch.zeros(max_seq_len, d_model)
i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(max_seq_len, d_model)

self.embedding = nn.Embedding(max_seq_len, d_model)
self.embedding.weight.data = pe
self.embedding.requires_grad_(False)

def forward(self, t):
return self.embedding(t)


class ResidualBlock(nn.Module):

def __init__(self, in_c: int, out_c: int):
super().__init__()
self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(out_c)
self.actvation1 = nn.ReLU()
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(out_c)
self.actvation2 = nn.ReLU()
if in_c != out_c:
self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, 1),
nn.BatchNorm2d(out_c))
else:
self.shortcut = nn.Identity()

def forward(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.actvation1(x)
x = self.conv2(x)
x = self.bn2(x)
x += self.shortcut(input)
x = self.actvation2(x)
return x


class ConvNet(nn.Module):

def __init__(self,
n_steps,
intermediate_channels=[10, 20, 40],
pe_dim=10,
insert_t_to_all_layers=False):
super().__init__()
C, H, W = get_img_shape() # 1, 28, 28
self.pe = PositionalEncoding(n_steps, pe_dim)

self.pe_linears = nn.ModuleList()
self.all_t = insert_t_to_all_layers
if not insert_t_to_all_layers:
self.pe_linears.append(nn.Linear(pe_dim, C))

self.residual_blocks = nn.ModuleList()
prev_channel = C
for channel in intermediate_channels:
self.residual_blocks.append(ResidualBlock(prev_channel, channel))
if insert_t_to_all_layers:
self.pe_linears.append(nn.Linear(pe_dim, prev_channel))
else:
self.pe_linears.append(None)
prev_channel = channel
self.output_layer = nn.Conv2d(prev_channel, C, 3, 1, 1)

def forward(self, x, t):
n = t.shape[0]
t = self.pe(t)
for m_x, m_t in zip(self.residual_blocks, self.pe_linears):
if m_t is not None:
pe = m_t(t).reshape(n, -1, 1, 1)
x = x + pe
x = m_x(x)
x = self.output_layer(x)
return x


class UnetBlock(nn.Module):

def __init__(self, shape, in_c, out_c, residual=False):
super().__init__()
self.ln = nn.LayerNorm(shape)
self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.activation = nn.ReLU()
self.residual = residual
if residual:
if in_c == out_c:
self.residual_conv = nn.Identity()
else:
self.residual_conv = nn.Conv2d(in_c, out_c, 1)

def forward(self, x):
out = self.ln(x)
out = self.conv1(out)
out = self.activation(out)
out = self.conv2(out)
if self.residual:
out += self.residual_conv(x)
out = self.activation(out)
return out


class UNet(nn.Module):

def __init__(self,
n_steps,
channels=[10, 20, 40, 80],
pe_dim=10,
residual=False) -> None:
super().__init__()
C, H, W = get_img_shape()
layers = len(channels)
Hs = [H]
Ws = [W]
cH = H
cW = W
for _ in range(layers - 1):
cH //= 2
cW //= 2
Hs.append(cH)
Ws.append(cW)

self.pe = PositionalEncoding(n_steps, pe_dim)

self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.pe_linears_en = nn.ModuleList()
self.pe_linears_de = nn.ModuleList()
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
prev_channel = C
for channel, cH, cW in zip(channels[0:-1], Hs[0:-1], Ws[0:-1]):
self.pe_linears_en.append(
nn.Sequential(nn.Linear(pe_dim, prev_channel), nn.ReLU(),
nn.Linear(prev_channel, prev_channel)))
self.encoders.append(
nn.Sequential(
UnetBlock((prev_channel, cH, cW),
prev_channel,
channel,
residual=residual),
UnetBlock((channel, cH, cW),
channel,
channel,
residual=residual)))
self.downs.append(nn.Conv2d(channel, channel, 2, 2))
prev_channel = channel

self.pe_mid = nn.Linear(pe_dim, prev_channel)
channel = channels[-1]
self.mid = nn.Sequential(
UnetBlock((prev_channel, Hs[-1], Ws[-1]),
prev_channel,
channel,
residual=residual),
UnetBlock((channel, Hs[-1], Ws[-1]),
channel,
channel,
residual=residual),
)
prev_channel = channel
for channel, cH, cW in zip(channels[-2::-1], Hs[-2::-1], Ws[-2::-1]):
self.pe_linears_de.append(nn.Linear(pe_dim, prev_channel))
self.ups.append(nn.ConvTranspose2d(prev_channel, channel, 2, 2))
self.decoders.append(
nn.Sequential(
UnetBlock((channel * 2, cH, cW),
channel * 2,
channel,
residual=residual),
UnetBlock((channel, cH, cW),
channel,
channel,
residual=residual)))

prev_channel = channel

self.conv_out = nn.Conv2d(prev_channel, C, 3, 1, 1)

def forward(self, x, t):
n = t.shape[0]
t = self.pe(t)
encoder_outs = []
for pe_linear, encoder, down in zip(self.pe_linears_en, self.encoders,
self.downs):
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = encoder(x + pe)
encoder_outs.append(x)
x = down(x)
pe = self.pe_mid(t).reshape(n, -1, 1, 1)
x = self.mid(x + pe)
for pe_linear, decoder, up, encoder_out in zip(self.pe_linears_de,
self.decoders, self.ups,
encoder_outs[::-1]):
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = up(x)

pad_x = encoder_out.shape[2] - x.shape[2]
pad_y = encoder_out.shape[3] - x.shape[3]
x = F.pad(x, (pad_x // 2, pad_x - pad_x // 2, pad_y // 2,
pad_y - pad_y // 2))
x = torch.cat((encoder_out, x), dim=1)
x = decoder(x + pe)
x = self.conv_out(x)
return x


convnet_small_cfg = {
'type': 'ConvNet',
'intermediate_channels': [10, 20],
'pe_dim': 128
}

convnet_medium_cfg = {
'type': 'ConvNet',
'intermediate_channels': [10, 10, 20, 20, 40, 40, 80, 80],
'pe_dim': 256,
'insert_t_to_all_layers': True
}
convnet_big_cfg = {
'type': 'ConvNet',
'intermediate_channels': [20, 20, 40, 40, 80, 80, 160, 160],
'pe_dim': 256,
'insert_t_to_all_layers': True
}

unet_1_cfg = {'type': 'UNet', 'channels': [10, 20, 40, 80], 'pe_dim': 128}
unet_res_cfg = {
'type': 'UNet',
'channels': [10, 20, 40, 80],
'pe_dim': 128,
'residual': True
}


def build_network(config: dict, n_steps):
network_type = config.pop('type')
if network_type == 'ConvNet':
network_cls = ConvNet
elif network_type == 'UNet':
network_cls = UNet

network = network_cls(n_steps, **config)
return network

实验结果与采样

把之前的所有代码综合一下,我们以带残差块的U-Net为去噪网络,执行训练。

1
2
3
4
5
6
7
8
9
10
11
if __name__ == '__main__':
n_steps = 1000
config_id = 4
device = 'cuda'
model_path = 'dldemos/ddpm/model_unet_res.pth'

config = unet_res_cfg
net = build_network(config, n_steps)
ddpm = DDPM(device, n_steps)

train(ddpm, net, device=device, ckpt_path=model_path)

按照默认训练配置,在3090上花5分钟不到,训练30~40个epoch即可让网络基本收敛。最终收敛时loss在0.023~0.024左右。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
batch size: 512
epoch 0 loss: 0.23103461712201437 elapsed 7.01s
epoch 1 loss: 0.0627968365987142 elapsed 13.66s
epoch 2 loss: 0.04828845852613449 elapsed 20.25s
epoch 3 loss: 0.04148937337398529 elapsed 26.80s
epoch 4 loss: 0.03801360730528831 elapsed 33.37s
epoch 5 loss: 0.03604260584712028 elapsed 39.96s
epoch 6 loss: 0.03357676289876302 elapsed 46.57s
epoch 7 loss: 0.0335664684087038 elapsed 53.15s
...
epoch 30 loss: 0.026149748386939366 elapsed 204.64s
epoch 31 loss: 0.025854381563266117 elapsed 211.24s
epoch 32 loss: 0.02589433005253474 elapsed 217.84s
epoch 33 loss: 0.026276464049021404 elapsed 224.41s
...
epoch 96 loss: 0.023299352884292603 elapsed 640.25s
epoch 97 loss: 0.023460942271351815 elapsed 646.90s
epoch 98 loss: 0.023584651704629263 elapsed 653.54s
epoch 99 loss: 0.02364126600921154 elapsed 660.22s

训练这个网络时,并没有特别好的测试指标,我们只能通过观察采样图像来评价网络的表现。我们可以用下面的代码调用DDPM的反向传播方法,生成多幅图像并保存下来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def sample_imgs(ddpm,
net,
output_path,
n_sample=81,
device='cuda',
simple_var=True):
net = net.to(device)
net = net.eval()
with torch.no_grad():
shape = (n_sample, *get_img_shape()) # 1, 3, 28, 28
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var).detach().cpu()
imgs = (imgs + 1) / 2 * 255
imgs = imgs.clamp(0, 255)
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(n_sample**0.5))

imgs = imgs.numpy().astype(np.uint8)

cv2.imwrite(output_path, imgs)

一切顺利的话,我们可以得到一些不错的生成结果。下图是我得到的一些生成图片:

大部分生成的图片都对应一个阿拉伯数字,它们和训练集MNIST里的图片非常接近。这算是一个不错的生成结果。

如果神经网络的拟合能力较弱,生成结果就会差很多。下图是我训练一个简单的ResNet后得到的采样结果:

可以看出,每幅图片都很乱,基本对应不上一个数字。这就是一个较差的训练结果。

如果网络再差一点,可能会生成纯黑或者纯白的图片。这是因为网络的预测结果不准,在反向过程中,图像的均值不断偏移,偏移到远大于1或者远小于-1的值了。

总结一下,在复现DDPM时,最主要是要学习DDPM论文的两个算法,即训练算法和采样算法。两个算法很简单,可以轻松地把它们翻译成代码。而为了成功完成复现,还需要花一点心思在编写U-Net上,尤其是注意处理时间戳的部分。

前段时间我写了一篇VQVAE的解读,现在再补充一篇VQVAE的PyTorch实现教程。在这个项目中,我们会实现VQVAE论文,在MNIST和CelebAHQ两个数据集上完成图像生成。具体来说,我们会先实现并训练一个图像压缩网络VQVAE,它能把真实图像编码成压缩图像,或者把压缩图像解码回真实图像。之后,我们会训练一个生成压缩图像的生成网络PixelCNN。

代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/VQVAE

项目运行示例

如果你只是想快速地把项目运行起来,可以只阅读本节。

在本地安装好项目后,运行python dldemos/VQVAE/dataset.py来下载MNIST数据集。之后运行python dldemos/VQVAE/main.py,这个脚本会完成以下四个任务:

  1. 训练VQVAE
  2. 用VQVAE重建数据集里的随机数据
  3. 训练PixelCNN
  4. 用PixelCNN+VQVAE随机生成图片

第二步得到的重建结果大致如下(每对图片中左图是原图,右图是重建结果):

第四步得到的随机生成结果大致如下:

如果你要使用CelebAHQ数据集,请照着下一节的指示把CelebAHQ下载到指定目录,再执行python dldemos/VQVAE/main.py -c 4

数据集准备

MNIST数据集可以用PyTorch的API自动下载。我们可以用下面的代码下载MNIST数据集并查看数据的格式。从输出中可知,MNIST的图片形状为[1, 28, 28],颜色取值范围为[0, 1]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def download_mnist():
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp_mnist.jpg')
tensor = transforms.ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

我们可以用下面的代码把它封成简单的Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class MNISTImageDataset(Dataset):

def __init__(self, img_shape=(28, 28)):
super().__init__()
self.img_shape = img_shape
self.mnist = torchvision.datasets.MNIST(root='data/mnist')

def __len__(self):
return len(self.mnist)

def __getitem__(self, index: int):
img = self.mnist[index][0]
pipeline = transforms.Compose(
[transforms.Resize(self.img_shape),
transforms.ToTensor()])
return pipeline(img)

接下来准备CelebAHQ。CelebAHQ数据集原本的图像大小是1024x1024,但我们这个项目用不到这么大的图片。我在kaggle上找到了一个256x256的CelebAHQ (https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256),所有文件加起来只有300MB左右,很适合我们项目。请在该页面下载压缩包,并把压缩包解压到项目的`data/celebA/celeba_hq_256`目录下。

下载完数据后,我们可以写一个简单的从目录中读取图片的Dataset类。和MNIST的预处理流程不同,我这里给CelebAHQ的图片加了一个中心裁剪的操作,一来可以让人脸占比更大,便于模型学习,二来可以让该类兼容CelebA数据集(CelebA数据集的图片不是正方形,需要裁剪)。这个操作是可选的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class CelebADataset(Dataset):

def __init__(self, root, img_shape=(64, 64)):
super().__init__()
self.root = root
self.img_shape = img_shape
self.filenames = sorted(os.listdir(root))

def __len__(self) -> int:
return len(self.filenames)

def __getitem__(self, index: int):
path = os.path.join(self.root, self.filenames[index])
img = Image.open(path)
pipeline = transforms.Compose([
transforms.CenterCrop(168),
transforms.Resize(self.img_shape),
transforms.ToTensor()
])
return pipeline(img)

有了数据集类后,我们可以用它们生成Dataloader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
CELEBA_DIR = 'data/celebA/img_align_celeba'
CELEBA_HQ_DIR = 'data/celebA/celeba_hq_256'
def get_dataloader(type,
batch_size,
img_shape=None,
dist_train=False,
num_workers=4,
**kwargs):
if type == 'CelebA':
if img_shape is not None:
kwargs['img_shape'] = img_shape
dataset = CelebADataset(CELEBA_DIR, **kwargs)
elif type == 'CelebAHQ':
if img_shape is not None:
kwargs['img_shape'] = img_shape
dataset = CelebADataset(CELEBA_HQ_DIR, **kwargs)
elif type == 'MNIST':
if img_shape is not None:
dataset = MNISTImageDataset(img_shape)
else:
dataset = MNISTImageDataset()
if dist_train:
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers)
return dataloader, sampler
else:
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return dataloader

我们可以利用Dataloader来查看CelebAHQ数据集的内容及数据格式。

1
2
3
4
5
6
7
8
9
10
11
12
13
if os.path.exists(CELEBA_HQ_DIR):
dataloader = get_dataloader('CelebAHQ', 16)
img = next(iter(dataloader))
print(img.shape)
N = img.shape[0]
img = einops.rearrange(img,
'(n1 n2) c h w -> c (n1 h) (n2 w)',
n1=int(N**0.5))
print(img.shape)
print(img.max())
print(img.min())
img = transforms.ToPILImage()(img)
img.save('work_dirs/tmp_celebahq.jpg')

从输出中可知,CelebAHQ的颜色取值范围同样是[0, 1]。经我们的预处理流水线得到的图片如下。

实现并训练 VQVAE

要用VQVAE做图像生成,其实要训练两个模型:一个是用于压缩图像的VQVAE,另一个是生成压缩图像的PixelCNN。这两个模型是可以分开训练的。我们先来实现并训练VQVAE。

VQVAE的架构非常简单:一个编码器,一个解码器,外加中间一个嵌入层。损失函数为图像的重建误差与编码器输出与其对应嵌入之间的误差。

VQVAE的编码器和解码器的结构也很简单,仅由普通的上/下采样层和残差块组成。具体来说,编码器先是有两个3x3卷积+2倍下采样卷积的模块,再有两个残差块(ReLU, 3x3卷积, ReLU, 1x1卷积);解码器则反过来,先有两个残差块,再有两个3x3卷积+2倍上采样反卷积的模块。为了让代码看起来更清楚一点,我们不用过度封装,仅实现一个残差块模块,再用残差块和PyTorch自带模块拼成VQVAE。

先实现残差块。注意,由于模型比较简单,残差块内部和VQVAE其他地方都可以不使用BatchNorm。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class ResidualBlock(nn.Module):

def __init__(self, dim):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)

def forward(self, x):
tmp = self.relu(x)
tmp = self.conv1(tmp)
tmp = self.relu(tmp)
tmp = self.conv2(tmp)
return x + tmp

有了残差块类后,我们可以直接实现VQVAE类。我们先在初始化函数里把模块按顺序搭好。编码器和解码器的结构按前文的描述搭起来即可。嵌入空间(codebook)其实就是个普通的嵌入层。此处我仿照他人代码给嵌入层显式初始化参数,但实测下来和默认的初始化参数方式差别不大。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class VQVAE(nn.Module):

def __init__(self, input_dim, dim, n_embedding):
super().__init__()
self.encoder = nn.Sequential(nn.Conv2d(input_dim, dim, 4, 2, 1),
nn.ReLU(), nn.Conv2d(dim, dim, 4, 2, 1),
nn.ReLU(), nn.Conv2d(dim, dim, 3, 1, 1),
ResidualBlock(dim), ResidualBlock(dim))
self.vq_embedding = nn.Embedding(n_embedding, dim)
self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding,
1.0 / n_embedding)
self.decoder = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1),
ResidualBlock(dim), ResidualBlock(dim),
nn.ConvTranspose2d(dim, dim, 4, 2, 1), nn.ReLU(),
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1))
self.n_downsample = 2

之后,我们来实现模型的前向传播。这里的逻辑就略显复杂了。整体来看,这个函数完成了编码、取最近邻、解码这三步。其中,取最近邻的部分最为复杂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def forward(self, x):
# encode
ze = self.encoder(x)

# ze: [N, C, H, W]
# embedding [K, C]
embedding = self.vq_embedding.weight.data
N, C, H, W = ze.shape
K, _ = embedding.shape
embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
ze_broadcast = ze.reshape(N, 1, C, H, W)
distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
nearest_neighbor = torch.argmin(distance, 1)
# make C to the second dim
zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)
# stop gradient
decoder_input = ze + (zq - ze).detach()

# decode
x_hat = self.decoder(decoder_input)
return x_hat, ze, zq

我们来详细看一看取最近邻的实现。取最近邻时,我们要用到两块数据:编码器输出ze与嵌入矩阵embeddingze可以看成一个形状为[N, H, W]的数组,数组存储了长度为C的向量。而嵌入矩阵里有K个长度为C的向量。
1
2
3
4
5
# ze: [N, C, H, W]
# embedding [K, C]
embedding = self.vq_embedding.weight.data
N, C, H, W = ze.shape
K, _ = embedding.shape

为了求N*H*W个向量在嵌入矩阵里的最近邻,我们要先算这每个向量与嵌入矩阵里K个向量的距离。在算距离前,我们要把embeddingze的形状变换一下,保证(embedding_broadcast - ze_broadcast)**2的形状为[N, K, C, H, W]。我们对这个临时结果的第2号维度(C所在维度)求和,得到形状为[N, K, H, W]distance。它的含义是,对于N*H*W个向量,每个向量到嵌入空间里K个向量的距离分别是多少。
1
2
3
embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
ze_broadcast = ze.reshape(N, 1, C, H, W)
distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)

有了距离张量后,我们再对其1号维度(K所在维度)求最近邻所在下标。

1
nearest_neighbor = torch.argmin(distance, 1)

有了下标后,我们可以用self.vq_embedding(nearest_neighbor)从嵌入空间取出最近邻了。别忘了,nearest_neighbor的形状是[N, H, W]self.vq_embedding(nearest_neighbor)的形状会是[N, H, W, C]。我们还要把C维度转置一下。

1
2
# make C to the second dim
zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)

最后,我们用论文里提到的停止梯度算子,把zq变形一下。这样,算误差的时候用的是zq,算梯度时ze会接收解码器传来的梯度。

1
2
# stop gradient
decoder_input = ze + (zq - ze).detach()

求最近邻的部分就到此结束了。最后再补充一句,前向传播函数不仅返回了重建结果x_hat,还返回了ze, zq。这是因为我们待会要在训练时根据ze, zq求损失函数。

准备好了模型类后,假设我们已经用某些超参数初始化好了模型model,我们可以用下面的代码训练VQVAE。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def train_vqvae(model: VQVAE,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/model.pth',
batch_size=64,
dataset_type='MNIST',
lr=1e-3,
n_epochs=100,
l_w_embedding=1,
l_w_commitment=0.25):
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr)
mse_loss = nn.MSELoss()
tic = time.time()
for e in range(n_epochs):
total_loss = 0

for x in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)

x_hat, ze, zq = model(x)
l_reconstruct = mse_loss(x, x_hat)
l_embedding = mse_loss(ze.detach(), zq)
l_commitment = mse_loss(ze, zq.detach())
loss = l_reconstruct + \
l_w_embedding * l_embedding + l_w_commitment * l_commitment
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), ckpt_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')

先看一下训练函数的参数。其他参数都没什么特别的,只有误差权重l_w_embedding=1,l_w_commitment=0.25需要讨论一下。误差函数有三项,但论文只给了第三项的权重(0.25),默认第二项的权重为1。我在实现时把第二项的权重l_w_embedding也加上了。

1
2
3
4
5
6
7
8
9
10
def train_vqvae(model: VQVAE,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/model.pth',
batch_size=64,
dataset_type='MNIST',
lr=1e-3,
n_epochs=100,
l_w_embedding=1,
l_w_commitment=0.25):

再来把函数体过一遍。一开始,我们可以用传来的参数把dataloader初始化一下。

1
2
3
4
5
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)

再把模型的状态调好,并准备好优化器和算均方误差的函数。

1
2
3
4
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr)
mse_loss = nn.MSELoss()

准备好变量后,进入训练循环。训练的过程比较常规,唯一要注意的就是误差计算部分。由于我们把复杂的逻辑都放在了模型类中,这里我们可以直接先用model(x)得到重建图像x_hat和算误差的ze, zq,再根据论文里的公式算3个均方误差,最后求一个加权和,代码比较简明。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for e in range(n_epochs):
for x in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)

x_hat, ze, zq = model(x)
l_reconstruct = mse_loss(x, x_hat)
l_embedding = mse_loss(ze.detach(), zq)
l_commitment = mse_loss(ze, zq.detach())
loss = l_reconstruct + \
l_w_embedding * l_embedding + l_w_commitment * l_commitment
optimizer.zero_grad()
loss.backward()
optimizer.step()

训练完毕后,我们可以用下面的代码来测试VQVAE的重建效果。所谓重建,就是模拟训练的过程,随机取一些图片,先编码后解码,看解码出来的图片和原图片是否一致。为了获取重建后的图片,我们只需要直接执行前向传播函数model(x)即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def reconstruct(model, x, device, dataset_type='MNIST'):
model.to(device)
model.eval()
with torch.no_grad():
x_hat, _, _ = model(x)
n = x.shape[0]
n1 = int(n**0.5)
x_cat = torch.concat((x, x_hat), 3)
x_cat = einops.rearrange(x_cat, '(n1 n2) c h w -> (n1 h) (n2 w) c', n1=n1)
x_cat = (x_cat.clip(0, 1) * 255).cpu().numpy().astype(np.uint8)
if dataset_type == 'CelebA' or dataset_type == 'CelebAHQ':
x_cat = cv2.cvtColor(x_cat, cv2.COLOR_RGB2BGR)
cv2.imwrite(f'work_dirs/vqvae_reconstruct_{dataset_type}.jpg', x_cat)

vqvae = ...
dataloader = get_dataloader(...)
img = next(iter(dataloader)).to(device)
reconstruct(vqvae, img, device, cfg['dataset_type'])

训练压缩图像生成模型 PixelCNN

有了一个VQVAE后,我们要用另一个模型对VQVAE的离散空间采样,也就是训练一个能生成压缩图片的模型。我们可以按照VQVAE论文的方法,使用PixelCNN来生成压缩图片。

PixelCNN 的原理及实现方法就不在这里过多介绍了。详情可以参见我之前的PixelCNN解读文章。简单来说,PixelCNN给每个像素从左到右,从上到下地编了一个序号,让每个像素仅由之前所有像素决定。采样时,PixelCNN按序号从左上到右下逐个生成图像的每一个像素;训练时,PixelCNN使用了某种掩码机制,使得每个像素只能看到编号更小的像素,并行地输出每一个像素的生成结果。

PixelCNN具体的训练示意图如下。模型的输入是一幅图片,每个像素的取值是0~255;模型给图片的每个像素输出了一个概率分布,即表示此处颜色取0,取1,……,取255的概率。由于神经网络假设数据的输入符合标准正态分布,我们要在数据输入前把整型的颜色转换成0~1之间的浮点数。最简单的转换方法是除以255。

以上是训练PixelCNN生成普通图片的过程。而在训练PixelCNN生成压缩图片时,上述过程需要修改。压缩图片的取值是离散编码。离散编码和颜色值不同,它不是连续的。你可以说颜色1和颜色0、2相近,但不能说离散编码1和离散编码0、2相近。因此,为了让PixelCNN建模离散编码,需要把原来的除以255操作换成一个嵌入层,使得网络能够读取离散编码。

反映在代码中,假设我们已经有了一个普通的PixelCNN模型GatedPixelCNN,我们需要在整个模型的最前面套一个嵌入层,嵌入层的嵌入个数等于离散编码的个数(color_level),嵌入长度等于模型的特征长度(p)。由于嵌入层会直接输出一个长度为p的向量,我们还需要把第一个模块的输入通道数改成p

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from dldemos.pixelcnn.model import GatedPixelCNN, GatedBlock

import torch.nn as nn


class PixelCNNWithEmbedding(GatedPixelCNN):

def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
super().__init__(n_blocks, p, linear_dim, bn, color_level)
self.embedding = nn.Embedding(color_level, p)
self.block1 = GatedBlock('A', p, p, bn)

def forward(self, x):
x = self.embedding(x)
x = x.permute(0, 3, 1, 2).contiguous()
return super().forward(x)

有了一个能处理离散编码的PixelCNN后,我们可以用下面的代码来训练PixelCNN。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def train_generative_model(vqvae: VQVAE,
model,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/gen_model.pth',
dataset_type='MNIST',
batch_size=64,
n_epochs=50):
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)
vqvae.to(device)
vqvae.eval()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
loss_fn = nn.CrossEntropyLoss()
tic = time.time()
for e in range(n_epochs):
total_loss = 0
for x in dataloader:
current_batch_size = x.shape[0]
with torch.no_grad():
x = x.to(device)
x = vqvae.encode(x)

predict_x = model(x)
loss = loss_fn(predict_x, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), ckpt_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')
gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],
cfg['pixelcnn_dim'],
cfg['pixelcnn_linear_dim'], True,
cfg['n_embedding'])
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
train_generative_model(vqvae,
gen_model,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['gen_model_path'],
dataset_type=cfg['dataset_type'],
batch_size=cfg['batch_size_2'],
n_epochs=cfg['n_epochs_2'])

训练部分的核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
loss_fn = nn.CrossEntropyLoss()
for x in dataloader:
current_batch_size = x.shape[0]
with torch.no_grad():
x = x.to(device)
x = vqvae.encode(x)

predict_x = model(x)
loss = loss_fn(predict_x, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()

这段代码的意思是说,从训练集里随机取图片x,再将图片压缩成离散编码x = vqvae.encode(x)。这时,x既是PixelCNN的输入,也是PixelCNN的拟合目标。把它输入进PixelCNN,PixelCNN会输出每个像素的概率分布。用交叉熵损失函数约束输出结果即可。

训练完毕后,我们可以用下面的函数来完成整套图像生成流水线。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def sample_imgs(vqvae: VQVAE,
gen_model,
img_shape,
n_sample=81,
device='cuda',
dataset_type='MNIST'):
vqvae = vqvae.to(device)
vqvae.eval()
gen_model = gen_model.to(device)
gen_model.eval()

C, H, W = img_shape
H, W = vqvae.get_latent_HW((C, H, W))
input_shape = (n_sample, H, W)
x = torch.zeros(input_shape).to(device).to(torch.long)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = gen_model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist, 1)
x[:, i, j] = pixel[:, 0]

imgs = vqvae.decode(x)

imgs = imgs * 255
imgs = imgs.clip(0, 255)
imgs = einops.rearrange(imgs,
'(n1 n2) c h w -> (n1 h) (n2 w) c',
n1=int(n_sample**0.5))

imgs = imgs.detach().cpu().numpy().astype(np.uint8)
if dataset_type == 'CelebA' or dataset_type == 'CelebAHQ':
imgs = cv2.cvtColor(imgs, cv2.COLOR_RGB2BGR)

cv2.imwrite(f'work_dirs/vqvae_sample_{dataset_type}.jpg', imgs)

抛掉前后处理,和图像生成有关的代码如下。一开始,我们要随便创建一个空图片x,用于储存PixelCNN生成的压缩图片。之后,我们按顺序遍历每个像素,把当前图片输入进PixelCNN,让PixelCNN预测下一个像素的概率分布prob_dist。我们再用torch.multinomial从概率分布中采样,把采样的结果填回图片。遍历结束后,我们用VQVAE的解码器把压缩图片变成真实图片。

1
2
3
4
5
6
7
8
9
10
11
12
13
C, H, W = img_shape
H, W = vqvae.get_latent_HW((C, H, W))
input_shape = (n_sample, H, W)
x = torch.zeros(input_shape).to(device).to(torch.long)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = gen_model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist, 1)
x[:, i, j] = pixel[:, 0]

imgs = vqvae.decode(x)

至此,我们已经实现了用VQVAE做图像生成的四个任务:训练VQVAE、重建图像、训练PixelCNN、随机生成图像。完整的main函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
if __name__ == '__main__':
os.makedirs('work_dirs', exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument('-c', type=int, default=0)
parser.add_argument('-d', type=int, default=0)
args = parser.parse_args()
cfg = get_cfg(args.c)

device = f'cuda:{args.d}'

img_shape = cfg['img_shape']

vqvae = VQVAE(img_shape[0], cfg['dim'], cfg['n_embedding'])
gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],
cfg['pixelcnn_dim'],
cfg['pixelcnn_linear_dim'], True,
cfg['n_embedding'])
# 1. Train VQVAE
train_vqvae(vqvae,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['vqvae_path'],
batch_size=cfg['batch_size'],
dataset_type=cfg['dataset_type'],
lr=cfg['lr'],
n_epochs=cfg['n_epochs'],
l_w_embedding=cfg['l_w_embedding'],
l_w_commitment=cfg['l_w_commitment'])

# 2. Test VQVAE by visualizaing reconstruction result
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
dataloader = get_dataloader(cfg['dataset_type'],
16,
img_shape=(img_shape[1], img_shape[2]))
img = next(iter(dataloader)).to(device)
reconstruct(vqvae, img, device, cfg['dataset_type'])

# 3. Train Generative model (Gated PixelCNN in our project)
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))

train_generative_model(vqvae,
gen_model,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['gen_model_path'],
dataset_type=cfg['dataset_type'],
batch_size=cfg['batch_size_2'],
n_epochs=cfg['n_epochs_2'])

# 4. Sample VQVAE
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
gen_model.load_state_dict(torch.load(cfg['gen_model_path']))
sample_imgs(vqvae,
gen_model,
cfg['img_shape'],
device=device,
dataset_type=cfg['dataset_type'])

实验

VQVAE有两个超参数:嵌入个数n_embedding、特征向量长度dim。论文中n_embedding=512dim=256。而经我实现发现,用更小的参数量也能达到不错的效果。

所有实验的配置文件我都放在了该项目目录下config.py文件中。对于MNIST数据集,我使用的模型超参数为:dim=32, n_embedding=32。VQVAE重建结果如下所示。可以说重建得几乎完美(每对图片左图为原图,右图为重建结果)。

而对于CelebAHQ数据集,我测试了不同输入尺寸下的不同VQVAE,共有4组配置。

  1. shape=(3, 128, 128) dim=128 n_embedding=64
  2. shape=(3, 128, 128) dim=128 n_embedding=128
  3. shape=(3, 64, 64) dim=128 n_embedding=64
  4. shape=(3, 64, 64) dim=128 n_embedding=32

实验的结果很好预测。对于同尺寸的图片,嵌入数越多重建效果越好。这里我只展示下第一组和第二组的重建结果。

可以看出,VQVAE的重建效果还不错。但由于只使用了均方误差,重建图片在细节上还是比较模糊。重建效果还是很重要的,它决定了该方法做图像生成的质量上限。后续有很多工作都试图提升VQVAE的重建效果。

接下来来看一下随机图像生成的实验。PixelCNN主要有模块数n_blocks、特征长度dim,输出线性层特征长度linear_dim这三个超参数。其中模块数一般是固定的,而输出线性层就被用了一次,其特征长度的影响不大。最需要调节的是特征长度dim。对于MNIST,我的超参数设置为

  • n_blocks=15 dim=128 linear_dim=32.

对于CelebAHQ,我的超参数设置为

  • n_blocks=15 dim=384 linear_dim=256.

PixelCNN的训练时间主要由输入图片尺寸和dim决定,训练难度主要由VQVAE的嵌入个数(即多分类的类别数)决定。PixelCNN训起来很花时间。如果时间有限,在CelebAHQ上建议只训练最小最简单的第4组配置。我在项目中提供了PixelCNN的并行训练脚本,比如用下面的命令可以用4张卡在1号配置下并行训练。

1
torchrun --nproc_per_node=4 dldemos/VQVAE/dist_train_pixelcnn.py -c 1

来看一下实验结果。MNIST上的采样结果还是非常不错的。

CelebAHQ上的结果会差一点。以下是第4组配置(图像边长64,嵌入数32)的采样结果。大部分图片都还行,起码看得出是一张人脸。但64x64的图片本来就分辨率不高,加上VQVAE解码的损耗,放大来看人脸还是比较模糊的。

第1组配置(图像边长128,嵌入数64)的PixelCNN实在训练得太慢了,我只训了一个半成品模型。由于部分生成结果比较吓人,我只挑了几个还能看得过去的生成结果。可以看出,如果把模型训完的话,边长128的模型肯定比边长64的模型效果更好。

参考资料

网上几乎找不到在CelebAHQ上训练的VQVAE PyTorch项目。我在实现这份代码时,参考了以下项目:

实验经历分享

别看VQVAE的代码不难,我做这些实验时还是经历了不少波折的。

一开始,我花一天就把代码写完了,并完成了MNIST上的实验。我觉得在MNIST上做实验的难度太低,不过瘾,就准备把数据集换成CelebA再做一些实验。结果这一做就是两个星期。

换成CelebA后,我碰到的第一个问题是VQVAE训练速度太慢。我尝试减半模型参数,训练时间却减小得不明显。我大致猜出是数据读取占用了大量时间,用性能分析工具一查,果然如此。原来我在DataLoader中一直只用了一个线程,加上num_workers=4就好了。我还把数据集打包成LMDB格式进一步加快数据读取速度。

之后,我又发现VQVAE在CelebA上的重建效果很差。我尝试增加模型参数,没起作用。我又怀疑是64x64的图片质量太低,模型学不到东西,就尝试把输入尺寸改成128x128,并把数据集从CelebA换成CelebAHQ,重建效果依然不行。我调了很多参数,发现了一些奇怪的现象:在嵌入层前使用和不使用BatchNorm对结果的影响很大,且显式初始化嵌入层会让模型的误差一直居高不下。我实在是找不到问题,就拿代码对着别人的PyTorch实现一行一行比较过去。总算,我发现我在使用嵌入层时是用vq_embedding.weight.data[x](因为前面已经获取了这个矩阵,这样写比较自然),别人是用vq_embedding(x)。我的写法会把嵌入层排除在梯度计算外,嵌入层根本得不到优化。我说怎么换了一个嵌入层的初始化方法模型就根本训不动了。改完bug之后,只训了5个epoch,新模型的误差比原来训练数小时的模型要低了。新模型的重建效果非常好。

总算,任务完成了一半,现在只剩PixelCNN要训练了。我先尝试训练输入为128x128,嵌入数64的模型,采样结果很差。为了加快实验速度,我把输入尺寸减小到64x64,再次训练,采样结果还是不行。根据我之前的经验,PixelCNN的训练难度主要取决于类别数。于是,我把嵌入的数量从64改成了32,并大幅增加PixelCNN的参数量,再次训练。过了很久,训练误差终于降到0.08左右。我一测,这次的采样结果还不错。

这样看来,之前的采样效果不好,是输入128x128,嵌入数64的实验太难了。我毕竟只是想做一个demo,在一个小型实验上成功就行了,没必要花时间去做更耗时的实验。按理说,我应该就此收手。但是,我就是咽不下这一口气,就是想在128x128的实验上成功。我再次加大了PixelCNN的参数量,用128x128的配置,大火慢炖,训练了一天一夜。第二天一早起来,我看到这回的误差也降到了0.08。上次的实验误差降到这个程度时实验已经成功了。我迫不及待地去测试采样效果,却发现采样效果还是稀烂。没办法,我选择投降,开始写这篇文章,准备收工。

写到PixelCNN介绍的那一章节时,我正准备讲解代码。看到PixelCNN训练之前预处理除以color_level那一行时,我楞了一下:这行代码是用来做什么的来着?这段代码全是从PixelCNN项目里复制过来的。当时是做普通图片的图像生成,所以要对输入颜色做一个预处理,把整数颜色变成0~1之间的浮点数。但现在是在生成压缩图片,不能这样处理啊!我恍然大悟,知道是在处理离散输入时做错了。应该多加一个嵌入层,把离散值转换成向量。由于VQVAE的重点不在生成模型上,原论文根本没有强调PixelCNN在离散编码上的实现细节。网上几乎所有文章也都没谈这一点。因此,我在实现PixelCNN时,直接不假思索地把原来的代码搬了过来,根本没想过这种地方会出现bug。

把这处bug改完后,我再次开启训练。这下所有模型的采样结果都正常了。误差降到0.5左右就已经有不错的采样结果了,原来我之前把误差降到0.08完全是无用功。太气人了。

这次的实验让我学到了很多东西。首先是PyTorch编程上的一些注意事项:

  • 调用embedding.weight.data[x]是传不了梯度的。
  • 如果读数据时有费时的处理操作(读写硬盘、解码),要在Dataloader里设置num_workers

另外,在测试一个模型是否实现成功时有一个重要的准则:

  • 不要仅在简单的数据集(如MNIST)上测试。测试成功可能只是暴力拟合的结果。只有在一个难度较大的数据集上测试成功才能说模型没有问题。

在观察模型是否训成功时,还需要注意:

  • 训练误差降低不代表模型更优。训练误差的评价方法和模型实际使用方法可能完全不同。不能像我这样偷懒不加测试指标。

除了学到的东西外,我还有一些感想。在别人的项目的基础上修改、照着他人代码复现、完全自己动手从零开始写,对于深度学习项目来说,这三种实现方式的难度是依次递增的。改别人的项目,你可能去配置文件里改一两个数字就行了。而照着他人代码复现,最起码你能把代码改成和他人的代码一模一样,然后再去比较哪一块错了。自己动手写,则是有bug都找不到可以参考的地方了。说深度学习的算法难以调试,难就难在这里。效果不好,你很难说清是训练代码错了、超参数没设置好、训练流程错了,或是测试代码错了。可以出错的地方太多了,通常的代码调试手段难以用在深度学习项目上。

对于想要在深度学习上有所建树的初学者,我建议一定要从零动手复现项目。很多工程经验是难以总结的,只有踩了一遍坑才能知道。除了凭借经验外,还可以掌握一些特定的工程方法来减少bug的出现。比如运行训练之前先拿性能工具分析一遍,看看代码是否有误,是否可以提速;又比如可以训练几步后看所有可学习参数是否被正确修改。

2022年中旬,以扩散模型为核心的图像生成模型将AI绘画带入了大众的视野。实际上,在更早的一年之前,就有了一个能根据文字生成高清图片的模型——VQGAN。VQGAN不仅本身具有强大的图像生成能力,更是传承了前作VQVAE把图像压缩成离散编码的思想,推广了「先压缩,再生成」的两阶段图像生成思路,启发了无数后续工作。

VQGAN生成出的高清图片

在这篇文章中,我将对VQGAN的论文和源码中的关键部分做出解读,提炼出VQGAN中的关键知识点。由于VQGAN的核心思想和VQVAE如出一辙,我不会过多地介绍VQGAN的核心思想,强烈建议读者先去学懂VQVAE,再来看VQGAN。

VQGAN 核心思想

VQGAN的论文名为Taming Transformers for High-Resolution Image Synthesis,直译过来是「驯服Transformer模型以实现高清图像合成」。可以看出,该方法是在用Transformer生成图像。可是,为什么这个模型叫做VQGAN,是一个GAN呢?这是因为,VQGAN使用了两阶段的图像生成方法:

  • 训练时,先训练一个图像压缩模型(包括编码器和解码器两个子模型),再训练一个生成压缩图像的模型。
  • 生成时,先用第二个模型生成出一个压缩图像,再用第一个模型复原成真实图像。

其中,第一个图像压缩模型叫做VQGAN,第二个压缩图像生成模型是一个基于Transformer的模型。

为什么会有这种乍看起来非常麻烦的图像生成方法呢?要理解VQGAN的这种设计动机,有两条路线可以走。两条路线看待问题的角度不同,但实际上是在讲同一件事。

第一条路线是从Transformer入手。Transformer已经在文本生成领域大展身手。同时,Transformer也在视觉任务中开始崭露头角。相比擅长捕捉局部特征的CNN,Transformer的优势在于它能更好地融合图像的全局信息。可是,Transformer的自注意力操作开销太大,只能生成一些分辨率较低的图像。因此,作者认为,可以综合CNN和Transformer的优势,先用基于CNN的VQGAN把图像压缩成一个尺寸更小、信息更丰富的小图像,再用Transformer来生成小图像。

第二条路线是从VQVAE入手。VQVAE是VQGAN的前作,它有着和VQGAN一模一样两阶段图像生成方法。不同的是,VQVAE没有使用GAN结构,且其配套的压缩图像生成模型是基于CNN的。为提升VQVAE的生成效果,作者提出了两项改进策略:1) 图像压缩模型VQVAE仅使用了均方误差,压缩图像的复原结果较为模糊,可以把图像压缩模型换成GAN;2) 在生成压缩图片这个任务上,基于CNN的图像生成模型比不过Transformer,可以用Transformer代替原来的CNN。

第一条思路是作者在论文的引言中描述的,听起来比较高大上;而第二条思路是读者读过文章后能够自然总结出来的,相对来说比较清晰易懂。如果你已经理解了VQVAE,你能通过第二条思路瞬间弄懂VQGAN的原理。说难听点,VQGAN就是一个改进版的VQVAE。然而,VQGAN的改进非常有效,且使用了若干技巧来实现带约束(比如根据文字描述)的高清图像生成,有非常多地方值得学习。

在下文中,我将先补充VQVAE的背景以方便讨论,再介绍VQGAN论文的四大知识点:VQGAN的设计细节、生成压缩图像的Transformer的设计细节、带约束图像生成的实现方法、高清图像生成的实现方法。

VQVAE 背景知识补充

VQVAE的学习目标是用一个编码器把图像压缩成离散编码,再用一个解码器把图像尽可能地还原回原图像。

通俗来说,VQVAE就是把一幅真实图像压缩成一个小图像。这个小图像和真实图像有着一些相同的性质:小图像的取值和像素值(0-255的整数)一样,都是离散的;小图像依然是二维的,保留了某些空间信息。因此,VQVAE的示意图画成这样会更形象一些:

但小图像和真实图像有一个关键性的区别:与像素值不同,小图像的离散取值之间没有关联。真实图像的像素值其实是一个连续颜色的离散采样,相邻的颜色值也更加相似。比如颜色254和颜色253和颜色255比较相似。而小图像的取值之间是没有关联的,你不能说编码为1与编码为0和编码为2比较相似。由于神经网络不能很好地处理这种离散量,在实际实现中,编码并不是以整数表示的,而是以类似于NLP中的嵌入向量的形式表示的。VAE使用了嵌入空间(又称codebook)来完成整数序号到向量的转换。

为了让任意一个编码器输出向量都变成一个固定的嵌入向量,VQVAE采取了一种离散化策略:把每个输出向量$z_e(x)$替换成嵌入空间中最近的那个向量$z_q(x)$。$z_e(x)$的离散编码就是$z_q(x)$在嵌入空间的下标。这个过程和把254.9的输出颜色值离散化成255的整数颜色值的原理类似。

VQVAE的损失函数由两部分组成:重建误差和嵌入空间误差。

其中,重建误差就是输入和输出之间的均方误差。

嵌入空间误差为解码器输出向量$z_e(x)$和它在嵌入空间对应向量$z_q(x)$的均方误差。

作者在误差中还使用了一种「停止梯度」的技巧。这个技巧在VQGAN中被完全保留,此处就不过多介绍了。

图像压缩模型 VQGAN

回顾了VQVAE的背景知识后,我们来正式认识VQGAN的几个创新点。第一点,图像压缩模型VQVAE被改进成了VQGAN。

一般VAE重建出来出来的图像都会比较模糊。这是因为VAE只使用了均方误差,而均方误差只能保证像素值尽可能接近,却不能保证图像的感知效果更加接近。为此,作者把GAN的一些方法引入VQVAE,改造出了VQGAN。

具体来说,VQGAN有两项改进。第一,作者用感知误差(perceptual loss)代替原来的均方误差作为VQGAN的重建误差。第二,作者引入了GAN的对抗训练机制,加入了一个基于图块的判别器,把GAN误差加入了总误差。

计算感知误差的方法如下:把两幅图像分别输入VGG,取出中间某几层卷积层的特征,计算特征图像之间的均方误差。如果你之前没学过相关知识,请搜索”perceptual loss”。

基于图块的判别器,即判别器不为整幅图输出一个真或假的判断结果,而是把图像拆成若干图块,分别输出每个图块的判断结果,再对所有图块的判断结果取一个均值。这只是GAN的一种改进策略而已,没有对GAN本身做太大的改动。如果你之前没学过相关知识,请搜索”PatchGAN”。

这样,总的误差可以写成:

其中,$\lambda$是控制两种误差比例的权重。作者在论文中使用了一个公式来自适应地设置$\lambda$。和普通的GAN一样,VQGAN的编码器、解码器(即生成器)、codebook会最小化误差,判别器会最大化误差。

用VQGAN代替VQVAE后,重建图片中的模糊纹理清晰了很多。

有了一个保真度高的图像压缩模型,我们可以进入下一步,训练一个生成压缩图像的模型。

基于 Transformer 的压缩图像生成模型

如前所述,经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。因此,之前的VQVAE使用了一个能建模离散颜色的PixelCNN模型作为压缩图像生成模型。但PixelCNN的表现不够优秀。

恰好,功能强大的Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。

Transformer 随机生成句子的过程

为了让Transformer生成图像,我们可以把生成句子的一个个单词,变成生成压缩图像的一个个像素。但是,要让Transformer生成二维图像,还需要克服一个问题:在生成句子时,Transformer会先生成第一个单词,再根据第一个单词生成第二个单词,再根据第一、第二个单词生成第三个单词……。也就是说,Transformer每次会根据之前所有的单词来生成下一单词。而图像是二维数据,没有先后的概念,怎样让像素和文字一样有先后顺序呢?

VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第$i$步,Transformer会根据前$i - 1$个像素$s_{ < i}$生成第$i$个像素$s_i$,

带约束的图像生成

在生成新图像时,我们更希望模型能够根据我们的需求生成图像。比如,我们希望模型生成「一幅优美的风景画」,又或者希望模型在一幅草图的基础上作画。这些需求就是模型的约束。为了实现带约束的图像生成,一般的做法是先有一个无约束(输入是随机数)的图像生成模型,再在这个模型的基础上把一个表示约束的向量插入进图像生成的某一步。

把约束向量插入进模型的方法是需要设计的,插入约束向量的方法往往和模型架构有着密切关系。比如假设一个生成模型是U-Net架构,我们可以把约束向量和当前特征图拼接在一起,输入进U-Net的每一大层。

为了实现带约束的图像生成,VQGAN的作者再次借鉴了Transformer实现带约束文字生成的方法。许多自然语言处理任务都可以看成是带约束的文字生成。比如机器翻译,其实可以看成在给定一种语言的句子的前提下,让模型「随机」生成一个另一种语言的句子。比如要把「简要访问非洲」翻译成英语,我们可以对之前无约束文字生成的Transformer做一些修改。

也就是说,给定约束的句子$c$,在第$i$步,Transformer会根据前$i-1$个输出单词$s_{ < i}$以及$c$生成第$i$个单词$s_i$。表示约束的单词被添加到了所有输出之前,作为这次「随机生成」的额外输入。

上述方法并不是唯一的文字生成方法。这种文字生成方法被称为”decoder-only”。实际上,也有使用一个编码器来额外维护约束信息的文字生成方法。最早的Transformer就用到了带编码器的方法。

我们同样可以把这种思想搬到压缩图像生成里。比如对于MNIST数据集,我们希望模型只生成0~9这些数字中某一个数字的手写图像。也就是说,约束是类别信息,约束的取值是0~9。我们就可以把这个0~9的约束信息添加到Transformer的输入$s_{ < i}$之前,以实现由类别约束的图像生成。

但这种设计又会产生一个新的问题。假设约束条件不能简单地表示成整数,而是一些其他类型的数据,比如语义分割图像,那该怎么办呢?对于这种以图像形式表示的约束,作者的做法是,再训练另一个VQGAN,把约束图像压缩成另一套压缩图片。这一套压缩图片和生成图像的压缩图片有着不同的codebook,就像两种语言有着不同的单词一样。这样,约束图像也变成了一系列的整数,可以用之前的方法进行带约束图像生成了。

生成高清图像

由于Transformer注意力计算的开销很大,作者在所有配置中都只使用了$16 \times 16$的压缩图像,再增大压缩图像尺寸的话计算资源就不够了。而另一方面,每张图像在VQGAN中的压缩比例是有限的。如果图像压缩得过多,则VQGAN的重建质量就不够好了。因此,设边长压缩了$f$倍,则该方法一次能生成的图片的最大尺寸是$16f \times 16f$。在多项实验中,$f=16$的表现都较好。这样算下来,该方法一次只能生成$256 \times 256$的图片。这种尺寸的图片还称不上高清图片。

为了生成更大尺寸的图片,作者先训练好了一套能生成$256 \times 256$的图片的VQGAN+Transformer,再用了一种基于滑动窗口的采样机制来生成大图片。具体来说,作者把待生成图片划分成若干个$16\times16$像素的图块,每个图块对应压缩图像的一个像素。之后,在每一轮生成时,只有待生成图块周围的$16\times16$个图块($256\times256$个像素)会被输入进VQGAN和Transformer,由Transformer生成一个新的压缩图像像素,再把该压缩图像像素解码成图块。(在下面的示意图中,每个方块是一个图块,transformer的输入是$3\times3$个图块)

这个滑动窗口算法不是那么好理解,需要多想一下才能理解它的具体做法。在理解这个算法时,你可能会有这样的问题:上面的示意图中,待生成像素有的时候在最左边,有的时候在中间,有的时候在右边,每次约束它的像素都不一样。这么复杂的约束逻辑怎么编写?其实,Transformer自动保证了每个像素只会由之前的像素约束,而看不到后面的像素。因此,在实现时,只需要把待生成像素框起来,直接用Transformer预测待生成像素即可,不需要编写额外的约束逻辑。

如果你没有学过Transformer的话,理解这部分会有点困难。Transformer可以根据第1~k-1个像素并行地生成第2~k个像素,且保证生成每个像素时不会偷看到后面像素的信息。因此,假设我们要生成第i个像素,其实是预测了所有第2~k个像素的结果,再取出第i个结果,填回待生成图像。

由于论文篇幅有限,作者没有对滑动窗口机制做过多的介绍,也没有讲带约束的滑动窗口是怎么实现的。如果你在理解这一部分时碰到了问题,不用担心,这很正常。稍后我们会在代码阅读章节彻底理解滑动窗口的实现方法。我也是看了代码才看懂此处的做法。

作者在论文中解释了为什么用滑动窗口生成高清图像是合理的。作者先是讨论了两种情况,只要满足这两种情况中的任意一种,拿滑动窗口生成图像就是合理的。第一种情况是数据集的统计规律是几乎空间不变,也就是说训练集图片每$256\times256$个像素的统计规律是类似的。这和我们拿$3\times3$卷积卷图像是因为图像每$3\times3$个像素的统计规律类似的原理是一样的。第二种情况是有空间上的约束信息。比如之前提到的用语义分割图来指导图像生成。由于语义分割也是一张图片,它给每个待生成像素都提供了额外信息。这样,哪怕是用滑动窗口,在局部语义的指导下,模型也足以生成图像了。

若是两种情况都不满足呢?比如在对齐的人脸数据集上做无约束生成。在对齐的人脸数据集里,每张图片中人的五官所在的坐标是差不多的,图片的空间不变性不满足;做无约束生成,自然也没有额外的空间信息。在这种情况下,我们可以人为地添加一个坐标约束,即从左到右、从上到下地给每个像素标一个序号,把每个滑动窗口里的坐标序号做为约束。有了坐标约束后,就还原成了上面的第二种情况,每个像素有了额外的空间信息,基于滑动窗口的方法依然可行。

学完了论文的四大知识点,我们知道VQGAN是怎么根据约束生成高清图像的了。接下来,我们来看看论文的实验部分,看看作者是怎么证明方法的有效性的。

实验

在实验部分,作者先是分别验证了基于Transformer的压缩图像生成模型较以往模型的优越性(4.1节)、VQGAN较以往模型的优越性(4.4节末尾)、使用VQGAN做图像压缩的必要性及相关消融实验(4.3节),再把整个生成方法综合起来,在多项图像生成任务上与以往的图像生成模型做定量对比(4.4节),最后展示了该方法惊艳的带约束生成效果(4.2节)。

在论文4.1节中,作者验证了基于Transformer的压缩图像生成模型的有效性。之前,压缩图像都是使用能输出离散分布的PixelCNN系列模型来生成的。PixelCNN系列的最强模型是PixelSNAIL。为确保公平,作者对比了相同训练时间、相同训练步数下两个网络在不同训练集下的负对数似然(NLL)指标。结果表明,基于Transformer的模型确实训练得更快。

对于直接能建模离散分布的模型来说,NLL就是交叉熵损失函数。

在论文4.4节末尾,作者将VQGAN和之前的图像压缩模型对比,验证了引入感知误差和GAN结构的有效性。作者汇报了各模型重建图像集与原数据集(ImageNet的训练集和验证集)的FID(指标FID是越低越好)。同时,结果也说明,增大codebook的尺寸或者编码种类都能提升重建效果。

在论文4.3节中,作者验证了使用VQGAN的必要性。作者训了两个模型,一个直接让Transformer做真实图像生成,一个用VQGAN把图像边长压缩2倍,再用Transformer生成压缩图像。经比较,使用了VQGAN后,图像生成速度快了10多倍,且图像生成效果也有所提升。

另外,作者还做了有关图像边长压缩比例$f$的消融实验。作者固定让Transformer生成$16 \times 16$的压缩图片,即每次训练时用到的图像尺寸都是$16f \times 16f$。之后,作者训练训练了不同$f$下的模型,用各个模型来生成图片。结果显示$f=16$时效果最好。这是因为,在固定Transformer的生成分辨率的前提下,$f$越小,Transformer的感受野越小。如果Transformer的感受野过小,就学习不到足够的信息。

在论文4.4节中,作者探究了VQGAN+Transformer在多项基准测试(benchmark)上的结果。

首先是语义图像合成(根据语义分割图像来生成)任务。本文的这套方法还不错。

接着是人脸生成任务。这套方法表现还行,但还是比不过专精于某一任务的GAN。

作者还比较了各模型在ImageNet上的生成结果。这一比较的数据量较多,欢迎大家自行阅读原论文。

在论文4.2节中,作者展示了多才多艺的VQGAN+Transformer在各种约束下的图像生成结果。这些图像都是按照默认配置生成的,大小为$256\times256$。

作者还展示了使用了滑动窗口算法后,模型生成的不同分辨率的图像。

本文开头的那张高清图片也来自论文。

总结

VQGAN是一个改进版的VQVAE,它将感知误差和GAN引入了图像压缩模型,把压缩图像生成模型替换成了更强大的Transformer。相比纯种的GAN(如StyleGAN),VQGAN的强大之处在于它支持带约束的高清图像生成。VQGAN借助NLP中”decoder-only”策略实现了带约束图像生成,并使用滑动窗口机制实现了高清图像生成。虽然在某些特定任务上VQGAN还是落后于其他GAN,但VQGAN的泛化性和灵活性都要比纯种GAN要强。它的这些潜力直接促成了Stable Diffusion的诞生。

如果你是读完了VQVAE再来读的VQGAN,为了完全理解VQGAN,你只需要掌握本文提到的4个知识点:VQVAE到VQGAN的改进方法、使用Transformer做图像生成的方法、使用”decoder-only”策略做带约束图像生成的方法、用滑动滑动窗口生成任意尺寸的图片的思想。

代码阅读

在代码阅读章节中,我将先简略介绍官方源码的项目结构以方便大家学习,再介绍代码中的几处核心代码。具体来说,我会介绍模型是如何组织配置文件的、模型的定义代码在哪、训练代码在哪、采样代码在哪,同时我会主要分析VQGAN的结构、Transformer的结构、损失函数、滑动窗口采样算法这几部分的代码。

官方源码地址:https://github.com/CompVis/taming-transformers。

官方的Git仓库里有很多很大的图片,且git记录里还藏了一些很大的数据,整个Git仓库非常大。如果你的网络不好,建议以zip形式下载仓库,或者只把代码部分下载下来。

项目结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
├─assets
├─configs
├─scripts
└─taming
├─data
│ └─conditional_builder
├─models
└─modules
├─diffusionmodules
├─discriminator
├─losses
├─misc
├─transformer
└─vqvae

configs目录下存放的是模型配置文件。VQGAN和Transformer的模型配置是分开来放的。每个模型配置文件都会指向一个Python模型类,比如taming.models.vqgan.VQModel,配置里的参数就是模型类的初始化参数。我们可用通过阅读配置文件找到模型的定义位置。

运行脚本包括根目录下的main.pyscripts文件夹下的脚本。main.py是用于训练的。scripts文件夹下有各种采样脚本和数据集可视化脚本。

taming是源代码的主目录。其data子文件夹下放置了各数据集的预处理代码,models放置了VQGAN和Transformer PyTorch模型的定义代码,modules则放置了模型中用到的模块,主要包括VQGAN编码解码模块(diffusionmodules)、判别器模块(discriminator)、误差模块(losses)、Transformer模块(transformer)、codebook模块(vqvae)。

VQGAN 模型结构

打开configs\faceshq_vqgan.yaml,我们能够找到高清人脸生成任务使用的VQGAN模型配置。我们来学习一下这个模型的定义方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
model:
base_learning_rate: 4.5e-6
target: taming.models.vqgan.VQModel
params:
embed_dim: 256
n_embed: 1024
ddconfig:
...

lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
...

从配置文件的target字段中,我们知道VQGAN定义在模块taming.models.vqgan.VQModel中。我们可以打开taming\models\vqgan.py这个文件,查看其中VQModel类的代码。

首先先看一下初始化函数。初始化函数主要是初始化了encoderdecoderlossquantize这几个模块,我们可以从文件开头的import语句中找到这几个模块的定义位置。不过,先不急,我们来继续看一下模型的前向传播函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer

class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap, sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.image_key = image_key
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor

模型的前向传播逻辑非常清晰。self.encoder可以把一张图片变为特征,self.decoder可以把特征变回图片。self.quant_convpost_quant_conv则分别完成了编码器到codebook、codebook到解码器的通道数转换。self.quantize实现了VQVAE和VQGAN中那个找codebook里的最近邻、替换成最近邻的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info

def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec

def forward(self, input):
quant, diff, _ = self.encode(input)
dec = self.decode(quant)
return dec, diff

接下来,我们再看一看VQGAN的各个模块的定义。编码器和解码器的定义都可以在taming\modules\diffusionmodules\model.py里找到。VQGAN使用的编码器和解码器基于DDPM论文中的U-Net架构(而此架构又可以追溯到PixelCNN++的模型架构)。相比于最经典的U-Net,此U-Net每一层由若干个残差块和若干个自注意力块构成。为了把这个U-Net用到VQGAN里,U-Net的下采样部分和上采样部分被拆开,分别做成了VQGAN的编码器和解码器。

此处代码过长,我就只贴出部分关键代码了。以下是编码器的__init__函数和forward函数的关键代码。self.down存储了U-Net各层的模块。对于第i层,down[i].block是所有残差块,down[i].attn是所有自注意力块,down[i].downsample是下采样操作。它们在forward里会被依次调用。解码器的结构与之类似,只不过下采样变成了上采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, **ignore_kwargs):
super().__init__()
...
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)

...


def forward(self, x):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
...

return h

之后,我们再看看离散化层的代码,即把编码器的输出变成codebook里的嵌入的实现代码。作者在taming\modules\vqvae\quantize.py中提供了VQVAE原版的离散化操作以及若干个改进过的离散化操作。我们就来看一下原版的离散化模块VectorQuantizer是怎么实现的。

离散化模块的初始化非常简洁,主要是初始化了一个嵌入层。

1
2
3
4
5
6
7
8
9
10
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta

self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

在前向传播时,作者先是算出了编码器输出z和所有嵌入的距离d,再用argmin算出了最近邻嵌入的下标min_encodings,最后根据下标取出解码器输入z_q。同时,该函数还计算了其他几个可能用到的量,比如和codebook有关的误差 loss。注意,在计算lossz_q时,作者都使用到了停止梯度算子(.detach())。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def forward(self, z):
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())

## could possible replace this here
# #\start...
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)

min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)

z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
#.........\end


# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)

# preserve gradients
z_q = z + (z_q - z).detach()

# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()

return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

VQGAN的三个主要模块已经看完了。最后,我们来看一下误差的定义。误差的定义在taming\modules\losses\vqperceptual.pyVQLPIPSWithDiscriminator类里。误差类名里的LPIPS(Learned Perceptual Image Patch Similarity,学习感知图像块相似度)就是感知误差的全称,”WithDiscriminator”表示误差是带了判定器误差的。我们来把这两类误差分别看一下。

说实话,这个误差模块乱得一塌糊涂,一边自己在算误差,一边又维护了codebook误差和重建误差的权重,最后会把自己维护的两个误差和其他误差合在一起输出。功能全部耦合在一起。我们就跳过这个类的实现细节,主要关注self.perceptual_lossself.discriminator是怎么调用其他模块的。

1
2
3
4
5
6
7
8
9
10
from taming.modules.losses.lpips import LPIPS
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init

class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, ...):
super().__init__()

self.perceptual_loss = LPIPS().eval()

self.discriminator = NLayerDiscriminator...

感知误差模块在taming\modules\losses\vqperceptual.py文件里。这个文件来自GitHub项目 PerceptualSimilarity。

感知误差可以简单地理解为两张图片在VGG中几个卷积层输出的误差的加权和。加权的权重是可以学习的。作者使用的是已经学习好的感知误差。感知误差的初始化函数如下。其中,self.lin0等模块就是算权重的模块,self.net是VGG。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False

在算误差时,先是把图像inputtarget都输入进VGG,获取各层输出outs0, outs1,再求出两个图像的输出的均方误差diffs,最后用lins给各层误差加权,求和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val

GAN的判别器写在taming\modules\discriminator\model.py文件里。这个文件来自GitHub上的 pytorch-CycleGAN-and-pix2pix 项目。这个判别器非常简单,就是一个全卷积网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d

kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)

def forward(self, input):
"""Standard forward."""
return self.main(input)

Transformer 模型结构

此方法使用的Transformer是GPT2。我们先看一下该项目封装Transformer的模型类taming.models.cond_transformer.Net2NetTransformer,再稍微看一下GPT类taming.modules.transformer.mingpt.GPT的具体实现。

Net2NetTransformer主要是实现了论文中提到的带约束生成。它会把输入x和约束c分别用一个VQGAN转成压缩图像,把图像压扁成一维,再调用GPT。我们来看一下这个类的主要内容。

初始化函数主要是初始化了输入图像的VQGAN self.first_stage_model、约束图像的VQGAN self.cond_stage_model、Transformer self.transformer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Net2NetTransformer(pl.LightningModule):
def __init__(self,
transformer_config,
first_stage_config,
cond_stage_config,
permuter_config=None,
ckpt_path=None,
ignore_keys=[],
first_stage_key="image",
cond_stage_key="depth",
downsample_cond_size=-1,
pkeep=1.0,
sos_token=0,
unconditional=False,
):
super().__init__()
self.be_unconditional = unconditional
self.sos_token = sos_token
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
self.init_first_stage_from_ckpt(first_stage_config)
self.init_cond_stage_from_ckpt(cond_stage_config)
if permuter_config is None:
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
self.permuter = instantiate_from_config(config=permuter_config)
self.transformer = instantiate_from_config(config=transformer_config)

if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.downsample_cond_size = downsample_cond_size
self.pkeep = pkeep

def init_first_stage_from_ckpt(self, config):
model = instantiate_from_config(config)
model = model.eval()
model.train = disabled_train
self.first_stage_model = model

def init_cond_stage_from_ckpt(self, config):
...
self.cond_stage_model = ...

模型的前向传播函数如下。一开始,函数调用encode_to_zencode_to_c,根据self.cond_stage_modelself.first_stage_model把约束图像和输入图像编码成压扁至一维的压缩图像。之后函数做了一个类似Dropout的操作,根据self.pkeep随机替换掉约束编码。最后,函数把约束编码和输入编码拼接起来,使用通常方法调用Transformer。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def forward(self, x, c):
# one step to produce the logits
_, z_indices = self.encode_to_z(x)
_, c_indices = self.encode_to_c(c)

if self.training and self.pkeep < 1.0:
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
device=z_indices.device))
mask = mask.round().to(dtype=torch.int64)
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
a_indices = mask*z_indices+(1-mask)*r_indices
else:
a_indices = z_indices

cz_indices = torch.cat((c_indices, a_indices), dim=1)

# target includes all sequence elements (no need to handle first one
# differently because we are conditioning)
target = z_indices
# make the prediction
logits, _ = self.transformer(cz_indices[:, :-1])
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
logits = logits[:, c_indices.shape[1]-1:]

return logits, target

GPT2的结构不是本文的重点,我们就快速把模型结构过一遍了。GPT2的模型定义在taming.modules.transformer.mingpt.GPT里。GPT2的结构并不复杂,就是一个只有解码器的Transformer。前向传播时,数据先通过嵌入层self.tok_emb,再经过若干个Transformer模块self.blocks,最后过一个LayerNorm层self.ln_f和线性层self.head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class GPT(nn.Module):

def forward(self, idx, embeddings=None, targets=None):
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector

if embeddings is not None: # prepend explicit embeddings
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)

t = token_embeddings.shape[1]
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)

# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

return logits, loss

每个Transformer块就是非常经典的自注意力加全连接层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(), # nice
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)

def forward(self, x, layer_past=None, return_present=False):
# TODO: check that training still works
if return_present: assert not self.training
# layer past: tuple of length two with B, nh, T, hs
attn, present = self.attn(self.ln1(x), layer_past=layer_past)

x = x + attn
x = x + self.mlp(self.ln2(x))
if layer_past is not None or return_present:
return x, present
return x

基于滑动窗口的带约束图像生成

看完了所有模型的结构,我们最后来学习一下论文中没能详细介绍的滑动窗口算法。在scripts\taming-transformers.ipynb里有一个采样算法的最简实现,我们就来学习一下这份代码。

这份代码可以根据一幅语义分割图像来生成高清图像。一开始,代码会读入模型和语义分割图像。大致的代码为:

1
2
3
4
5
6
7
from taming.models.cond_transformer import Net2NetTransformer
model = Net2NetTransformer(**config.model.params)
from PIL import Image
import numpy as np
segmentation_path = "data/sflckr_segmentations/norway/25735082181_999927fe5a_b.png"
segmentation = Image.open(segmentation_path)
...


之后,代码把约束图像用对应的VQGAN编码进压缩空间,得到c_indices。由于待生成图像为空,我们可以随便生成一个待生成图像的压缩图像z_indices,代码中使用了randint初始化待生成的压缩图像。

1
2
3
4
5
6
7
8
c_code, c_indices = model.encode_to_c(segmentation)
z_indices = torch.randint(codebook_size, z_indices_shape, device=model.device)

idx = z_indices
idx = idx.reshape(z_code_shape[0],z_code_shape[2],z_code_shape[3])

cidx = c_indices
cidx = cidx.reshape(c_code.shape[0],c_code.shape[2],c_code.shape[3])

最后就是最关键的滑动窗口采样部分了。我们先稍微浏览一遍代码,再详细地一行一行看过去。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
temperature = 1.0
top_k = 100

for i in range(0, z_code_shape[2]-0):
if i <= 8:
local_i = i
elif z_code_shape[2]-i < 8:
local_i = 16-(z_code_shape[2]-i)
else:
local_i = 8
for j in range(0,z_code_shape[3]-0):
if j <= 8:
local_j = j
elif z_code_shape[3]-j < 8:
local_j = 16-(z_code_shape[3]-j)
else:
local_j = 8

i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16

patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

logits = logits/temperature

if top_k is not None:
logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

一开始的temperaturetop_k是得到logit后的采样参数,和滑动窗口算法无关。
1
2
temperature = 1.0
top_k = 100

进入生成图像循环后,i, j分别表示压缩图像的竖索引和横索引,i_start, i_end, j_start, j_end是滑动窗口上下左右边界。

1
2
3
4
5
6
7
8
for i in range(0, z_code_shape[2]-0):
...
for j in range(0,z_code_shape[3]-0):
...
i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16

为了获取这四个滑动窗口的范围,代码用了若干条件语句计算待生成像素在滑动窗口里的相对位置local_i, local_j

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for i in range(0, z_code_shape[2]-0):
if i <= 8:
local_i = i
elif z_code_shape[2]-i < 8:
local_i = 16-(z_code_shape[2]-i)
else:
local_i = 8
for j in range(0,z_code_shape[3]-0):
if j <= 8:
local_j = j
elif z_code_shape[3]-j < 8:
local_j = 16-(z_code_shape[3]-j)
else:
local_j = 8

得到了滑动窗口的边界后,代码用滑动窗口从约束图像的压缩图像和待生成图像的压缩图像上各取出一个图块,并拼接起来。

1
2
3
4
5
patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)

之后,只需要把拼接的图块直接输入进Transformer,得到输出logits,再用local_i,local_j去输出图块的对应位置取出下一个压缩图像像素的概率分布,就可以随机生成下一个压缩图像像素了。如前文所述,Transformer类会把二维的图块压扁到一维,输入进GPT。同时,GPT会自动保证前面的像素看不到后面的像素,我们不需要人为地指定约束像素。这个地方的调用逻辑其实非常简单。

1
2
3
4
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

最后只要从logits里采样,把采样出的压缩图像像素填入idx,就完成了一步生成。

1
2
3
4
5
6
7
logits = logits/temperature

if top_k is not None:
logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

反复执行循环,就能将压缩图像生成完毕。最后将压缩图像过一遍VQGAN的解码器即可得到最终的生成图像。

1
2
x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

参考资料

VQGAN论文:https://arxiv.org/abs/2012.09841

VQGAN GitHub:https://github.com/CompVis/taming-transformers

如果你需要补充学习早期工作,欢迎阅读我之前的文章。

Transformer解读

PixelCNN解读

VQVAE解读