0%

近期,最受开源社区欢迎的文生图模型 Stable Diffusion 的最新版本 Stable Diffusion 3 开放了源码和模型参数。开发者宣称,Stable Diffusion 3 使用了全新的模型结构和文本编码方法,能够生成更符合文本描述且高质量的图片。得知 Stable Diffusion 3 开源后,社区用户们纷纷上手测试,在网上分享了许多测试结果。而在本文中,我将面向之前已经熟悉 Stable Diffusion 的科研人员,快速讲解 Stable Diffusion 3 论文的主要内容及其在 Diffusers 中的源码。对于 Stable Diffusion 3 中的一些新技术,我并不会介绍其细节,而是会讲清其设计动机并指明进一步学习的参考文献。

内容索引

本文会从多个角度简单介绍 SD3,具体要介绍的方面如下所示。读者可以根据自己的需求,跳转到感兴趣的部分阅读。

流匹配原理简介

流匹配是一种定义图像生成目标的方法,它可以兼容当前扩散模型的训练目标。流匹配中一个有代表性的工作是整流 (rectified flow),它也正是 SD3 用到的训练目标。我们会在本文中通过简单的可视化示例学习流匹配的思想。

SD3 中的 DiT

我们会从一个简单的类 ViT 架构开始,学习 SD3 中的去噪网络 DiT 模型是怎么一步一步搭起来的。读者不需要提前学过 DiT,只需要了解 Transformer 的结构,并大概知道视觉任务里的 Transformer 会做哪些通用的修改(如图块化),即可学懂 SD3 里的 DiT。

SD3 模型与训练策略改进细节

除了将去噪网络从 U-Net 改成 DiT 外,SD3 还在模型结构与训练策略上做了很多小改进:

  • 改变训练时噪声采样方法
  • 将一维位置编码改成二维位置编码
  • 提升 VAE 隐空间通道数
  • 对注意力 QK 做归一化以确保高分辨率下训练稳定

本文会简单介绍这些改进。

大型消融实验

对于想训练大型文生图模型的开发者,SD3 论文提供了许多极有价值的大型消融实验结果。本文会简单分析论文中的两项实验结果:各训练目标在文生图任务中的表现、SD3 的参数扩增实验结果。

SD3 Diffusers 源码解读

本文会介绍如何配置 Diffusers 环境以用代码运行 SD3,并简单介绍相比于 SD,SD3 的采样代码和模型代码有哪些变动。

论文阅读

核心贡献

介绍 Stable Diffusion 3 (SD3) 的文章标题为 Scaling Rectified Flow Transformers for High-Resolution Image Synthesis。与其说它是一篇技术报告,更不如说它是一篇论文,因为它确实是按照撰写学术论文的一般思路,将正文的叙述重点放到了方法的核心创新点上,而没有过多叙述工程细节。正如其标题所示,这篇文章的内容很简明,就是用整流 (rectified flow) 生成模型、Transformer 神经网络做了模型参数扩增实验,以实现高质量文生图大模型。

由于这是一篇实验主导而非思考主导的文章,论文的开头没有太多有价值的内容。从我们读者学习论文的角度,文章的核心贡献如下:

从方法设计上:

  • 首次在大型文生图模型上使用了整流模型。
  • 用一种新颖的 Diffusion Transformer (DiT) 神经网络来更好地融合文本信息。
  • 使用了各种小设计来提升模型的能力。如使用二维位置编码来实现任意分辨率的图像生成。

从实验上:

  • 开展了一场大规模、系统性的实验,以验证哪种扩散模型/整流模型的学习目标最优。
  • 开展了扩增模型参数的实验 (scaling study),以证明提升参数量能提升模型的效果。

整流模型简介

由于 SD3 最后用了整流模型来建模图像生成,所以文章是从一种称为流匹配 (Flow Matching) 的角度而非更常见的扩散模型的角度来介绍各种训练目标。鉴于 SD3 并没有对其他论文中提出的整流模型做太多更改,我们在阅读本文时可以主要关注整流的想法及其与扩散模型的关系,后续再从其他论文中学习整流的具体原理。在此,我们来大致认识一下流匹配与整流的想法。

所谓图像生成,其实就是让神经网络模型学习一个图像数据集所表示的分布,之后从分布里随机采样。比如我们想让模型生成人脸图像,就是要让模型学习一个人脸图像集的分布。为了直观理解,我们可以用二维点来表示一张图像的数据。比如在下图中我们希望学习红点表示的分布,即我们希望随机生成点,生成的点都落在红点处,而不是落在灰点处。

我们很难表示出一个适合采样的复杂分布。因此,我们会把学习一个分布的问题转换成学习一个简单好采样的分布到复杂分布的映射。一般这个简单分布都是标准正态分布。如下图所示,我们可以用简单的算法采样在原点附近的来自标准正态分布的蓝点,我们要想办法得到蓝点到红点的映射方法。

学习这种映射依然是很困难的。而近年来包括扩散模型在内的几类生成模型用一种巧妙的方法来学习这种映射:从纯噪声(标准正态分布里的数据)到真实数据的映射很难表示,但从真实数据到纯噪声的逆映射很容易表示。所以,我们先人工定义从图像数据集到噪声的变换路线(红线),再让模型学习逆路线(蓝线)。让噪声数据沿着逆路线走,就实现了图像生成。

我们又可以用一种巧妙的方法间接学习图像生成路线。知道了预定义的数据到噪声的路线后,我们其实就知道了数据在路线上每一位置的速度(红箭头)。那么,我们可以以每一位置的反向速度(蓝箭头)为真值,学习噪声到真实数据的速度场。这样的学习目标被称为流匹配。

对于不同的扩散模型及流匹配模型,其本质区别在于图像到噪声的路线的定义方式。在扩散模型中,图像到噪声的路线是由一个复杂的公式表示的。而整流模型将图像到噪声的路线定义为了直线。比如根据论文的介绍,整流中 $t$ 时刻数据 $z_t$ 由真实图像 $x_0$ 变换成纯噪声 $\epsilon$ 的位置为:

而较先进的扩散模型 EDM 提出的路线公式为($b_t$ 是一个形式较为复杂的变量):

由于整流最后学习出来的生成路线近乎是直线,这种模型在设计上就支持少步数生成。

虽然整流模型是这样宣传的,但实际上 SD3 还是默认用了 28 步来生成图像。单看这篇文章,原整流论文里的很多设计并没有用上。对整流感兴趣的话,可以去阅读原论文 Flow straight and fast: Learning to generate and transfer data with rectified flow

流匹配模型和扩散模型的另一个区别是,流匹配模型天然支持 image2image 任务。从纯噪声中生成图像只是流匹配模型的一个特例。

非均匀训练噪声采样

在学习这样一种生成模型时,会先随机采样一个时刻 $t \in [0, 1]$,根据公式获取此时刻对应位置在生成路线上的速度,再让神经网络学习这个速度。直观上看,刚开始和快到终点的路线很好学,而路线的中间处比较难学。因此,在采样时刻 $t$ 时,SD3 使用了一种非均匀采样分布。

如下图所示,SD3 主要考虑了两种公式: mode(左)和 logit-norm (右)。二者的共同点是中间多,两边少。mode 相比 logit-norm,在开始和结束时概率不会过分接近 0。

网络整体架构

以上内容都是和训练相关的理论基础,下面我们来看多数用户更加熟悉的文生图架构。

从整体架构上来看,和之前的 SD 一样,SD3 主要基于隐扩散模型(latent diffusion model, LDM)。这套方法是一个两阶段的生成方法:先用一个 LDM 生成隐空间低分辨率的图像,再用一个自编码器把图像解码回真实图像。

扩散模型 LDM 会使用一个神经网络模型来对噪声图像去噪。为了实现文生图,该去噪网络会以输入文本为额外约束。相比之前多数扩散模型,SD3 的主要改进是把去噪模型的结构从 U-Net 变为了 DiT。

DiT 的论文为 Scalable Diffusion Models with Transformers。如果只是对 DiT 的结构感兴趣的话,可以去直接通过读 SD3 的源码来学习。读 DiT 论文时只需要着重学习 AdaLayerNormZero 模块。

提升自编码器通道数

在当时设计整套自编码器 + LDM 的生成架构时,SD 的开发者并没有仔细改进自编码器,用了一个能把图像下采样 8 倍,通道数变为 4 的隐空间图像。比如输入 $512 \times 512 \times 3$ 的图像会被自编码器编码成 $64 \times 64 \times 4$。而近期有些工作发现,这个自编码器不够好,提升隐空间的通道数能够提升自编码器的重建效果。因此,SD3 把隐空间图像的通道数从 $4$ 改为了 $16$。

多模态 DiT (MM-DiT)

SD3 的去噪模型是一个 Diffusion Transformer (DiT)。如果去噪模型只有带噪图像这一种输入的话,DiT 则会是一个结构非常简单的模型,和标准 ViT 一样:图像过图块化层 (Patching) 并与位置编码相加,得到序列化的数据。这些数据会像标准 Transformer 一样,经过若干个子模块,再过反图块层得到模型输出。DiT 的每个子模块 DiT-Block 和标准 Transformer 块一样,由 LayerNorm, Self-Attention, 一对一线性层 (Pointwise Feedforward, FF) 等模块构成。

图块化层会把 $2\times2$ 个像素打包成图块,反图块化层则会把图块还原回像素。

然而,扩散模型中的去噪网络一定得支持带约束生成。这是因为扩散模型约束于去噪时刻 $t$。此外,作为文生图模型,SD3 还得支持文本约束。DiT 及本文的 MM-DiT 把模型设计的重点都放在了处理额外约束上。

我们先看一下模块是怎么处理较简单的时刻约束的。此处,如下图所示,SD3 的模块保留了 DiT 的设计,用自适应 LayerNorm (Adaptive LayerNorm, AdaLN) 来引入额外约束。具体来说,过了 LayerNorm 后,数据的均值、方差会根据时刻约束做调整。另外,过完 Attention 层或 FF 层后,数据也会乘上一个和约束相关的系数。

我们再来看文本约束的处理。文本约束以两种方式输入进模型:与时刻编码拼接、在注意力层中融合。具体数据关联细节可参见下图。如图所示,为了提高 SD3 的文本理解能力,描述文本 (“Caption”) 经由三种编码器编码,得到两组数据。一组较短的数据会经由 MLP 与文本编码加到一起;另一组数据会经过线性层,输入进 Transformer 的主模块中。

将约束编码与时刻编码相加是一种很常见的做法。此前 U-Net 去噪网络中处理简单约束(如 ImageNet 类型约束)就是用这种方法。

SD3 的 DiT 的子模块结构图如下所示。我们可以分几部分来看它。先看时刻编码 $y$ 的那些分支。和标准 DiT 子模块一样,$y$ 通过修改 LayerNorm 后数据的均值、方差及部分层后的数据大小来实现约束。再看输入的图像编码 $x$ 和文本编码 $c$。二者以相同的方式做了 DiT 里的 LayerNorm, FF 等操作。不过,相比此前多数基于 DiT 的模型,此模块用了一种特殊的融合注意力层。具体来说,在过注意力层之前,$x$ 和 $c$ 对应的 $Q, K, V$ 会分别拼接到一起,而不是像之前的模型一样,$Q$ 来自图像,$K, V$ 来自文本。过完注意力层,输出的数据会再次拆开,回到原本的独立分支里。由于 Transformer 同时处理了文本、图像的多模态信息,所以作者将模型取名为 MM-DiT (Multimodal DiT)。

论文里讲:「这个结构可以等价于两个模态各有一个 Transformer,但是在注意力操作时做了拼接,使得两种表示既可以在独自的空间里工作也可以考虑到另一个表示。」然而,我不太喜欢这种尝试去凭空解读神经网络中间表示的表述。仅从数据来源来看,过了一个注意力层后,图像信息和文本信息就混在了一起。你很难说,也很难测量,之后的 $x$ 主要是图像信息,$c$ 主要是文本信息。只能说 $x, c$ 都蕴含了多模态的信息。之前 SD U-Net 里的 $x, c$ 可以认为是分别包含了图像信息和文本信息,因为之前的 $x$ 保留了二维图像结构,而 $c$ 仅由文本信息决定。

比例可变的位置编码

此前多数方法在使用类 ViT 架构时,都会把图像的图块从左上到右下编号,把二维图块拆成一维序列,再用这种一维位置编码来对待图块。

这样做有一个很大的坏处:生成的图像的分辨率是无法修改的。比如对于上图,假如采样时输入大小不是 $4 \times 4$,而是 $4 \times 5$,那么 $0$ 号图块的下面就是 $5$ 而不是 $4$ 了,模型训练时学习到的图块之间的位置关系全部乱套。

解决此问题的方法很简单,只需要将一维的编码改为二维编码。这样 Transformer 就不会搞混二维图块间的关系了。

SD3 的 MM-DiT 一开始是在 $256^2$ 固定分辨率上训练的。之后在高分辨率图像上训练时,开发者用了一些巧妙的位置编码设置技巧,让不同比例的高分辨率图像也能共享之前学到的这套位置编码。详细公式请参见原论文。

训练数据预处理

看完了模块设计,我们再来看一下 SD3 在训练中的一些额外设计。在大规模训练前,开发者用三个方式过滤了数据:

  1. 用了一个 NSFW 过滤器过滤图片,似乎主要是为了过滤色情内容。
  2. 用美学打分器过滤了美学分数太低的图片。
  3. 移除了看上去语义差不多的图片。

虽然开发者们自信满满地向大家介绍了这些数据过滤技术,但根据社区用户们的反馈,可能正是因为色情过滤器过分严格,导致 SD3 经常会生成奇怪的人体。

由于在训练 LDM 时,自编码器和文本编码器是不变的,因此可以提前处理好所有训练数据的图像编码和文本编码。当然,这是一项非常基础的工程技巧,不应该写在正文里的。

用 QK 归一化提升训练稳定度

按照之前高分辨率文生图模型的训练方法,SD3 会先在 $256^2$ 的图片上训练,再在高分辨率图片上微调。然而,开发者发现,开始微调后,混合精度训练常常会训崩。根据之前工作的经验,这是由于注意力输入的熵会不受控制地增长。解决方法也很简单,只要在做注意力计算之前对 Q, K 做一次归一化就行,具体做计算的位置可以参考上文模块图中的 “RMSNorm”。不过,开发者也承认,这个技巧并不是一个长久之策,得具体问题具体分析。看来这种 DiT 模型在大规模训练时还是会碰到许多训练不稳定的问题,且这些问题没有一个通用解。

哪种扩散模型训练目标最适合文生图任务?

最后我们来看论文的实验结果部分。首先,为了寻找最好的扩散模型/流匹配模型,开发者开展了一场声势浩大的实验。实验涉及 61 种训练公式,其中的可变项有:

  • 对于普通扩散模型,考虑 $\epsilon$- 或 $\mathbf{v}$-prediction,考虑线性或 cosine 噪声调度。
  • 对于整流,考虑不同的噪声调度。
  • 对于 EDM,考虑不同的噪声调度,且尽可能与整流的调度机制相近以保证可比较。

在训练时,除了训练目标公式可变外,优化算法、模型架构、数据集、采样器都不可变。所有模型在 ImageNet 和 CC12M 数据集上训练,在 COCO-2014 验证集上评估 FID 和 CLIP Score。根据评估结果,可以选出每个模型的最优停止训练的步数。基于每种目标下的最优模型,开发者对模型进行最后的排名。由于在最终评估时,仍有采样步数、是否使用 EMA 模型等可变采样配置,开发者在所有 24 种采样配置下评估了所有模型,并用一种算法来综合所有采样配置的结果,得到一个所有模型的最终排名。最终的排名结果如下面的表 1 所示。训练集上的一些指标如表 2 所示。

根据实验结果,我们可以得到一些直观的结论:整流领先于扩散模型。惊人的是,较新推出的 EDM 竟然没有战胜早期的 LDM (“eps/linear”)。

当然,我个人认为,应该谨慎看待这份实验结果。一般来说,大家做图像生成会用一个统一的指标,比如 ImageNet 上的 FID。这篇论文相当于是新提出了一种昂贵的评价方法。这种评价方法是否合理,是否能得到公认还犹未可知。另外,想说明一个生成模型的拟合能力不错,用 ImageNet 上的 FID 指标就足够有说服力了,大家不会对一个简单的生成模型有太多要求。然而,对于大型文生图模型,大家更关心的是模型的生成效果,而 FID 和 CLIP Score 并不能直接反映文生图模型的质量。因此,光凭这份实验结果,我们并不能说整流一定比之前的扩散模型要好。

会关注这份实验结果的应该都是公司里的文生图开发者。我建议体量小的公司直接参考这份实验结果,无脑使用整流来代替之前的训练目标。而如果有能力做同等级的实验的话,则不应该错过改良后的扩散模型,如最新的 EDM2,说不定以后还会有更好的文生图训练目标。

参数扩增实验结果

现在多数生成模型都会做参数扩增实验,即验证模型表现随参数量增长而增长,确保模型在资源足够的情况下可以被训练成「大模型」。SD3 也做了类似的实验。开发者用参数 $d$ 来控制 MM-DiT 的大小,Transformer 块的个数为 $d$,且所有特征的通道数与 $d$ 成正比。开发者在 $256^2$ 的数据上训练了所有模型 500k 步,每 50k 步在 CoCo 数据集上统计验证误差。最终所有评估指标如下图所示。可以说,所有指标都表明,模型的表现的确随参数量增长而增长。更多结果请参见论文。

Diffusers 源码阅读

测试脚本

我们来阅读一下 SD3 在最流行的扩散模型框架 Diffusers 中的源码。在读源码前,我们先来跑通官方的示例脚本。

由于使用协议的限制,SD3 的环境搭起来稍微有点麻烦。首先,我们要确保 Diffuers 和 Transformers 都用的是最新版本。

1
pip install --upgrade diffusers transformers

之后,我们要注册 HuggingFace 账号,再在 SD3 的模型网站 https://huggingface.co/stabilityai/stable-diffusion-3-medium 里确认同意某些使用协议。之后,我们要设置 Access Token。具体操作如下所示,先点右上角的 “settings”,再点左边的 “Access Tokens”,创建一个新 token。将这个 token 复制保存在本地后,点击 token 右上角选项里的 “Edit Permission”,在权限里开启 “… public gated repos …”。

最后,我们用命令行登录 HuggingFace 并使用 SD3。先用下面的命令安装 HuggingFace 命令行版。

1
pip install -U "huggingface_hub[cli]"

再输入 huggingface-cli login,命令行会提示输入 token 信息。把刚刚保存好的 token 粘贴进去,即可完成登录。

1
2
3
huggingface-cli login

Enter your token (input will not be visible): 在这里粘贴 token

做完准备后,我们就可以执行下面的测试脚本了。注意,该脚本会自动下载模型,我们需要保证当前环境能够访问 HuggingFace。执行完毕后,生成的 $1024 \times 1024$ 大小的图片会保存在 tmp.png 里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

image = pipe(
"A cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
guidance_scale=7.0,
).images[0]

image.save('tmp.png')

我得到的图片如下所示。看起来 SD3 理解文本的能力还是挺强的。

模型组件

接下来我们来快速浏览一下 SD3 流水线 StableDiffusion3Pipeline 的源码。在 IDE 里使用源码跳转功能可以在 diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py 里找到该类的源码。

通过流水线的 __init__ 方法,我们能知道 SD3 的所有组件。组件包括自编码器 vae, MM-DiT Transformer, 流匹配噪声调度器 scheduler,以及三个文本编码器。每个编码器由一个 tokenizer 和一个 text encoder 组成.

1
2
3
4
5
6
7
8
9
10
11
12
def __init__(
self,
transformer: SD3Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
):

vae 的用法和之前 SD 的一模一样,编码时用 vae.encode 并乘 vae.config.scaling_factor,解码时除以 vae.config.scaling_factor 并用 vae.decode

文本编码器的用法可以参见 encode_prompt 方法。文本会分别过各个编码器的 tokenizer 和 text encoder,得到三种文本编码,并按照论文中的描述拼接成两种约束信息。这部分代码十分繁杂,多数代码都是在处理数据形状,没有太多有价值的内容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def encode_prompt(
self,
prompt,
prompt_2,
prompt_3,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
negative_prompt_2,
negative_prompt_3,
...

):
...

return prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds

采样流水线

我们再来通过阅读流水线的 __call__ 方法了解 SD3 采样的过程。由于 SD3 并没有修改 LDM 的这套生成框架,其采样流水线和 SD 几乎完全一致。SD3 和 SD 的 __call__ 方法的主要区别是,生成文本编码时会生成两种编码。

1
2
3
4
5
6
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(...)

在调用去噪网络时,那个较小的文本编码 pooled_prompt_embeds 会作为一个额外参数输入。

1
2
3
4
5
6
7
8
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

MM-DiT 去噪模型

相比之下,SD3 的去噪网络 MM-DiT 的改动较大。我们来看一下对应的 SD3Transformer2DModel 类,它位于文件 diffusers\models\transformers\transformer_sd3.py

类的构造函数里有几个值得关注的模块:二维位置编码类 PatchEmbed、组合时刻编码和文本编码模块 CombinedTimestepTextProjEmbeddings、主模块类 JointTransformerBlock

1
2
3
4
5
6
7
8
9
10
11
def __init__(...):
...
self.pos_embed = PatchEmbed(...)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(...)
...
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(..)
for i in range(self.config.num_layers)
]
)

类的前向传播函数 forward 里都是比较常规的操作。数据会依次经过前处理、若干个 Transformer 块、后处理。所有实现细节都封装在各个模块类里。

1
2
3
4
5
6
7
8
9
10
11
def forward(...):
hidden_states = self.pos_embed(hidden_states)
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(...)

encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
...

接下来我们来看这几个较为重要的子模块。PatchEmbed 类的实现写在 diffusers/models/embeddings.py 里。这个类的实现写得非常清晰。PatchEmbed 类本身用于维护位置编码宽高、特征长度这些信息,计算位置编码的关键代码在 get_2d_sincos_pos_embed 中。get_2d_sincos_pos_embed 会生成 (0, 0), (1, 0), ... 这样的二维坐标网格,再调用 get_2d_sincos_pos_embed_from_grid 生成二维位置编码。get_2d_sincos_pos_embed_from_grid 会调用两次一维位置编码函数 get_1d_sincos_pos_embed_from_grid,也就是 Transformer 里那种标准位置编码生成函数,来分别生成两个方向的编码,最后拼接成二维位置编码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class PatchEmbed(nn.Module):
...
def forward(self, latent):
...
pos_embed = get_2d_sincos_pos_embed(...)

def get_2d_sincos_pos_embed(...):
grid_h = np.arange(...)
grid_w = np.arange(...)
grid = np.meshgrid(grid_w, grid_h)
...
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

def get_2d_sincos_pos_embed_from_grid(...):
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb

组合时刻编码和文本编码模块 CombinedTimestepTextProjEmbeddings 的代码非常短。它实际上就是用通常的 Timesteps 类获取时刻编码,用一个 text_embedder 模块再次处理文本编码,最后把两个编码加起来。
text_embedder 是一个线性层、激活函数、线性层构成的简单模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)

pooled_projections = self.text_embedder(pooled_projection)

conditioning = timesteps_emb + pooled_projections

return conditioning

class PixArtAlphaTextProjection(nn.Module):
def __init__(...):
...

def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states

MM-DiT 的主要模块 JointTransformerBlockdiffusers/models/attention.py 文件里。这个类的代码写得比较乱。它主要负责处理 LayerNorm 及数据的尺度变换操作,具体的注意力计算由注意力处理器 JointAttnProcessor2_0 负责。两处 LayerNorm 的实现方式竟然是不一样的。

我们先简单看一下构造函数里初始化了哪些模块。代码中,norm1, ff, norm2 等模块都是普通 Transformer 块中的模块。而加了 _context 的模块则表示处理文本分支 $c$ 的模块,如 norm1_context, ff_contextcontext_pre_only 表示做完了注意力计算后,还要不要给文本分支加上 LayerNorm 和 FeedForward。如前文所述,具体的注意力计算由 JointAttnProcessor2_0 负责。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class JointTransformerBlock(nn.Module):

def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
super().__init__()

self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"

self.norm1 = AdaLayerNormZero(dim)

if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim)

processor = JointAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
)

self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

if not context_pre_only:
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
else:
self.norm2_context = None
self.ff_context = None

我们再来看 forward 方法。在前向传播时,图像分支和文本分支会分别过 norm1,再一起过注意力操作,再分别过 norm2ff。大概的代码如下所示,我把较复杂的 context 分支的代码略过了。

这份代码写得很不漂亮,按理说模块里两个 LayerNorm + 尺度变换 (即 Adaptive LayerNorm) 的操作是一样的,应该用同样的代码来处理。但是这个模块里 norm1AdaLayerNormZero 类,norm2LayerNorm 类。norm1 会自动做完 AdaLayerNorm 的运算,并把相关变量返回。而在 norm2 处,代码会先执行普通的 LayerNorm,再根据之前的变量手动调整数据的尺度。我们心里知道这份代码是在实现论文里那张结构图就好,没必要去仔细阅读。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

if self.context_pre_only:
...

# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)

# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = hidden_states + ff_output
if self.context_pre_only:
...

return encoder_hidden_states, hidden_states

融合注意力的实现方法很简单。和普通的注意力计算相比,这种注意力就是把另一条数据分支 encoder_hidden_states 也做了 QKV 的线性变换,并在做注意力运算前与原来的 QKV 拼接起来。做完注意力运算后,两个数据又会拆分回去。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class JointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""


def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
...

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

...

# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)

总结

在这篇文章中,我们学习了 SD3 论文及源码中的主要内容。相比于 SD,SD3 做了两项较大的改进:用整流代替原来的 DDPM 中的训练目标;将去噪模型从 U-Net 变成了能更好地处理多模态信息的 MM-DiT。SD3 还在模型结构和训练目标上做了许多小改进,如调整训练噪声采样分布、使用二维位置编码。SD3 论文展示了多项大型消融实验的结果,证明当前的 SD3 是以最优配置训练得到的。SD3 可以在 Diffusers 中使用。当然,由于 SD3 的使用协议较为严格,我们需要做一些配置,才能在代码中使用 SD3。SD3 的采样流水线基本没变,原来 SD 的多数编辑方法能够无缝迁移过来。而 SD3 的去噪模型变动较大,和 U-Net 相关的编辑方法则无法直接用过来。在学习源码时,主要值得学习的是新 MM-DiT 模型中每个 Transformer 层的实现细节。

尽管 SD3 并没有提出新的流匹配方法,但其实验结果表明流匹配模型可能更适合文生图任务。作为研究者,受此启发,我们或许需要关注一下整流等流匹配模型,知道它们的思想,分析它们与原扩散模型训练目标的异同,以拓宽自己的视野。

今年第一次参加 CVPR 的前后发生了很多糟心事。我被气得受不了了。


美国签证不太好申,我周围所有因参加学术会议而申请签证的人都被审查了一个月。我决定参会与申请签证的时间较晚,过签的时间非常仅限。在会议开始前的倒数第二周,我收到了签证通过的通知,并得知我需要在一周后,也就是会议开始前的最后几个工作日领取带了美国签证的护照。能够恰好拿到签证,算得上是运气不错。可下一周发生的种种事情,却谈不上了幸运了。

会议开始前一周的周三早上,我赶早骑行前往学校参加会议前的最后一次周会。骑至半途,忽然天降暴雨。在新加坡,这种现象我已经见怪不怪了。我先穿上日常携带的雨衣,再将书包背到雨衣外,以免背上出太多汗。我又用手机向导师报告我可能会推迟参会,随后将手机和浸满了雨水的眼镜装进口袋。

这场雨就像老天给我开的一个玩笑。我一到目的地,雨就小了下来。上楼,掀开笔记本电脑,发现电脑被穿透了书包的一点雨水浸湿,开不了机。伸手一摸口袋,眼镜竟然也不见了。

这种雨中骑车的事我已经做过好多回了,但我从来没碰到过这么多意外。我有多次电脑进水的经历,可这是我第一次知道穿透书包的那一点水也能弄坏电脑。另外,本来我这条裤子的口袋是不会掉东西的,可这次恰好我要用手机发消息,而眼镜和手机放在了一起,顺势掉了出来。

无论多么令人懊恼,生活都得继续。我用上了我所有的规划能力来处理问题。我先和导师商议将周会延期。随后,我请同学陪我一起找眼镜。还好,眼镜在楼下就找到了,但是被我自己或者其他人踩了一脚,一边的镜框断了。我艰难地用半副眼镜处理完一天的杂事,于傍晚赶往市中心将电脑送修。晚上回家,我用强力胶将断裂的镜框部件粘在一起,眼镜勉强可以用了。第二天周四,我领完美国签证,同时得知电脑严重进水,一天修不好。我回家立刻拿出老电脑并配好工作环境。周五,我准备好一切出国开会所需资料,总算安顿好了一切。

周六,出发前一天,我去换美元。付款时要输入储蓄卡密码,我多次尝试也没能把密码输对,卡被冻结。这时我才想起来为什么我会不记得卡的密码:这张卡的密码是系统随机设置的,平时用卡根本不需要密码,所以我既忘了密码,也忘了密码不是我自己设的。这下去美国用不了这张卡了。还好,我还有国内的信用卡,完成了在新加坡后续的付款。

启程的飞机将于周日早上九点起飞。我起了个大早,提前前往机场。兴冲冲地首次坐上长途国际航班后,我立刻收到了飞机因故障无法起飞的通知,灰溜溜地下了飞机。航空公司给我安排的新航班是 36 小时之后,周一晚上九点。为表赔偿,航空公司给我们预订了新加坡五星级酒店的住宿,并附带了早中晚餐各 20 新币的代金券。走进五星级酒店最普通的房间,我原以为能享受一番,却又被正餐 30 新币起步(还不含税)的菜单气晕了过去。上午我在机场吃过了,中午我就随便点了个 20 新币的汤应付了过去。晚上我兼顾价格与饱腹度,点了一个 30 多新币的披萨。晚餐送至房间,我将代金券和 20 新币交给服务员后,服务员将代金券收走,现金还回,说道:“可以拿明天的早餐钱来抵这顿晚饭。”唉,我又亏了将近 10 新币,早知道点一个更好的披萨了。

周一,白天没有安排,我决定去解决新加坡储蓄卡的事。手机应用有挂失换卡的功能,却楞是没有申请解冻的功能。我只好顶着烈日,跨越条条街道,又穿梭于人山人海的购物中心,找到银行的支行。这个支行只有通过视频交流的虚拟柜台。为了确认身份,柜员问了我三个有关账户的验证问题。有一个问题的华文名词我听不懂,没答上来。因此,卡解冻失败,柜员建议我前往线下柜台。我真是纳闷,我有银行账户的一切访问权限,有身份证、护照,为什么要用这种方法来验证我的身份。但我今天时间多,不跟你们计较了。我又匆忙坐上满载的地铁,赶往下一处有线下柜台的支行。地图上说银行 4 点关门,我是三点半到的,可银行已经关门了。我突然感到一股违和感:今天是周一,为什么地铁那么多人?为什么购物中心那么多人?为什么银行提早下班?我一查,果然如我所料,原来今天是新加坡法定假日。今天,到美国开会的同学在当地开开心心地旅游,留在本地的人则能享受假期。就我一个人,飞机没坐上,假没休息到,事情没办成,从一对对情侣旁边擦身而过,在我本不该在的繁华市区里奔波。没办法,休息一下,晚上早点去坐飞机吧。

晚上,刚进机场,我的镜框因没有粘牢又断开了。还好我早有准备,拿出透明胶勉强固定了镜框。

过安检,我被选为 SSSS 级“幸运”用户,被安检警察拎到队伍前面,进行细致的搜身检查。也好,我也不会带什么危险物品,顺便插了个队。

SSSS 是美国TSA (美国运输安全管理局)随机选取需要接受二次安检的乘客。安检人员看到SSSS的登机牌后会对乘客格外注意,在经过普通的X光扫描检查以后,行李会被拆开进行人工检查,人工彻底搜身,电子产品会接受爆炸物探测。

上了飞机,我总算坐上了有着小屏幕的美联航高级国际航班。航班要航行 15 个小时,我计划在这段时间里写一篇文章。一看时间还早,我就开了两把小游戏。玩到一半,电脑啪地屏幕一黑,没电了。飞机上的插座怎么充不进电啊?过了好一会儿,机长发来广播:“很抱歉,有乘客反映插座失灵,我们将重启设备,尽快修好。(翻译自英文)”我满怀期待地等着电脑的开机,这一等就等到了飞机降落。临走前,贴心的乘务长向大家诚恳地致歉:“今天有一些旅客一直没有接上插座,我们真是深表歉意呢~(翻译自英文)”

第一趟飞机在旧金山降落,我还需要乘坐中转西雅图的飞机。由于下一站是国内航班,我需要在这里完成入境。入境的队伍很长,一眼望去,四五十分钟都不见得能排完队。帮我安排航班的工作人员想必之前是一位极限运动的教练,给我安排了一小时后结束登机、一个半小时后起飞的第二趟航班,让我在这里锻炼极限的时间管理。排队五十分钟后,我极速应付完安检人员的问题,总算顺利入境。眼下飞机只有十分钟就要停止登机了,在这种极限条件下,我的身体潜能与语言表达潜能被猛然激活。我背着沉甸甸的书包——我唯一的行李,仔细地看遍每一个路标,在机场里沿着最短路线狂奔,流畅地办理登机、安检。我这一个多小时一直没来得及上厕所。跑至空荡荡的登机口前,我遗憾地望了一眼对面的厕所后,急忙向柜台后的空乘人员出示机票。“啊!你很幸运啊!快点上飞机。”在我的极限运营下,我成功在起飞前登机,就是肚子下面不太舒服。到了原定的起飞时间,我正双目紧闭静待起飞,忽然听见了机长包含温情的广播:“我们得知有部分旅客还在转机,飞机将在约十五分钟后起飞。”不早说啊!我这么赶是为了什么啊!快把厕所门打开!

飞机起飞,我如愿以偿地解了个手,满足地回座位睡了一觉。可能是剧烈运动导致新陈代谢加快,醒来后,我又是一阵内急。可又好巧不巧,飞机刚开始降落,我还要等半个多小时才能下飞机上厕所。在这段时间里,我仿佛领悟了长生不老之道,一分钟可以当十分钟来过。我思考了我为什么这么难受,为什么上这趟飞机,为什么会存在这个世界上。恍惚的精神慢慢回归,我规划出了飞机降落后最快的下机方式。总算到站了,舱门一打开,我就背起书包,右手紧扣腹部,面露难色,名正言顺地在下机的队伍里插队。一路小跑到厕所前,这下没人看,不用演戏了。可我的手却怎么都不肯从肚子上松开——哦,原来我真的憋得肚子不舒服了。好在厕所只有一步之遥了。我特意找了个高级包厢,坐下来不紧不慢地上起了厕所。这是我第一次憋得这么难受,也是第一次觉得上厕所是人生中最幸福的事。

当地时间凌晨两点,我总算来到了西雅图机场出口。没有飞机要赶,没有厕所要上,充足休息了近 20 个小时的我,即将迎来光明的前程。我没有办理美国的流量,因为我知道在需要用网络的地方都有免费无线网络。按照网上的攻略,我前往停车场的某处,准备用 Uber 打车。一到停车场,凌冽的寒气扑面而来。这冷空调开得有点大啊。不对!这是室外的正常气温。我来美国之前是看了天气预报的,每天的温度都是 10 到 20 多度。我过了太久的夏天,对 10 度的天气没有概念,也没有料到我大晚上还会待在室外。但没事啊,打上车就好了。我在机场就配置好了 Uber,只差呼叫司机这一步了。欸,怎么付款前还要验证手机号?这里信号很差,接收国内手机的短信要三分钟左右。当然,这个延迟是我在多次发送短信后才意识到的。由于我多次输入了上一次短信的验证码,Uber 最后忍无可忍,把我的账号禁用了。寒风中,我身着短袖短裤,在空荡的候车处死盯没有回应的手机,打着牙颤,默默无语。冷静了一会儿,我想我有信用卡有钱,为什么非得用软件打车不可。于是,我走到了普通打车处,稍等片刻,上车,最终总算抵达了酒店。

第一次出国开会,第一次前往美国旅行。对于多数人来说快乐的事情,到我这为什么就成了渡劫呢?这些事或许算不上什么人生大事,但攒在一起总是会令人烦恼。诚然,有些事是我考虑不周,但大多数事是我完全无法控制的。不过这些都没关系。在麻将等概率游戏中搏杀已久的我知道,过去是无法改变的,结果通常是无法控制的,我们能做的只有在此刻寻找问题的最优解。哪怕是碰到了这么多事,我还是拿出了我玩解谜游戏的所有实力,尽力去处理好每一件事。如果还是对生活的不幸感到愤愤不平,不如带着幸灾乐祸的心理想一想我这次事件里的美联航。他们的航班延误,影响了那么多乘客,最后给每个人都赔了一晚的五星级酒店,还不是把问题解决了。


以上是我头脑冷静时写的东西。参会过程省略以示愤怒。


本来我准备了很多要写的东西,回程的时候又被气到了,累死了,不写了。这“一天”里,我提前3小时到机场,坐2小时飞机,等5小时飞机,又坐15小时飞机。累计睡眠5~6乘1个小时。到了新加坡我还想省钱,没打车,在机场走了半小时,又坐了1小时地铁。下地铁后走回家,莫名其妙绕了路,在太阳底下穿长裤走了1个多小时。好不容易回到家,洗完澡想刷个牙,牙膏还没了。真作假时假亦真,没睡醒时醒亦困。看上去我这几天睡了很多次觉,但我根本没有睡醒。算上在美国的时间,我已经颠沛流离了好几天,各种硬抗时差,已经不知道睡够六小时是怎样的感觉了。别人比我早一两天来,晚一两天走,玩得开开心心。我除了在会场当3天“好学生”外,剩下近一周时间全在参与人生的极限挑战。怪不得学术大佬都不愿意出来开会,从明年开始我也是学术大佬了。

在上篇文章中,我们浏览了 Stable Video Diffusion (SVD) 的论文,并特别学习了没有在论文中提及的模型结构、噪声调度器这两个模块。在这篇文章中,让我们来看看 SVD 在 Diffusers 中的源码实现。我们会先学习 SVD 的模型结构,再学习 SVD 的采样流水线。在本文的多数章节中,我都会将 SVD 的结构与 Stable Diffusion (SD) 的做对比,帮助之前熟悉 SD 的读者快速理解 SVD 的性质。强烈建议读者在阅读本文前先熟悉 SD 及其在 Diffusers 中的实现。

Stable Diffusion Diffusers 实现源码解读

简单采样实验

目前开源的 SVD 仅有图生视频模型,即给定视频首帧,模型生成视频的后续内容。在首次开源时,SVD 有 1.0 和 1.0-xt 两个版本。二者模型结构配置相同,主要区别在于训练数据上。SVD 1.0 主要用于生成 14 帧 576x1024 的视频,而 1.0-xt 版本由 1.0 模型微调而来,主要用于生成 25 帧 576x1024 的视频。后来,开发团队又开源了 SVD 1.1-xt,该模型在固定帧率的视频数据上微调,输出视频更加连贯。为了做实验方便,在这篇文章中,我们将使用最基础的 SVD 1.0 模型。

参考 Diffusers 官方文档: https://huggingface.co/docs/diffusers/main/en/using-diffusers/svd ,我们来创建一个关于 SVD 的 “Hello World” 项目。如果你的电脑可以访问 HuggingFace 原站的话,直接运行下面的脚本就行了;如果不能访问原网站,可以尝试取消代码里的那行注释,访问 HuggingFace 镜像站;如果还是不行,则需要手动下载 “stabilityai/stable-video-diffusion-img2vid” 仓库,并将仓库路径改成本地下载的仓库路径。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch
import os
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video

pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()

# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
image = image.resize((1024, 576))

generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]

export_to_video(frames, "generated.mp4", fps=7)

成功运行后,我们能得到这样的一个火箭升空视频。它的第一帧会和我们的输入图片一模一样。

SVD 概览

由于 SVD 并没有在论文里对其图生视频模型做详细的介绍,我们没有官方资料可以参考,只能靠阅读源码来了解 SVD 的实现细节。为了让大家在读代码时不会晕头转向,我会在读代码前简单概述一下 SVD 的模型结构和采样方法。

SVD 和 SD 一样,是一个隐扩散模型(Latent Diffusion Model, LDM)。图像(视频帧)的生成由两个阶段组成:先由扩散模型生成压缩图像,再由 VAE 解码成真实图像。

扩散模型在生成图像时,会用一个去噪 U-Net $\epsilon_\theta$ 反复对纯噪声图像 $z_T$ 去噪,直至得到一幅有意义的图片 $z$。为了让模型输出我们想要的图像,我们会用一些额外的信息来约束模型,或者说将约束信息也输入进 U-Net。对于文生图 SD 来说,额外约束是文本。对于图生视频 SVD 来说,额外约束是图像。LDM 提出了两种输入约束信息的方式:与输入噪声图像拼接、作为交叉注意力模块的 K, V。SD 仅使用了交叉注意力的方式,而 SVD 同时使用了两种方式。

上面这两种添加约束信息的方法适用于信息量比较大的约束。实际上,还有一种更简单的输入实数约束信息的方法。除了噪声输入外,去噪模型还必须输入当前的去噪时刻 $t$。自最早的 DDPM 以来,时刻 $t$ 都是先被转换成位置编码,再输入进 U-Net 的所有残差块中。仿照这种输入机制,如果有其他的约束信息和 $t$ 一样可以用一个实数表示,则不必像前面那样将这种约束信息与输入拼接或输入交叉注意力层,只需要把约束也转换成位置编码,再与 $t$ 的编码加在一起。

SVD 给模型还添加了三种额外约束:噪声增强程度、帧率、运动程度。这三种约束都是用和时刻编码相加的形式实现的。

即使现在不完全理解这三种额外约束的意义也不要紧。稍后我们会在学习 U-Net 结构时看到这种额外约束是怎么添加进 U-Net 的,在学习采样流水线时了解这三种约束的意义。

总结一下,除了添加了少数模块外,SVD 和 SD 的整体架构一样,都是以去噪 U-Net 为核心的 LDM。除了原本扩散模型要求的噪声、去噪时刻这两种输入外,SVD 还加入了 4 种约束信息:约束图像(视频首帧)、噪声增强程度、帧率、运动程度。约束图像是最主要的约束信息,它会与噪声输入拼接,且输入进 U-Net 的交叉注意力层中。后三种额外约束会以和处理去噪时刻类似的方式输入进 U-Net 中。

去噪模型结构

接下来,我们来学习 SVD 的去噪模型的结构。在 Diffusers 中,一个扩散模型的参数、配置全部放在一个模型文件夹里,该文件夹的各个子文件夹存储了模型的各个模块,如自编码器、去噪模型、调度器等。我们可以在 https://huggingface.co/stabilityai/stable-video-diffusion-img2vid/tree/main 找到 SVD 的模型文件夹,或者访问我们本地下载好的模型文件夹。

SVD 的去噪 U-Net 放在模型文件夹的 unet 子文件夹里。通过阅读子文件夹里的 config.json,我们就能知道模型类的名字是什么,并知道初始化模型的参数有哪些。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
{
"_class_name": "UNetSpatioTemporalConditionModel",
...
"down_block_types": [
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal"
],
...
"up_block_types": [
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal"
]
}

通过在本地 Diffusers 库文件夹里搜索类名 UNetSpatioTemporalConditionModel,或者利用 IDE 的 Python 智能提示功能,在前文的示例脚本里跳转到 StableVideoDiffusionPipeline 所在文件,再跳转到 UNetSpatioTemporalConditionModel 所在文件,我们就能知道 SVD 的去噪 U-Net 类定义在 diffusers/models/unet_spatio_temporal_condition.py 里。我们可以对照位于 diffusers/models/unet_2d_condition.py 的 SD 的 2D U-Net 类 UNet2DConditionModel 来看一下 SVD 的 U-Net 有何不同。

先来看 __init__ 构造函数。SVD U-Net 几乎就是一个写死了许多参数的特化版 2D U-Net,其构造函数也基本上是 SD 2D U-Net 的构造函数的子集。比如 2D U-Net 允许用 act_fn 来指定模型的激活函数,默认为 "silu",而 SVD U-Net 直接把所有模块的激活函数写死成 "silu"。经过简化后,SVD U-Net 的构造函数可读性高了很多。我们从参数开始读起,逐一了解构造函数每一个参数的意义:

  • sample_size=None:隐空间图片边长。供其他代码调用,与 U-Net 无关。
  • in_channels=8:输入通道数。
  • out_channels=4: 输出通道数。
  • down_block_types:每一大层下采样模块的类名。
  • up_block_types:每一大层上采样模块的类名。
  • block_out_channels = (320, 640, 1280, 1280):每一大层的通道数。
  • addition_time_embed_dim=256: 每个额外约束的通道数。
  • projection_class_embeddings_input_dim=768: 所有额外约束的通道数。
  • layers_per_block=2: 每一大层有几个结构相同的模块。
  • cross_attention_dim=1024: 交叉注意力层的通道数。
  • transformer_layers_per_block=1: 每一大层的每一个模块里有几个 Transformer 层。
  • num_attention_heads=(5, 10, 10, 20): 各大层多头注意力层的头数。
  • num_frames=25: 训练时的帧数。供其他代码调用,与 U-Net 无关。

SVD U-Net 的参数基本和 SD 的一致,不同之处有:1)稍后我们会在采样流水线里看到,SVD 把图像约束拼接到了噪声图像上,所以整个噪声输入的通道数是原来的两倍,从 4 变为 8;2)多了一个给采样代码用的 num_frames 参数,它其实没有被 U-Net 用到。

我们再来大致过一下构造函数的实现细节。SVD U-Net 的整体结构和 2D U-Net 的几乎一致。数据先经过下采样模块,再经过中间模块,最后过上采样模块。下采样模块和上采样模块之间有短路连接。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for i, down_block_type in enumerate(down_block_types):
...
down_block = get_down_block(...)
self.down_blocks.append(down_block)

self.mid_block = UNetMidBlockSpatioTemporal(...)

for i, up_block_type in enumerate(up_block_types):
...
up_block = get_up_block(...)
self.up_blocks.append(up_block)

self.conv_norm_out = nn.GroupNorm(...)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(...)

扩散模型还需要处理去噪时刻约束 $t$。U-Net 会先用正弦编码(Transformer 里的位置编码)time_proj 来将时刻转为向量,再用一系列线性层 time_embedding 预处理这个编码。该编码后续会输入进 U-Net 主体的每一个模块中。

1
2
3
4
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
timestep_input_dim = block_out_channels[0]

self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

除了多数扩散模型都有的 U-Net 模块外,SVD 还加入了额外约束模块。如前文所述,对于能用一个实数表示的约束,可以使用和处理时刻类似的方式,先让其过位置编码层,再过线性层,最后把得到的输出编码和时刻编码加起来。所以,和这种额外约束相关的模块在代码里叫做 add_time。在 2D U-Net 里,额外约束是可选的。SD 没有用到额外约束。而 SVD 把额外约束设为了必选模块。稍后我们会在采样流水线里看到,SVD 将视频的帧率、运动程度、噪声增强强度作为了生成时的额外约束。这些约束都是用这种与时刻编码相加的形式实现的。

1
2
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

构造函数的代码就看完了。在构造函数中,我们认识了 SVD U-Net 的各个模块,但对其工作原理或许还存在着些许疑惑。我们来模型的前向传播函数 forward 里看一下各个模块是怎么处理输入的。

看代码前,我们先回顾一下概念,整理一下 U-Net 的数据处理流程。下面是我之前给 SD U-Net 画的示意图。该图对 SVD 同样适用。和 SD 相比,SVD 的输入 x 不仅包括噪声图像(准确说是多个表示视频帧的图像),还包括作为约束的首帧图像; c 换成了首帧图像的 CLIP 编码;t 不仅包括时刻,还包括一些额外约束。

和上图所示的一样,SVD U-Net 的 forward 方法的输入包含图像 sample,时刻 timestep,交叉注意力层约束(图像编码) encoder_hidden_states , 额外约束 added_time_ids

1
2
3
4
5
6
7
8
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor,
return_dict: bool = True,
)

方法首先会处理去噪时刻和额外参数,我们来看一下这两个输入是怎么拼到一起的。

做完一系列和形状相关的处理后,输入时刻 timestep 变成了 timesteps。随后,该变量会先过正弦编码(位置编码)层 time_proj,再过一些线性层 time_embedding,得到最后输入 U-Net 主体的时刻嵌入 emb。这两个模块的命名非常容易混淆,千万别弄反了。类似地,额外约束也是先过正弦编码层 add_time_proj,再过一些线性层 add_embedding,最后其输出 aug_emb 会加到 emb 上。当然,为了确保结果可以相加,time_embeddingadd_time_proj 的输出通道数是相同的。

1
2
3
4
5
6
7
8
9
10
11
12
13
# preprocessing
# timesteps = timestep

t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb)

time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)
emb = emb + aug_emb

这里有关额外约束的处理写得很差,逻辑也很难读懂。在构造函数里,额外约束的正弦编码层 add_time_proj 的输出通道数 addition_time_embed_dim 是 256, 线性模块 add_embedding 的输入通道数 projection_class_embeddings_input_dim 是 768。两个通道数不一样的模块是怎么接起来的?

1
2
3
4
5
6
7
8
9
def __init__(
...
addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768,
...
)

self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

原来,在下面这份模块前向传播代码中,added_time_ids 的形状是 [batch_size, 3]。其中的 3 表示有三个额外约束。做了 flatten() 再过 add_time_proj 后,可以得到形状为 [3 * batch_size, 256] 的正弦编码 time_embeds。之所以三个约束可以用同一个模块来处理,是因为正弦编码没有学习参数,对所有输入都会产生同样的输出。得到 time_embeds 后,再根据从输入噪声图像里得到的 batch_size,用 reshapetime_embeds 的形状变成 [batch_size, 768]。这样,time_embeds 就可以输入进 add_embedding 里了。 add_embedding 是有可学习参数的,三个约束必须分别处理。

1
2
3
4
time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)

这些代码不应该这样写的。当前的写法不仅可读性差,还不利于维护。比较好的写法是在构造函数里把输入参数从projection_class_embeddings_input_dim 改为 num_add_time,表示额外约束的数量。之后,把 add_embedding 的输入通道数改成 num_add_time * addition_time_embed_dim。这样,使用者不必手动设置合理的 add_embedding 的输入通道数(比如保证 768 必须是 256 的 3 倍),只设置有几个额外约束就行了。这样改了之后,为了提升可读性,还可以像下面那样把 reshape 里的那个 -1 写清楚来。Diffusers 采用这种比较混乱的写法,估计是因为这段代码是从 2D U-Net 里摘抄出来的。而原 2D U-Net 需要兼容更复杂的情况,所以 add_time_projadd_embedding 的通道数需要分别指定。

1
2
3
time_embeds = time_embeds.reshape((batch_size, -1))
->
time_embeds = time_embeds.reshape((batch_size, self.num_add_time * self.addition_time_embed_dim))

预处理完时刻和额外约束后,方法还会修改所有输入的形状,使得它们第一维的长度都是 batch_size 乘视频帧数。正如我们在上一篇文章中学到的,为了兼容图像模型里的模块,我们要先把视频长度那一维和 batch 那一维合并,等到了和时序相关的模块再对视频长度那一维单独处理。

1
2
3
4
5
6
7
8
# Flatten the batch and frames dimensions
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)

后面的代码就和 2D U-Net 的几乎一样了。数据依次经过下采样块、中间块、上采样块。下采样块的中间结果还会保存在栈 down_block_res_samples 里,作为上采样模块的额外输入。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
sample = self.conv_in(sample)

image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)

down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(...)
down_block_res_samples += res_samples

sample = self.mid_block(...)

for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(...)

sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

光看 U-Net 类,我们还看不出 SVD 的 3D U-Net 和 2D U-Net 的区别。接下来,我们来看一看 U-Net 中某一个具体的模块是怎么实现的。由于 U-Net 下采样块、中间块、上采样块的结构是类似的,我们只挑某一大层的下采样模块类 CrossAttnDownBlockSpatioTemporal 来学习。

CrossAttnDownBlockSpatioTemporal 类中,我们可以看到 SVD U-Net 的每一个子模块都可以拆成残差卷积块和 Transformer 块。数据在经过子模块时,会先过残差块,再过 Transformer 块。我们来继续深究时序残差块类 SpatioTemporalResBlock 和时序 Transformer 块 TransformerSpatioTemporalModel 的实现细节。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# __init__
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(...)
)
attentions.append(
TransformerSpatioTemporalModel(...)
)
# forward
blocks = list(zip(self.resnets, self.attentions))
for resnet, attn in blocks:
hidden_states = resnet(hidden_states, ...)
hidden_states = attn(hidden_states, ...)

在开始看代码之前,我们再回顾一下论文里有关 3D U-Net 块的介绍。SVD 的 U-Net 是从 Video LDM 的 U-Net 改过来的。下面的模块结构图源自 Video LDM 论文,我将其改成了能描述 SVD U-Net 块的图。图中红框里的模块表示在原 SD 2D U-Net 块的基础上新加入的模块。可以看出,SVD 实际上就是在原来的 2D 残差块后面加了一个 3D 卷积层,原空间注意力块后面加了一个时序注意力层。旧模块输出和新模块输出之间用一个比例 $\alpha$ 来线性混合。中间数据形状变换的细节我们已经在上篇文章里学过了,这篇文章里我们主要关心这些模块在代码里大概是怎么定义的。

3D 残差块类 SpatioTemporalResBlockdiffusers/models/resnet.py 文件中。它有三个子模块,分别对应上文示意图中的 2D 残差块、时序残差块(3D 卷积)、混合模块。在运算时,旧模块的输出会缓存到
hidden_states_mix 中,新模块的输出为 hidden_states,二者最终会送入混合模块 time_mixer 做一个线性混合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class SpatioTemporalResBlock(nn.Module):
def __init__(
self,
...
):
super().__init__()
self.spatial_res_block = ResnetBlock2D(...)
self.temporal_res_block = TemporalResnetBlock(...)
self.time_mixer = AlphaBlender(...)

def forward(
self,
...
):
hidden_states = self.spatial_res_block(hidden_states, temb)

...

hidden_states_mix = hidden_states

...

hidden_states = self.temporal_res_block(hidden_states, temb)
hidden_states = self.time_mixer(
x_spatial=hidden_states_mix,
x_temporal=hidden_states,
)
...
return hidden_states

ResnetBlock2D 是 SD 2D U-Net 的残差模块,我们在这篇文章里就不去学习它了。 时序残差块 TemporalResnetBlock 和 2D 残差块的结构几乎完全一致,唯一的区别在于 2D 卷积被换成了 3D 卷积。从代码中我们可以知道,这个模块是一个标准的残差块,数据会依次过两个卷积层,并在最后输出前与输入相加。扩散模型中的时刻约束 temb 会在数据过完第一个卷积层后,加到数据上。值得注意的是,虽然类里面的卷积层名字叫 3D 卷积,但实际上它的卷积核形状为 (3, 1, 1),这说明这个卷积层实际上只是一个时序维度上窗口大小为 3 的 1D 卷积层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class TemporalResnetBlock(nn.Module):
def __init__(...):
kernel_size = (3, 1, 1)
padding = [k // 2 for k in kernel_size]

self.norm1 = torch.nn.GroupNorm(...)
self.conv1 = nn.Conv3d(...)

if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None

self.norm2 = torch.nn.GroupNorm(...)

self.dropout = torch.nn.Dropout(0.0)
self.conv2 = nn.Conv3d(...)

self.nonlinearity = get_activation("silu")

self.use_in_shortcut = self.in_channels != out_channels

self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = nn.Conv3d(...)

def forward(self, input_tensor, temb):
hidden_states = input_tensor

hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)

if self.time_emb_proj is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, :, None, None]
temb = temb.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states + temb

hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)

if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)

output_tensor = input_tensor + hidden_states

return output_tensor

混合模块 AlphaBlender 其实就只是定义了一个可学习的混合比例 mix_factor,之后用这个比例来混合空间层输出和时序层输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class AlphaBlender(nn.Module):

def __init__(
self,
alpha: float,
...
):
...
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))

def forward(
self,
x_spatial,
x_temporal,
...
) -> torch.Tensor:
# Get mix_factor
alpha = self.get_alpha(...)
alpha = alpha.to(x_spatial.dtype)

if self.switch_spatial_to_temporal_mix:
alpha = 1.0 - alpha

x = alpha * x_spatial + (1.0 - alpha) * x_temporal
return x

看完了3D 残差块 SpatioTemporalResBlock 的内容,我们接着来看 3D 注意力块 TransformerSpatioTemporalModel 的内容。TransformerSpatioTemporalModel 也主要由 2D Transformer 块 BasicTransformerBlock、时序 Transformer 块 TemporalBasicTransformerBlock 、混合模块组成 AlphaBlender。它们的连接方式和上面的残差块类似。时序 Transformer 块和普通 2D Transformer 块一样,都是有自注意力、交叉注意力、全连接层的标准 Transformer 模块,它们的区别只在于时序 Transformer 块对输入做形状变换的方式不同,会让数据在时序维度上做信息交互。这里我们就不去进一步深究它们的实现细节了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class TransformerSpatioTemporalModel(nn.Module):
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: int = 320,
out_channels: Optional[int] = None,
num_layers: int = 1,
cross_attention_dim: Optional[int] = None,
):
...

self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(...)
for d in range(num_layers)
]
)

self.temporal_transformer_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(...)
for _ in range(num_layers)
]
)

self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
self.time_mixer = AlphaBlender(alpha=0.5, ...)

这个时序 Transformer 模块类有一个地方值得注意。我们知道,Transformer 模型本身是不知道输入数据的顺序的。无论是注意力层还是全连接层,它们都与顺序无关。为了让模型知道数据的先后顺序,比如在 NLP 里我们希望模型知道一句话里每个单词的前后顺序,我们会给输入数据加上位置编码。而有些时候我们觉得模型不用知道数据的先后顺序。比如在 SD 的 2D 图像 Transformer 块里,我们把每个像素当成一个 token,每个像素在 Transformer 块的运算方式是相同的,与其所在位置无关。而在处理视频时序的 Transformer 块中,知道视频每一帧的先后顺序看起来还是很重要的。所以,和 SD 的 2D Transformer 块不同,SVD 的时序 Transformer 块根据视频的帧号设置了位置编码,用和 NLP 里处理文本类似的方式处理视频。SVD 的时序 Transformer 类在构造函数里定义了生成位置编码的模块 TimestepEmbedding, Timesteps。在前向传播时,forward 方法会用 torch.arange(num_frames) 根据总帧数生成帧号列表,并经过两个模块得到最终的位置编码嵌入 emb。嵌入 emb 会在数据过时序 Transformer 块前与输入 hidden_states_mix 相加。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class TransformerSpatioTemporalModel(nn.Module):
def __init__(...):
...
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
...
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
...
):
...

num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
num_frames_emb = num_frames_emb.reshape(-1)
t_emb = self.time_proj(num_frames_emb)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]

for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
hidden_states = block(
...
)

hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb

hidden_states_mix = temporal_block(...)
hidden_states = self.time_mixer(...)

hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()

output = hidden_states + residual
...

到这里,我们就读完了 SVD U-Net 的主要代码。相比 SD U-Net,SVD U-Net 主要做了以下修改:

  • 由于输入多了一张约束图像,输入通道数变为原来的两倍。
  • 多加了三个和视频相关的额外约束。它们是通过和扩散模型的时刻嵌入相加输入进模型的。它们的命名通常与 add_time 相关。
  • 仿照 Video LDM 的结构设计,SVD 也在 2D 残差块后面加入了由 3D 卷积组成的时序残差块,在空间 Transformer 块后面加入了对时序维度做注意力的时序 Transformer 块。新旧模块的输出会以一个可学习的比例线性混合。

VAE 结构

SVD 不仅微调了 SD 的 U-Net,还微调了 VAE 的解码器,让输出视频在时序上更加连贯。由于更新 VAE 和更新 U-Net 的方法几乎一致,我们就来快速看一下 SVD 的时序 VAE 的结构,而跳过每个模块的更新细节。

通过阅读 VAE 的配置文件,我们可以知道时序 VAE 的类名为 AutoencoderKLTemporalDecoder,它位于文件 diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py 中。从它的构造函数里我们可以知道,时序 VAE 的编码器类是 Encoder,和 SD 的一样,只是解码器类从 Decoder 变成了 TemporalDecoder。我们来看一下这个新解码器类的代码做了哪些改动。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
...
):
super().__init__()

self.encoder = Encoder(...)

self.decoder = TemporalDecoder(...)

...

在 SD 中,VAE 和 U-Net 的组成模块是几乎一致的,二者的结构主要有三个区别:1)由于 VAE 的解码器和编码器是独立的,它们之间没有残差连接。而 U-Net 是一个整体,它的编码器(下采样块)和解码器(上采样块)之间有残差连接,以减少数据在下采样中的信息损失; 2)由于 VAE 中图像的尺寸较大,仅在 VAE 最深层图像尺寸为 64x64 时才有自注意力层。具体来说,这个自注意力层加到了 VAE 解码器的一开头,代码中相关模块称为 mid_block;3)VAE 仅有空间自注意力,而 SD U-Net 用了完整的 Transformer 块(包含自注意力层、交叉注意力层、全连接层)。由于 SD VAE 和 U-Net 结构上的相似性,SVD 的开发者直接把对 U-Net 的更新也搬到了 VAE 上来。

SVD VAE 解码器仅做了两项更新:1)将所有模块里的 2D 残差块都被换成了我们在上文中见过的 3D 残差块;2)在最终输出前加了一个 3D 卷积(时序维度上的 1D 卷积)。VAE 的自注意力层的结构并没有更新。更新 2D 残差块的方法和 U-Net 的是一致的。比如在新的上采样块类 UpBlockTemporalDecoder 中,我们就可以看到之前在新 U-Net 里看过的 3D 残差块类 SpatioTemporalResBlock 的身影。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
...

class UpBlockTemporalDecoder(nn.Module):
def __init__(...):
super().__init__()
for i in range(num_layers):
...
resnets.append(SpatioTemporalResBlock(...))

class TemporalDecoder(nn.Module):
def __init__(...):
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = nn.Conv2d(...)
self.mid_block = MidBlockTemporalDecoder(...)

...

for i in range(len(block_out_channels)):
...
up_block = UpBlockTemporalDecoder(...)
self.up_blocks.append(up_block)

...

conv_out_kernel_size = (3, 1, 1)
self.time_conv_out = torch.nn.Conv3d(...)

采样流水线

看完了 U-Net 和 VAE 的代码后,我们来看整套 SVD 的采样代码。和其他方法一样,在 Diffusers 中,一套采样方法会用一个流水线类 (xxxPipeline)来表示。SVD 对应的流水线类叫做 StableVideoDiffusionPipeline。我们可以利用 IDE 的代码跳转功能,在本文开头的示例采样脚本中跳转至 StableVideoDiffusionPipeline 所在源文件 diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py

如示例脚本所示,使用流水线类时,可以将类实例 pipe 当成一个函数来用。这种用法实际上会调用实例的 __call__ 方法。所以,在阅读流水线类的代码时,我们可以先忽略其他部分,直接看 __call__ 方法。

1
2
pipe = StableVideoDiffusionPipeline.from_pretrained(...)
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]

__call__ 的参数定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
height: int = 576,
width: int = 1024,
num_frames: Optional[int] = None,
num_inference_steps: int = 25,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: float = 0.02,
decode_chunk_size: Optional[int] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
return_dict: bool = True,
):

__call__ 的参数就是我们在使用 SVD 采样时能修改的参数,我们需要把其中的主要参数弄懂。各参数的解释如下:

  • image:SVD 会根据哪张图片生成视频。
  • height, width: 生成视频的尺寸。如果输入图片与这个尺寸对不上,会将输入图片的尺寸调整为该尺寸。
  • num_frames: 生成视频的帧数。SVD 1.0 版默认 14 帧,1.0-xt 版默认 25 帧。
  • min_guidance_scale, max_guidance_scale: 使用 Classifiser-free Guidance (CFG) 的强度范围。SVD 用了一种特殊的设置 CFG 强度的机制,稍后我们会在采样代码里见到。
  • fps:输出视频期望的帧率。SVD 的额外约束。实际上这个帧率肯定是不准的,只不过提高这个值可以让视频更平滑。
  • motion_bucket_id: SVD 的额外约束。官方没有解释该值的原理,只说明了提高该值能让输出视频的运动更多。
  • noise_aug_strength: 对输入图片添加的噪声强度。值越低输出视频越像原图。
  • decode_chunk_size: 一次放几张图片进时序 VAE 做解码,用于在内存占用和效果之间取得一个平衡。按理说一次处理所有图片得到的视频连续性最好,但那样也会消耗过多的内存。
  • num_videos_per_prompt: 对于每张输入图片 (prompt),输出几段视频。
  • generator: PyTorch 的随机数生成器。如果想要手动控制生成中的随机种子,就手动设置这个变量。
  • latents: 强制指定的扩散模型的初始高斯噪声。
  • output_type: 输出图片格式,是 NumPy、PIL,还是 PyTorch。
  • callback_on_step_endcallback_on_step_end_tensor_inputs 用于在不修改原流水线代码的情况下向采样过程中添加额外的处理逻辑。学习代码的时候可以忽略。
  • return_dict: 流水线是返回一个词典,还是像普通 Python 函数一样返回用元组表示的多个返回值。

大致搞清楚了输入参数的意义后,我们来看流水线的执行代码。一开始的代码都是在预处理输入,可以直接跳过。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames

# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width)

# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
device = self._execution_device
self._guidance_scale = max_guidance_scale

之后,代码开始预处理交叉注意力层的约束信息。在 SD 里,约束信息是文本,所以这一步会用 CLIP 文本编码器得到约束文本的嵌入。而 SVD 是一个图生视频模型,所以这一步会用 CLIP 图像编码器得到约束图像的嵌入。

1
2
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)

代码还把额外约束帧率 fps 减了个一,因为训练的时候模型实际上输入的额外约束是 fps - 1

1
fps = fps - 1

接着,代码开始处理与噪声拼接的约束图像。回顾一下,SVD 的约束图像以两种形式输入进模型:一种是过 CLIP 图像编码器,以交叉注意力 K,V 的形式输入,其预处理如上部分的代码所示;另一种形式是与原去噪 U-Net 的噪声输入拼接,其预处理如当前这部分代码所示。

在预处理要拼接的图像时,代码会先调用预处理器 image_processor.preprocess,把其他格式的图像转成 PyTorch 的 Tensor 类型。之后,代码会随机生成一点高斯噪声,并把噪声根据噪声增强强度 noise_aug_strength 加到这张约束图像上。这种做法来自于之前有约束图像的扩散模型 Cascaded diffusion modelsnoise_aug_strength 稍后会作为额外约束输入进 U-Net 里,与去噪时刻的编码相加。

1
2
3
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
image = image + noise_aug_strength * noise

加了这个噪声后,图像会过 VAE 的编码器,得到 image_latentsimage_latents 会通过 repeat 操作复制成多份,并于稍后拼接到每一帧带噪图像上。注意,一般图像在过 VAE 的编码器后,要乘一个系数 vae.config.scaling_factor; 在过 VAE 的解码器前,要除以这个系数。然而,只有在这个地方,image_latents 没有乘系数。我个人觉得这是开发者的一个失误。当然,做不做这个操作对于模型来说区别不大,因为模型能很快学会这种系数上的差异。

1
2
3
4
5
6
7
8
9
# 4. Encode input image using VAE
image_latents = self._encode_vae_image(
image,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
image_latents = image_latents.to(image_embeddings.dtype)
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)

下一步,代码会把三个额外约束拼接在一起,得到 added_time_ids。它会接入到 U-Net 中,与时刻编码加到一起。在训练时,帧率 fps 和 运动程度 motion_bucket_id 完全来自于数据集标注,而 noise_aug_strength 是可以随机设置的。在采样时,这三个参数都可以手动设置。

1
2
3
4
5
6
7
8
9
10
11
# 5. Get Added Time IDs
added_time_ids = self._get_add_time_ids(
fps,
motion_bucket_id,
noise_aug_strength,
image_embeddings.dtype,
batch_size,
num_videos_per_prompt,
self.do_classifier_free_guidance,
)
added_time_ids = added_time_ids.to(device)

再下一步,代码会将采样的总步数 num_inference_steps 告知采样调度器 scheduler。这一步是 Diffusers API 的要求。

1
2
# 6. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)

然后,代码会随机生成初始高斯噪声。不同的随机噪声即对应不同的输出视频。

1
2
3
4
5
6
7
8
9
10
11
12
13
# 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_frames,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)

开始采样前,SVD 对约束图像的强度做了一种很特殊的设定。在看代码之前,我们先回顾一下约束强度的意义。现在的扩散模型普遍使用了 CFG (Classifier-free Guidance) 技术,它允许我们在采样时灵活地调整模型和约束信息的相符程度。这个强度默认取 1.0。我们可以通过增大强度来提升模型的生成效果,比如在 SD 中,这个强度一般取 7.5,这代表模型会更加贴近输入文本。

而 SVD 中,约束信息为图像。开发者对视频的不同帧采用了不同的约束强度:首帧为 min_guidance_scale, 末帧为 max_guidance_scale。强度从首帧到末帧线性增加。默认情况下,约束强度的范围是 [1, 3]。

1
2
3
4
5
6
7
# 8. Prepare guidance scale
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
guidance_scale = guidance_scale.to(device, latents.dtype)
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
guidance_scale = _append_dims(guidance_scale, latents.ndim)

self._guidance_scale = guidance_scale

最后,就来到了扩散模型的去噪循环了。根据之前采样调度器返回的采样时刻列表 timesteps,代码从中取出去噪时刻,对纯噪声输入迭代去噪。

1
2
3
4
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):

去噪迭代的一开始,代码会根据是否要执行 CFG 来决定是否要把输入额外复制一份。这是因为做 CFG 时,我们需要把同一个输入过两次去噪模型,一次带约束,一次不带约束。为了简化这个流程,我们可以直接把输入复制一遍,这样只要过一次去噪模型就能得到两个输出了。下一行的 scale_model_input 是 Diffusers 的 API 要求,可以忽略。

1
2
3
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

接着,加了噪声、过了 VAE 解码器、没有乘系数的约束图像 image_latents 会与普通的噪声拼接到一起,作为模型的直接输入。

1
2
# Concatenate image_latents over channels dimension
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)

准备好了所有输入后,代码调用 U-Net 对输入噪声图像去噪。输入包括直接输入 latent_model_input,去噪时刻 t,约束图像的 CLIP 嵌入 image_embeddings,三个额外约束的拼接 added_time_ids

1
2
3
4
5
6
7
8
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=image_embeddings,
added_time_ids=added_time_ids,
return_dict=False,
)[0]

去噪结束后,代码根据公式做 CFG。

1
2
3
4
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)

有了去噪的输出 noise_pred 还不够,我们还需要用一些比较复杂的公式计算才能得到下一时刻的噪声图像。这一切都被 Diffusers 封装进调度器里了。

1
2
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample

以上就是一步去噪迭代的主要内容。代码会反复执行去噪迭代。这后面除了下面这行会调用 VAE 解码器将隐空间的视频解码回真实视频外,没有其他重要代码了。

1
frames = self.decode_latents(latents, num_frames, decode_chunk_size)

总结

在这篇文章中,我们学习了图生视频模型 SVD 的模型结构和采样代码。整体上看,SVD 相较 SD 在模型上的修改不多,只是在原来的 2D 模块后面加了一些在时序维度上交互信息的卷积块和 Transformer 块。在学习时,我们应该着重关注 SVD 的采样流水线。SVD 使用拼接和交叉注意力两种方式添加了图像约束,并以与时刻编码相加的方式额外输入了三种约束信息。由于视频不同帧对于首帧的依赖情况不同,SVD 还使用了一种随帧号线性增长的 CFG 强度设置方式。

近期,各个科技公司纷纷展示了自己在视频生成模型上的最新成果。虽然不少模型的演示效果都非常惊艳,但其中可供学术界研究的开源模型却少之又少。Stable Video Diffusion (SVD) 算得上是目前开源视频生成模型中的佼佼者,有认真一学的价值。在这篇文章中,我将面向之前已经熟悉 Stable Diffusion (SD) 的读者,简要解读 SVD 的论文。由于 SVD 的部分结构复用了之前的工作,并没有在论文正文中做详细介绍,所以我还会补充介绍一下 SVD 的模型结构、调度器。后续我还会在其他文章中详细介绍 SVD 的代码实现及使用方法。

背景

Stable Video Diffusion 是 Stability 公司于 2023 年 11 月 21 日公布并开源的一套用扩散模型实现的视频生成模型。由于该模型是从 Stability 公司此前发布的著名文生图模型 Stable Diffusion 2.1 微调而成的,因而得名 Stable Video Diffusion。SVD 的技术报告论文与模型同日发布,它对 SVD 的训练过程做了一个详细的分享。由于该论文过分偏向实践,这里我们仅对它的开头及中间模型设计的几处关键部分做解读。

摘要

最近,有许多视频生成模型都是在图像生成模型 SD 的基础上,添加和视频时序相关的模块,并在小规模高质量视频数据集上微调新模型。而 SVD 作者认为,该领域在训练方法及精制数据集的策略上并未达成统一。这篇文章的主要贡献,也正是提出了一套训练方法与精制数据集的方法。具体而言,SVD 的训练由三个阶段组成:文生图预训练、视频预训练、高质量视频微调。同时,SVD 提出了一种系统性的数据精制流程,包含数据的标注与过滤这两部分的策略。论文会分享诸多的实验成果,包括验证精心构建的数据集对生成高质量视频的必要性、探究视频预训练与微调这两步的重要性、展示基础模型如何为图生视频等下游任务提供强大的运动表示、演示模型如何提供多视角三维先验并可以作为微调多视角扩散模型的基础模型在一轮神经网络推理中同时生成多视角的图片。

「构建」一个数据集在论文中通常用动词 curate 及名词 curation 指代。curate 原指展出画作时,选择、组织和呈现艺术品的过程。而现代将这个词用在数据集上时,则转变为表示精心选择、组织和管理数据的过程。中文中并没有完全对应的翻译,我暂时将这个词翻译为「精制」,以区别于随便收集一些数据来构成一个数据集。

总结一下,SVD 并没有强调在模型设计或者采样算法上的创新,而主要宣传了该工作在数据集精制及训练策略上的创新。对于大部分普通研究人员来说,由于没有训练大视频模型的需求,该文章的很多内容都价值不大。我们就只是来大致过一遍这篇文章的主要内容。

SVD 模型架构回顾

Video-LDM 与 SVD

在阅读正文之前,我们先来回顾一下此前视频生成模型的开发历程,并重点探究 SVD 的模型架构——Video LDM 的具体组成。绝大多数工作在训练一个基于扩散模型的视频生成模型时,都是在预训练的 SD 上加入时序模块,如 3D 卷积,并通过微调把一个图像生成模型转换成视频生成模型。由于 SD 是一种 LDM (Latent Diffusion Model),所以这些视频模型都可以归类为 Video-LDM。所谓 LDM,就是一种先生成压缩图像,再用解码模型把压缩图像还原成真实图像的模型。而对于视频,Video-LDM 则会先生成边长压缩过的视频,再把压缩视频还原。

虽然 Video-LDM 严格上来说是一个视频扩散模型的种类,但大家一般会用Video LDM (没有横杠) 来指代 Align your Latents: High-Resolution Video Synthesis with Latent Diffusion Models 这篇工作。这篇论文已在 CVPR 2023 上发布,两个主要作者正是前一年在 CVPR 上发表 SD 论文的主要作者,也是现在这篇 SVD 论文的主要作者。从署名上来看,似乎两个作者在毕业后就加入了 Stability 公司,并将 Video LDM 拓展成了 SVD。论文中也讲到,SVD 完全复用了 Video LDM 的结构。为了了解 SVD 的模型结构,我们再来回顾一下 Video LDM 的结构。

在 SD 的基础上,Video LDM 做对模型结构了两项改动:在扩散模型的去噪模型 U-Net 中加入时序层、在对图像压缩和解压的 VAE 的解码器中加入时序层。

添加时序层

Video LDM 在 U-Net 中加入时序层的方法与多数同期方法相同,是在每个原来处理图像的空间层后面加上处理视频的时序层。Video LDM 加入的时序层包括 3D 卷积层与时序注意力层。这些新模块本身不难理解,但我们需要着重关注这些新模块是怎么与原模型兼容的。

要兼容各个模块,其实就是要兼容数据的形状。本来,图像生成模型的 U-Net 的输入形状为 B C H W,分别表示图像数、通道数、高、宽。而视频数据的形状是 B T C H W,即视频数、视频长度、通道数、高、宽。要让视频数据复用之前的图像模型的结构,只要把数据前两维合并,变成 (B T) C H W 即可。这种做法就是把 B 组长度为 T 的视频看成了 $B \cdot T$ 张图片。

对于之前已有的空间层,只要把数据形状变成 (B T) C H W 就没问题了。而 SVD 又新加入了两种时序层:3D 卷积和时序注意力。我们来看一下数据是怎么经过这些新的时序层的。

2D 卷积会对 B C H W 的数据的后两个高、宽维度做卷积。类似地,3D 卷积会对数据最后三个时间、高、宽维度做卷积。所以,过 3D 卷积前,要把形状从 (B T) C H W 变成 B C T H W,做完卷积再还原。

接下来我们来看新的时序注意力。这个地方稍微有点难理解,我们从最简单的注意力开始一点一点学习。最早的 NLP 中的注意力层的输入形状为 B L C,表示数据数、token 长度、token 通道数。L 这一维最为重要,它表示了 L个 token 之间互相交换信息。如果把其拓展成图像空间注意力,则 token 表示图像的每一个像素。在这种注意力层中,L(H W)B C H W 的数据会被转换成 B (H W) C 输入进注意力层。这表示同一组图像中,每个像素两两之间交换信息。而让视频数据过空间注意力层时,只需要把 B 换成 (B T) 即可,即把数据形状从 (B T) C H W 变为 (B T) (H W) C。这表示同一组、同一帧的图像的每个像素之间,两两交换信息。

在 SVD 新加入的时序注意力层中,token 依旧指代是某一组、某一帧上的一个像素。然而,这次我们不是让同一张图像的像素互相交换信息,而是让不同时刻的像素互相交换信息。因此,这次 token 长度 LT,它表示要像素在时间维度上交换信息。这样,在视频数据过时序层里的自注意力层时,要把数据形状从 (B T) C H W 变成 (B H W) T C。这表示每一组、图像每一处的像素独立处理,它们仅与同一位置不同时间的像素进行信息交换。

此处如果你没有理解注意力层的形状变换也不要紧,它只是一个实现细节,不影响后面的阅读。如果感兴趣的话,可以回顾一下 Transformer 论文的代码,看一下注意力运算为什么是对 B L C 的数据做操作的。

微调 VAE 解码器

Video LDM 的另一项改动是修改了图像压缩模型 VAE 的解码器。具体来说,方法先在 VAE 的解码器中加入类似的时序层,并在 VAE 配套的 GAN 的判别器里也加入了时序层,随后开始微调。在微调时,编码器不变,仅训练解码器和判别器。

如果你没听过这套 VAE + GAN 的架构的话,请回顾 Stable Diffusion 论文及与其紧密相关的 VQGAN 论文。

以上就是 Video LDM 的模型结构。SVD 对其没有做任何更改,所以也没有在论文里对模型结构做详细介绍。稍有不同的是,Video LDM 仅微调了新加入的模块,而 SVD 在加入新模块后对模型的所有参数都进行了重新训练。

SVD 训练细节

SVD 分四节介绍了模型训练过程。第一节介绍了数据精制的过程,后三节分别介绍了训练的三个阶段:文生图预训练、视频预训练、高质量视频微调。

获取了一个大规模视频数据集后,SVD 的数据精制主要由预处理和标注这两步组成。由于视频生成模型主要关注生成同一个场景的视频,而不考虑转场的问题,每段训练视频也应该尽量只包含一个场景。为此,预处理主要是在用一些自动化视频剪切工具把收集到的视频进一步切成连续的片段。经切片后,视频片段数变为原来的4倍。标注主要是给视频加上文字描述,以训练一个文生视频的模型。SVD 在添加文字描述时用到了多个标注模型,并使用大语言模型来润色描述。经预处理和标注后,得到的数据集被称作 LVD (Large Video Dataset)。

SVD 数据精制的细节中,比较值得注意的是有关视频帧数的处理。由于开发团队发现视频数据的播放速度快慢不一,于是他们使用光流预测模型来大致估计每段视频的播放速度(以帧率 FPS 表示),并将视频的帧率也作为标注。这样,在训练时,视频的帧率也可以作为一种约束信息。这样的好处是,在我们在生成视频时,可以用该约束来指定视频的播放速度。

之后我们来看 SVD 模型训练的三个阶段。对于第一个文生图预训练阶段,论文没有对模型结构做过多修改,因为他们在这一步使用了之前训练好的 SD 2.1。不过,SVD 在这一步做了一个非常重要的改进:SVD 的噪声调度器从原版的 DDPM 改成了 EDM,采样方法也改成了 EDM 的。

EDM 的论文全称为 Elucidating the Design Space of Diffusion-Based Generative Models 。这篇论文用一种概括性较强的数学模型统一表示了此前各种各样的扩散模型结构,并提出了改进版模型的训练及采样策略。简单来说,EDM 把扩散模型不同时刻的噪声强度表示成 $\sigma_t$,它表示在 $t$ 时刻时,对来自数据集的图像加了标准差为 $\sigma_t$ 的高斯噪声 $\mathcal{N}(\mathbf{0}, \sigma_t^2\mathbf{I})$。一开始,对于没加噪声的图像,$\sigma_0=0$。对于最后一个时刻 $T$ 的图像,$\sigma_T$ 要足够大,使得原图像的内容被完全破坏。

这里时刻 $0$ 与时刻 $T$ 的定义与 DDPM 论文相同,与 EDM 论文相反。

有了这样一种统一的表示后,EDM 对扩散模型的训练和采样都做了不少改进。这里我们仅关注其中最重要的一条改进:将离散噪声改进成连续噪声。原来 DDPM 的去噪模型会输入时刻 $t$ 这个参数。EDM 论文指出,$t$ 实际上表示了噪声强度 $\sigma_t$,应该把 $\sigma_t$ 输入进模型。与其用离散的 $t$ 训练一个只认识离散噪声强度的去噪模型,不如训练一个认识连续噪声强度 $\sigma$ 的模型。这样,在采样 $n$ 步时,我们不再是选择离散去噪时刻[timestep[n], timestep[n - 1], ..., 0],而是可以选择连续噪声强度[sigma[n], sigma[n - 1], ..., 0] 。这样采样更灵活,效果也更好。在第一个训练阶段中,SVD 照搬了 EDM 的这种训练方法,改进了原来的 DDPM。SVD 的默认采样策略也使用了 EDM 的 。我们会在之后的代码实践文章中详细学习这种新采样方法。

对于第二个视频预训练阶段,或许是因为视频模型和图像模型的训练过程毫无区别,论文的介绍重点依然放在了这一阶段的数据处理上,而没有强调训练方法上的创新。简单来看,这一阶段的目标是得到一个过滤后的高质量数据集 LVD-F。为了找到这样一种合适的过滤方案,开发团队先用排列组合生成了大量的过滤方案:对每类指标(文本视频匹配度、美学分数、帧率等)都设置 12.5%, 25% 或 50% 的过滤条件,然后不同指标的条件之间排列组合。之后,开发团队抽取原数据集的一个子集 LVD-10M,用各个方案得到过滤后的视频子集 LVD-10-F。最后,用这样得到的子数据集分别训练模型,比较模型输出的好坏,以决定在完整数据集上使用的最优过滤方案。

在第三个阶段,参考以往多阶段训练图像模型的经验,SVD 也在另一个小而精的视频数据集上进行微调。此数据集的获取方法并没有在论文中给出,大概率是人工手动收集并标注。

SVD 应用

经上述训练后,开发团队得到了一个低分辨率的基础文生视频模型。在实验部分,SVD 论文除了给出视频生成模型在各大公开数据集上的指标外,还分享了几个基于基础模型的应用。

高分辨率文生视频

基础文生视频最直接的应用就是高分辨率文生视频。实现的方法很简单,只要准备一个高分辨率的视频数据集,在此数据集上微调原基础模型即可。SVD 高分辨率文生视频模型能生成 576 x 1024 的视频。

高分辨率图生视频

除了文生视频外,也可以用基础模型来微调出一个图生视频模型。为了把约束从文本换成图像,开发团队将 U-Net 交叉注意力层的约束从文本嵌入变成了约束图像的图像嵌入,并将约束图像与原 U-Net 的噪声输入在通道维度上拼接在一起。特别地,参考以往 Cascaded diffusion models 论文的经验,约束图像在与噪声输入拼接前,会加上一些噪声。除此之外,由于约束机制的变动,像文生图模型一样将约束强度(CFG scale)设成 7.5 会让 SVD 图生视频模型产生瑕疵。因此,SVD 图生视频模型每一帧的约束强度不同,从第一帧到最后一帧以 1.0 到 3.0 线性增长。

参考之前 AnimateDiff 工作,SVD 也成功训练了相机运动 LoRA,使得图生视频模型只会生成平移、缩放等某一种特定相机运动的视频。

视频插帧

Video LDM 曾提出了一种把基础视频模型变成视频插帧模型方法。该方法以视频片段的首末帧为额外约束,在此新约束下把视频生成模型微调成了预测中间帧的视频预测模型。SVD 以同样方式实现了这一应用。

多视角生成

多视角生成是计算机视觉中另一类重要的任务:给定 3D 物体某个视角的图片,需要算法生成物体另外视角的图片,从而还原 3D 物体的原貌。而视频生成模型从数据中学到了物体的平滑变换规律,恰好能帮助到多视角生成任务。SVD 论文用不少篇幅介绍了如何在 3D 数据集上生成视频并微调基础模型,从而得到一个能生成环绕物体旋转的视频的模型。

结语

Stable Video Diffusion 是在文生图模型 Stable Diffusion 2.1 的基础上添加了和 Video LDM 相同的视频模块微调而成的一套视频生成模型。SVD 的论文主要介绍了其精制数据集的细节,并展示了几个微调基础模型能实现的应用。通过微调基础低分辨率文生视频模型,SVD 可以用于高分辨率文生视频、高分辨率图生视频、视频插帧、多视角生成。

对于没有资源与需求训练大视频模型的多数科研人员而言,没有深究这篇文章细节的必要。并且,由于 SVD 只开源了图生视频模型 (3D模型后来是在 SV3D 论文中正式公布的),这篇文章比较有用的只有和图生视频相关的部分。为了彻底搞懂 SVD 的原理,读这篇论文是不够的,我们还需要通过回顾 Video LDM 论文来了解模型结构,学习 EDM 论文来了解训练及采样机制。

这篇文章主要是面向熟悉 Stable Diffusion 的读者的。如果你缺少某些背景知识,欢迎读我之前介绍 Stable Diffusion 的文章。我没有在本文过多介绍 SVD 的实现细节,欢迎阅读我之后发表的 SVD 代码实践文章。

Stable Diffusion 解读(一):回顾早期工作

Stable Diffusion 解读(二):论文精读

Stable Diffusion 解读(三):原版实现及Diffusers实现源码解读

FID 是一种衡量图像生成模型质量的指标。对于这种常见的指标,一般都能找到好用的 PyTorch 计算接口。然而,当我用 PyTorch 的官方库 TorchEval 来算 FID 指标时,却发现它的结果和多数非官方库无法对齐。我花了不少时间,总算把 TorchEval 的 FID 计算接口修好了。在这篇文章中,我将分享有关 FID 计算的知识以及我调试 TorchEval 的经历,并总结用 pytorch-fid, torch-fidelity, TorchEval 算 FID 的方法。文章最后,我还会分享一个偶然发现的用于反映模型训练时的当前 FID 的方法。

FID 指标简介

FID 的全称是 Fréchet Inception Distance,它用于衡量两个图像分布之间的差距。如果令一个图像分布是训练集,再用生成模型输出的图像构成另一个分布,那么 FID 指标就表示了生成出来的图片和训练集整体上的相似度,也就间接反映了模型对训练集的拟合程度。FID 名字中的 Fréchet Distance 是一种描述两个样本分布的距离的指标,其定位和 KL 散度一样,但某些情况下会比 KL 散度更加合适。FID 用来算 Fréchet Distance 的样本来自预训练 InceptionV3 模型,它名称中的 Inception 由此而来。

计算 FID 的过程如下:

  1. 准备两个图片文件夹。一般一个是训练集,另一个存储了生成模型随机生成的图片。
  2. 用预训练的 InceptionV3 模型把每个输入图片转换成一个 2048 维的向量。
  3. 计算训练集、生成集上输出向量的均值、协方差。
  4. 把均值、协方差代入进下面这个算 Fréchet Distance 的公式,就得到了 FID。

实际上,在用 FID 的时候我们完全不用管它的原理,只要知道它的值越小就越好,并且会调用相关接口即可。需注意的是,由于 FID 是一种和集合相关的指标,算 FID 时一定要给足图片。在构建自己模型的输出集合时,至少得有 10000 张图片,推荐生成 50000 张。否则 FID 的结果会不准确。

用 PyTorch 计算 FID 的第三方库

由于 FID 的计算需要用到一个预训练的 InceptionV3 模型,只有在模型实现完全一致的情况下,FID 的输出结果才是可比的。因此,所有论文汇报的 FID 都基于提出 FID 的作者的官方实现。这份官方实现是用 TensorFlow 写的,后来也有完全等价的 PyTorch 实现。在这一节里,我们就来学习如何用这些基于 PyTorch 的库算 FID。

GitHub 上点赞最多的 PyTorch FID 库是 pytorch-fid。这个库被 FID 官方仓库推荐,且 Stable Diffusion 论文也用了这个库,结果绝对可靠。使用该库的方法很简单,只需要先安装它。

1
pip install pytorch-fid

再准备好两个用于计算 FID 的文件夹,将文件夹路径传给脚本即可。

1
python -m pytorch_fid path/to/dataset1 path/to/dataset2

另一个较为常见的用 PyTorch 算指标的库叫做 torch-fidelity。它用起来和 pytorch-fid 一样简单。一开始,需要用 pip 安装它。

1
pip install torch-fidelity

之后,同样是准备好两个图片文件夹,将文件夹路径传给脚本。

1
fidelity --gpu 0 --fid --input1 path/to/dataset1 --input2 path/to/dataset2

除了命令行脚本外,torch-fidelity 还提供了 Python API。我们可以在 Python 脚本里加入算 FID 的代码。

1
2
3
4
5
6
7
8
import torch_fidelity

metrics_dict = torch_fidelity.calculate_metrics(
input1='path1',
input2='path2',
fid=True
)
print(metrics_dict)

torch-fidelity 还提供了其他便捷的功能。比如直接以某个生成模型为 API 的输入 input1,而不是先把图像生成到一个文件夹里,再把文件夹路径传给 input1。同时,torch-fidelity 还支持计算其他指标,我们只需要在命令行脚本或者 API 里多加几个参数就行了。

修正 TorchEval 里的 FID 计算接口

尽管这些第三方库已经足够好用了,我还是想用 PyTorch 官方近年来推出的指标计算库 TorchEval 来算 FID 指标。原因有两点:

  1. 我的项目其他地方都是用 PyTorch 官方库实现的 (torch 以及 torchvision),算指标也用官方库会让整体代码风格更加统一。我已经用 TorchEval 算了 PSNR、SSIM,使用体验还可以。
  2. 目前,似乎只有 TorchEval 支持在线更新指标的值。也就是说,我可以先生成一部分图片,储存算 FID 需要的中间结果;再生成一部分图片,最终计算此前所有图片与训练集的 FID。这种计算方法的好处我会在文章后面介绍。

以前我都是用 pytorch-fid 来算 FID。而当我换成用 TorchEval 后,却发现结果对不齐。于是,漫长的调试之路开始了。

当你有两块时间不一样的手表时,应该怎样确认时间呢?答案是,再找到第三块表。如果三块表中能有两块表时间一样,那么它们的时间就是正确的。一开始,我并不能确定是哪个库写错了,所以我又测试了 torch-fidelity 的结果。实验发现,torch-fidelity 和 pytorch-fid 的结果是一致的。并且我去确认了 Stable Diffusion 的论文,其中用来计算 FID 的库也是 pytorch-fid。看来,是 TorchEval 结果不对。

像 FID 这么常见的指标,大家的中间计算过程肯定都没错,就是一些细微的预处理不太一样。抱着这样的想法,我随意地比对了一下二者的代码,很快就发现 TorchEval 把输入尺寸调成 [299, 299] 了,而 pytorch-fid 没做。可删掉这段代码,程序直接报错了。我深入阅读了 pytorch-fid 的代码,发现它的写法和 TorchEval 不一样,把调整尺寸为 [299, 299] 写到了另一个地方。且通过调查发现,InceptionV3 网络的输入尺寸必须是 [299, 299] 的,是我孤陋寡闻了。唉,看来这次的调试不能太随意啊。

我准备拿出我的真实实力来调 bug。我认真整理了一下算 FID 的步骤,将其主要过程总结为以下几步:

  1. 用预训练权重初始化 InceptionV3
  2. 用 InceptionV3 算两个数据集输出的均值、协方差
  3. 根据均值、协方差算距离

最后那个算距离的过程不涉及任何神经网络,输出该是什么就是什么。这一块是最不容易出错,且最容易调试的。于是,我决定先排除第三步是否对齐。我把 TorchEval 得到的均值、协方差存下来,用 pytorch-fid 算距离。发现结果和原 TorchEval 的输出差不多。看来算距离这一步没有问题。

接下来,我很自然地想到是不是均值和协方差算错了。我存下了两个库得到的均值、协方差,算了两个库输出之间的误差。结果发现,均值的误差在 0.09 左右,协方差的误差在 0.0002 左右。图像的数据范围在 0~1 之间,0.09 算是一个很大的误差了。可见,第一步和第二步一定存在着无法对齐的部分。

模型输出不同,最容易想到的是模型权重不同。于是,我尝试交换使用二者的模型权重,再比较输出的 FID。两个库的模型定义不太一样,不能直接换模型文件名。我用强大的代码魔改实力强行让新权重分别都跑起来了。结果非常神奇,算上之前的两个 FID,我一共得到了 4 个不一样的 FID 结果。也就是说,A 库 A 模型、B 库 B 模型、A 库 B 模型,B 库 A 模型,结果均不一样。

我被这两个库气得不行,决定认真研究对比二者的模型定义。眼尖的我发现初始化 pytorch-fid 的 InceptionV3 时有一个参数叫 use_fid_inception。作者对此的注释写道:「如果设置为 true,则用 TensorFlow 版 FID 实现;否则,用 torchvision 版 Inception 模型。TensorFlow 的 FID Inception 模型和 torchvision 的在权重和结构上有细微的差别。如果你要计算 FID,强烈推荐将此值设置为 true,以得到和其他论文可比的结果。」总结来说,TorchEval 用的是 torchvision 里的标准 PyTorch 版 InceptionV3,而 pytorch-fid 在标准 PyTorch 版 InceptionV3 外又封装了一层,改了一些模块的定义。为什么要改这些东西呢?这是因为原来的 FID Inception 模型是在 TensorFlow 里实现的,需要改一些结构来将 PyTorch 模型对齐过去。除了模型结构外,二者的权重也有一定差别。大家都是用 TensorFlow 版模型算 FID,一切都应该以 pytorch-fid 的为准。这个 TorchEval 太离谱了,我也懒得认真修改了,直接注释掉 TorchEval 里原 FIDInceptionV3 的定义,然后大笔一挥:

1
2
from pytorch_fid.inception import \
InceptionV3 as FIDInceptionV3

按理说,这下权重和模型结构都对齐了。FID 计算的第一、第二步绝对不会有错。而开始的结果表明,FID 计算的第三步也没有错。那么,两个库就应该对齐了。我激动地又测了 TorchEval 的结果,发现结果还是无法对齐!

这不应该啊?难道哪步测错了?人生就是在不断自我怀疑中度过的。而怀疑自我,首先会怀疑最久远的自我。所以,我感觉是最早测第三步的时候有问题。之前我是把 TorchEval 的均值、协方差放到 pytorch-fid 里,结果与 TorchEval 自己的输出一致。这次我反过来,把 pytorch-fid 的均值、协方差放到 TorchEval 的算距离函数里算。这次,我第一次见到 TorchEval 输出了正确的 FID。由此可见,第三步没错。难道是均值和协方差又没对齐了?

自我怀疑开始进一步推进,我开始怀疑第二步输出的均值、协方差还是没有对齐。我再次计算了 pytorch-fid 和 TorchEval 的输出之间的误差,发现误差这次仅有 1e-16,可以认为没有区别。我花了很多时间复习协方差的计算,想找出 TorchEval 里的 bug。可是越学习,越觉得 TorchEval 写得很对。这一回,我找不到错误了。

调试代码,不怕到处有错,而怕「没错却有错」。「没错」,指的是每一步中间步骤都找不到错误;「有错」,指的是最终结果还是错了。没有错误,就得创造错误。我开启了随机乱调模式,希望能触发一个错误。回忆一下,算 FID 要用到两个数据集,一般一个是训练集,一个是模型输出的集合。在 TorchEval 最后一步算距离时,我乱改代码,让一个集合的均值、协方差不变,即来自原 TorchEval 的 Inception 模型的输出;而让另一个的集合的均值、协方差来自 pytorch-fid。理论上说,如果两个库的均值、协方差是对齐的,那么这次输出的 FID 也应该是正确的。欸,这回代码报错了,运行不了。报错说数据精度不统一。原来,TorchEval 的输出精度是 float32,而 pytorch-fid 的输出精度是 float64。之前测试距离计算函数时,数据要么全来自 TorchEval,要么全来自 pytorch-fid,所以没报过这个错。可是这个错只是一个运行上的错误,稍微改改就好了。

我把 pytorch-fid 相关数据的精度统一成了 float32。这下代码跑起来了,可 FID 不对了。调试过程中,如果上一次成功,而这一次失败,则应该想办法把代码退回上一次的,再次测试。因此,我又修改了最后用 TorchEval 计算距离的数据来源,让所有数据都来自 pytorch-fid。可是,修改后,FID 输出没变,还是错的。

为什么两轮测试之前,我全用 pytorch-fid 的输出、TorchEval 的距离计算函数没有错,这次却错了?到底是哪里不同?当测试两份差不多的代码后,一份对了,一份错了,那么错误就可以定位到两份代码的差异处。仔细回顾一下我的调试经历,相信你可以推理出 bug 出自哪了。

没错!我仔细比对了当前代码和我记忆中两轮测试前的代码,仅发现了一处不同——我把 pytorch-fid 的输出数据的精度改成了 float32。把精度改回 float64 就对了。同样,如果把 TorchEval 的输出数据的精度改成 float64,再扔进 TorchEval 的距离计算函数里算,结果也是对的。问题出在 TorchEval 的距离计算函数的数据精度上。

定位到了 bug 的位置,再找出 bug 的原因就很简单了。对比 pytorch-fid 的距离计算函数和 TorchEval 的,可以发现二者描述的计算公式完全相同。然而,pytorch-fid 是用 NumPy 算的,而 TorchEval 是用 PyTorch 算的。算 FID 的距离时,会涉及矩阵特征值等较为复杂的运算,它们对数据精度要求较高。像 NumPy 这种久经考验的库应该会自动把数据变成高精度再计算,而 PyTorch 就没做这么多细腻的处理了。

汇总一下我调试的结论。TorchEval 在权重初始化、模型计算、距离计算这三步中均有错误。前两步没有让 InceptionV3 模型和普遍使用的 TensorFlow 版对齐,最后一步没有考虑输入精度,用了不够稳定的 PyTorch API 来做复杂矩阵运算。要用 TorchEval 算出正确的 FID,需要做以下修改:

  • 安装 pytorch-fid 和 TorchEval
  • 打开 torcheval/metrics/image/fid.py
  • 注释掉 FIDInceptionV3 类,在文件开头加上 from pytorch_fid.inception import InceptionV3 as FIDInceptionV3
  • FrechetInceptionDistance 类的构造函数中,在定义所有浮点数据时加上 dtype=torch.float64

这里点名批评 TorchEval。开源的时候吹得天花乱坠,结果根本没人用,这么简单的有关 FID 的 bug 也发现不了。我发了一个修正此 bug 的相关 issue https://github.com/pytorch/torcheval/issues/192,截至目前还是没有官方人员回复。这个库的开发水平实在太逆天了,希望他们能尽快维护好。

在线计算 FID

前文提到,我用 TorchEval 的原因是它支持在线计算 FID。具体来说,可以建立一个 FID 管理类,之后用 update 方法来不断往某个集合加入新图片,并随时使用 compute 方法算出当前所有图片的 FID。我之前写代码忘了清空旧图片的中间结果时发现了一个相关应用。经我使用下来,这种应用非常有用,我们可以用它高效估计训练时的当前 FID。

回顾一下,要得到准确的 FID 值,一般需要 50000 张图片。而训练图像生成模型时,如果每次验证都要生成这么多图片,则大部分时间都会消耗在验证上了。为了加快 FID 的验证,我发现可以用一种 「全局 FID」来近似表示当前的模型拟合情况。具体来说,我先用训练集的所有图片初始化 FID 的集合 1 的中间结果,再在模型训练中每次验证时随机生成 500 张图片,将其中间结果加到 FID 的集合 2 中,并输出一次当前 FID。这样,随着训练不断推进,算 FID 的图片的数量会逐渐满足 50000 张的要求,但是这些图片并不是来自同一个模型,而是来自不同训练程度的模型。这样得到的 FID 仅能大致反映当前的真实 FID 值,有时偏高、有时偏低。但经我测试发现,这种全局 FID 的相对关系很能反映最终的真实 FID 的相对关系。训练两个不同超参的模型时,如果一个全局 FID 较大,那它最终的 FID 一般也会较大。同时,如果训练一切正常,则全局 FID 会随验证轮数单调递减(因为图片数量变多,且拟合情况不会变差)。如果某一次验证时全局 FID 增加了,则模型也一定在这段时间里变差了。通过这种验证方式,我们能够大致评估模型在训练中的拟合情况。这应该是一种很容易想到的工程技巧,但由于分享自己训练生成模型的经验帖较少,且重要性不足以写进论文,我没有在任何地方看到有人介绍这种技巧。

总结

FID 是评估图像生成模型的重要指标。通过 pytorch-fid 等库,我们能轻松地用 PyTorch 计算两个图像分布间的 FID。而通过计算输出分布和训练分布之间的 FID,我们就能评估当前模型的拟合情况。

FID 的计算本身是很简单的。所以在介绍 FID 的计算方法之外,我分享了我调试 TorchEval 的漫长过程。这段经历很有意思,我学到了不少调 bug 的新知识。此前我从来没想到过数据精度竟然会大幅影响某个值的结果。这段经历启示我们,做一些复杂运算时,不要用 PyTorch 算,最好拿 NumPy 等更稳定的库来计算。如果你调 bug 的经验不足,这段经历也能给你许多参考。

文章最后我分享了一种算全局 FID 的方法。它可以高效反映生成模型在训练时的拟合情况。该功能很容易实现,感兴趣的话可以自己尝试一下。

相信大家都在网上看过这种「笑容逐渐消失」的表情包:一张图片经过经过平滑的变形,逐渐变成另一张截然不同的图片。

对此,计算机科学中有一种专门描述此应用的任务——图像变形(image morphing)。给定两张图像,图像变形算法会输出一系列合理的插值图像。当按顺序显示这些插值图像时,它们应该能构成一个描述两张输入图像平滑变换的视频。

图像变形可以广泛运用于创意制作中。比如在做 PPT 时,我们可以在翻页处用图像变形做出炫酷的过渡效果。当然,图像变形也可以用在更严谨的场合。比如在制作游戏中的 2D 人物动画时,可以让画师只画好一系列关键帧,再用图像变形来补足中间帧。可是,这种任务对于中间插值图像的质量有着很高的要求。而传统的基于优化的图像变形算法只能对两张输入图像的像素进行一定程度的变形与混合,难以生成高质量的中间帧。有没有一种更好的图像变形算法呢?

针对这一需求,我们提出了 DiffMorpher —— 一种基于预训练扩散模型 (Stable Diiffusion)的图像变形算法。该研究由 DragGAN 作者潘新钢教授指导,经清华大学、上海人工智能实验室、南洋理工大学 S-Lab 合作完成。目前,该工作已经被 CVPR 2024 接收。

我们可以借助 DiffMorpher 实现许多应用。最简单的玩法是输入两张人脸,生成人脸的渐变图。

如果输入一系列图片,我们还能制作更长更丰富的渐变图。

而当输入的两张图片很相似时,我们可以用该工具制作出质量尚可的补间动画。

在这篇文章中,让我们来浏览一下 DiffMorpher 的工作原理,并学习如何使用这一工具。学习 DiffMorpher 的一些技术也能为我们开发其他基于扩散模型的编辑工具提供启发。

项目官网:https://kevin-thu.github.io/DiffMorpher_page/

代码仓库:https://github.com/Kevin-thu/DiffMorpher

隐变量插值

如前所述,图像变形任务在生成插值图像时不仅需要混合输入图像的内容,还需要补充生成一些内容。用预训练的图像生成模型来完成图像变形是再自然不过的想法了。前几年,已经有一些工作探究了如何用 GAN 来完成图像变形。使用 GAN 做图像变形的方法非常直接:在 GAN 中,每张被生成的图片都由一个高维隐变量决定。可以说,隐变量蕴含了生成一张图像所需的所有信息。那么,只要先使用 GAN 反演(inversion)把输入图片变成隐变量,再对隐变量做插值,就能用其生成两张输入图像的一系列中间过渡图像了。

对于隐变量,我们一般使用球面插值(slerp)而不是线性插值。

然而,GAN 生成的图像往往局限于某一类别,泛用性差。因此,用 GAN 做图像变形时,往往得不到高质量的图像插值结果。

而以 Stable Diffusion(SD)为代表的图像生成扩散模型以能生成各式各样的图像而著称。我们可以在 SD 上也用类似的过程来实现图像插值。具体来说,我们需要对 DDIM 反演得到的纯噪声图像(隐变量)进行插值,并对输入文本的嵌入进行插值,最后根据插值结果生成图像。

可是,扩散模型也存在缺陷:扩散模型的隐变量没有 GAN 的那么适合编辑。如下面的动图所示,如果仅使用简单的隐变量插值,会存在着两个问题:1)早期和晚期的中间帧和输入图像非常相近,而中期的中间帧又变化过快,图像的过渡非常突兀;2)中间帧的图像质量较低。这样的结果无法满足实际应用的要求。

LoRA 插值

扩散模型的隐变量不适合编辑,准确来说是其所在隐空间的性质导致的。模型只能处理隐空间中部分区域的隐变量。如果对隐变量稍加修改,让隐变量「跑」到了一个模型处理不了的区域,那模型就生成不了高质量的结果。而在对两个输入隐变量做插值时,插值的隐变量很可能就位于一个模型处理不了的区域。

想解决此问题,我们需要参考一些其他的工作。为了提升扩散模型的编辑单张图像的能力,一些往期工作会在单张图片上微调预训练扩散模型(即训练集只由同一张图片构成,让模型在单张图片上过拟合)。这样,无论是调整初始的隐变量还是文本输入,模型总是能够生成一些和该图片很相近的图片。比如在 Imagic 工作中,为了编辑输入图片,算法会先在输入图片上微调扩散模型,再用新的文本描述重新生成一次图片。这样,最终生成的图片既和原图很接近(鸟的外观差不多),又符合新的文本描述(鸟张开了翅膀)。

后来,许多工作用 LoRA 代替了全参数微调。LoRA 是一种高效的模型微调技术。在训练 LoRA 时,原来的模型权重不用修改,只需要训练额外引入的少量参数即可。假设原模型的参数为$W$,则 LoRA 参数可以表示为 $\Delta W$,新模型可以表示为$W + \Delta W$,其中 $\Delta W$ 里的参数比 $W$ 要少得多。训练 LoRA 的目的和全参数微调是一样的,只不过 LoRA 相对而言更加高效。

对单张图片训练 LoRA,其实就是在把整个隐空间都变成一个能生成高质量图像的空间。但付出的代价是,模型只能生成和该图片差不多的图片。

我们能不能把 LoRA 的这种性质放在隐变量插值上呢?我们可以认为,LoRA 的参数 $\Delta W$ 存储了新的隐空间的一些信息。如果我们不仅对两个输入图片的隐变量做插值,还对分别对两个输入图片训练一个 LoRA,得到$\Delta W_1, \Delta W_2$,再对两个 LoRA 的参数进行插值,得到$\Delta W = \alpha \Delta W_1 + (1-\alpha) \Delta W_2$,就能让中间的插值隐变量也能生成有意义的图片,且该图片会保留两个输入图片的性质。

相关实验结果能支撑我们的假设。下图展示了不同 LoRA 配置下,对隐变量做插值得到的结果。第一行和第二行表示分别仅使用左图或右图的 LoRA,第三行表示对 LoRA 也进行插值。可以看出,使用 LoRA 后,所有图片的质量都还不错。固定 LoRA,对隐变量做插值时,图像的风格会随隐变量变化而变化,而图像的语义内容会与训练该 LoRA 的图像相同。而对 LoRA 也进行插值的话,图像的风格、语义都会平滑过渡。

下图是前文那个例子的某些中间帧不使用 LoRA 插值和使用 LoRA 插值的结果。可以看出,使用了 LoRA 后,图像质量提升了很多。而通过对 LoRA 的插值,输出图像也会保留两个输入图像的特征。

自注意力输入的插值与替换

使用 LoRA 插值后,中间帧的图像质量得到了大幅提升,可是图像变形不连贯的问题还是没有得到解决。要提升图像变换的连贯性,还需要使用到一项和自注意力相关的技术。

深度学习中常见的注意力运算都可以表示成交叉注意力 $CrossAttn(\mathbf{x}, \mathbf{y})$,它表示数据 $\mathbf{x}$ 从数据 $\mathbf{y}$ 中获取了一次信息。交叉注意力的特例是自注意力 $SelfAttn(\mathbf{x}) = CrossAttn(\mathbf{x}, \mathbf{x})$,它表示数据 $\mathbf{x}$ 自己内部做了一次信息聚合。多数扩散模型的 U-Net 都使用了自注意力层。

由于自注意力本质上是一种交叉注意力,我们可以把另一个图像的自注意力输入替换某图像的自注意力输入。具体来说,我们可以先生成一张参考图像,将自注意力输入$\mathbf{x’}$缓存下来。再开始生成当前图片,对于原来的自注意力计算 $CrossAttn(\mathbf{x}, \mathbf{x})$,我们把第二个 $\mathbf{x}$ 换成 $\mathbf{x’}$,让计算变成 $CrossAttn(\mathbf{x}, \mathbf{x’})$。这样,在生成当前图片时,当前图片会和参考图片更加相似一些。

在扩散模型生成图像时,每个去噪时刻的每个自注意力模块的输入都有自己的意义。在替换输入时,我们必须用参考图像当前时刻当前模块的输入来做替换。

这种注意力替换技巧通常用在基于图像扩散模型的视频编辑任务里。一般我们会以输出视频的第一帧为参考图像,让生成后续帧的自注意力模块参考第一帧的信息。这样视频每一帧的风格都会更加一致。

我们可以把视频编辑任务的这种技巧挪用到图像变形任务里。在图像变形中,每一个中间帧要以一定的混合比例参考两个输入图像。那么,我们也可以先分别生成两个输入图像,缓存它们的自注意力输入$\mathbf{x}_0, \mathbf{x}_1$。在生成混合比例为 $\alpha$ 的中间帧时,我们先混合自注意力输入$\mathbf{x’} = \alpha \mathbf{x}_0 + (1-\alpha) \mathbf{x}_1$,再以 $\mathbf{x’}$ 为自注意力的第二个参数,计算 $CrossAttn(\mathbf{x_{\alpha}}, \mathbf{x’})$。

下面是不使用/使用自注意力替换的结果。可以看出,不使用注意力替换时,视频中间某帧会出现突变。而使用了注意力替换后,视频平滑了很多。

在实验中我们也发现,直接用 $\mathbf{x’}$ 来替换注意力输入会降低中间帧的质量。为了权衡质量与过渡性,我们会让替换的注意力输入在原输入 $\mathbf{x_{\alpha}}$ 和 $\mathbf{x’}$ 之间做一个混合,即令插入的注意力输入为 $\mathbf{x’} \gets \lambda \mathbf{x’} + (1-\lambda) \mathbf{x_{\alpha}}$。最终实验中我们令 $\lambda=0.6$。

重调度采样

通过注意力插值,我们解决了中间帧跳变的问题。然而,视频的变换速度还是不够平均。在开始和结束时,视频的变化速度较慢;而在中间时刻,视频的变化又过快。

我们使用了一种重新选择混合比例 $\alpha$ 的重调度策略来解决这一问题。之前,我们在选择混合比例时,是均匀地在 0~1 之间采样。比如要生成 10 段过渡,9个中间帧,我们就可以令混合比例为 $[0, 0.1, 0.2, …, 0.9, 1]$。但是,由于不同比例处插值图像的变化率不同,这样选取混合比例会导致每两帧之间变化量不均匀。

上图是一个可能的变化率分布图。图的横坐标是插值的混合比例,或者说视频渐变的时刻,图的纵坐标是图像内容随时间的变化率。每个矩形的面积表示相邻两帧之间的 LPIPS 感知误差。如果等间距采样混合比例的话,由于每个时刻的变化率不同,矩形的面积也不同,图像的变化会时快时慢。

我们希望重新选择一些采样的横坐标,使得相邻帧构成的矩形的面积尽可能一致。通过使用类似于平均颜色分布的直方图均衡化(histogram equalization)算法,我们可以得到重采样的混合比例 $[0, \alpha_1, \alpha_2, …, \alpha_{n-1}, 1]$,达到下面这种相邻帧变化量几乎相同的结果。

下面是不使用/使用重采样的结果。可以看出,二者生成的中间图像几乎是一致的,但左边的视频在开头和结尾会停顿一会儿,而右边的视频的内容一直都在均匀地变化。

在线示例与代码

看完了该工作的原理,我们来动手使用一下 DiffMorpher。我们先来运行一下在线示例。在线示例可以在 OpenXLab (https://openxlab.org.cn/apps/detail/KaiwenZhang/DiffMorpher ) 或者 HuggingFace(https://huggingface.co/spaces/Kevin-thu/DiffMorpher )上访问。

使用 WebUI 时,可以直接点击 Run 直接运行示例,或者手动上传两张图片并给定 prompt 再运行。

如果你对一些细节感兴趣,也可以手动 clone GitHub 仓库。配置环境的过程也很简单,只需要准备一个有 PyTorch 的运行环境,再安装相关 Pip 包即可。注意,该项目用的 Diffusers 版本较旧,最新的 Diffusers 可能无法成功运行,建议直接照着 requirements.txt 里的版本来。

1
2
3
git clone https://github.com/Kevin-thu/DiffMorpher.git
cd DiffMorpher
pip install -r requirements.txt

配置好了环境后,可以直接尝试仓库里自带的示例:

1
2
3
4
5
python main.py \
--image_path_0 ./assets/Trump.jpg --image_path_1 ./assets/Biden.jpg \
--prompt_0 "A photo of an American man" --prompt_1 "A photo of an American man" \
--output_path "./results/Trump_Biden" \
--use_adain --use_reschedule --save_inter

运行后,就能得到下面的结果:

总结与展望

图像变形任务的目标是在两个有对应关系的图像之间产生一系列合理的过渡帧。传统基于像素变形与混合的方法无法在中间帧里生成新内容。我们希望用包含丰富图像信息的预训练扩散模型来完成图像变形任务。然而,直接对扩散模型的隐变量插值,会出现中间帧质量低、结果不连贯这两个问题。为了解决这两个问题,我们对扩散模型生成两个输入图像时的诸多属性进行了插值,包括 LoRA 插值、自注意力插值,分别解决了中间帧质量与结果连贯性的问题。另外,加入了重调度采样后,输出视频的连贯性得到了进一步的提升。

受限于图像变形这一任务本身的上限,DiffMorpher 在实际应用中的质量难以比拟专门面向某一任务的方法(比如只做拖拽式编辑,或者只做视频插帧)。这篇工作在科研上的贡献会远大于其在应用上的贡献。方法中一些较为新颖的插值手段或许会帮助到未来的图像编辑工作。

尽管 DiffMorpher 已经算是一个不错的图像变形工具了,该方法并没有从本质上提升扩散模型的可编辑性。相比 GAN 而言,逐渐对扩散模型的隐变量修改难以产生平滑的输出结果。比如在拖拽式编辑任务中,DragGAN 只需要优化 GAN 的隐变量就能产生合理的编辑效果,而扩散模型中的类似工具(如 DragDiffusion, DragonDiffusion)需要更多设计才能达到同样的结果。从本质上提升扩散模型的可编辑性依然是一个值得研究的问题。

出于可读性的考虑,本文没有过多探讨技术细节。如果你对相关技术感兴趣,欢迎阅读我之前的文章:

LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例

Stable Diffusion 中的自注意力替换技术与 Diffusers 实现

今天,一则重磅消息席卷了 AI 圈:OpenAI 发布了视频模型 Sora,能根据文本生成长达一分钟的高质量 1920x1080 视频,生成能力远超此前只能生成 25 帧 576x1024 图像的顶尖视频生成模型 Stable Video Diffusion。

同时,OpenAI 也公布了一篇非常简短的技术报告。报告仅大致介绍了 Sora 的架构及应用场景,并未对模型的原理详加介绍。让我们来快速浏览一下这份报告,看看科研人员从这份报告中能学到什么。

官网链接:https://openai.com/sora

技术报告链接:https://openai.com/research/video-generation-models-as-world-simulators

这篇文章没怎么贴视频,感兴趣的话可以对照着原报告中的视频阅读。

LDM 与 DiT 的结合

简单来说,Sora 就是 Latent Diffusion Model (LDM) [1] 加上 Diffusion Transformer (DiT) [2]。我们先简要回顾一下这两种模型架构。

LDM 就是 Stable Diffusion 使用的模型架构。扩散模型的一大问题是计算需求大,难以拟合高分辨率图像。为了解决这一问题,实现 LDM时,会先训练一个几乎能无损压缩图像的自编码器,能把 512x512 的真实图像压缩成 64x64 的压缩图像并还原。接着,再训练一个扩散模型去拟合分辨率更低的压缩图像。这样,仅需少量计算资源就能训练出高分辨率的图像生成模型。

LDM 的扩散模型使用的模型是 U-Net。而根据其他深度学习任务中的经验,相比 U-Net,Transformer 架构的参数可拓展性强,即随着参数量的增加,Transformer 架构的性能提升会更加明显。这也是为什么大模型普遍都采用了 Transformer 架构。从这一动机出发,DiT 应运而生。DiT 在 LDM 的基础上,把 U-Net 换成了 Transformer。

顺带一提,Transformer 本来是用于文本任务的,它只能处理一维的序列数据。为了让 Transformer 处理二维图像,通常会把输入图像先切成边长为 $p$ 的图块,再把每个图块处理成一项数据。也就是说,原来边长为 $I$ 的正方形图片,经图块化后,变成了长度为 $(I/p)^2$ 的一维序列数据。

Transformer 是一种和顺序无关的计算。比如对于输入”abc”和”bca”,Transformer 会输出一模一样的值。为了描述数据的先后顺序,使用 Transformer 时,一般会给数据加一个位置编码。

Sora 是一个视频版的 DiT 模型。让我们看一下 Sora 在 DiT 上做了哪些改进。

时空自编码器

在此之前,许多工作都尝试把预训练 Stable Diffusion 拓展成视频生成模型。在拓展时,视频的每一帧都会单独输入进 Stable Diffusion 的自编码器,再重新构成一个压缩过的图像序列。而 VideoLDM[3] 工作发现,直接对视频使用之前的图像自编码器,会令输出视频出现闪烁的现象。为此,该工作对自编码器的解码器进行了微调,加入了一些能够处理时间维度的模块,使之能一次性处理整段压缩视频,并输出连贯的真实视频。

Sora 则是从头训练了一套能直接压缩视频的自编码器。相比之前的工作,Sora 的自编码器不仅能在空间上压缩图像,还能在时间上压缩视频长度。这估计是为什么 Sora 能生成长达一分钟的视频。

报告中提到,Sora 也能处理图像,即长度为1的视频。那么,自编码器怎么在时间上压缩长度为1的视频呢?报告中并没有给出细节。我猜测该自编码器在时间维度做了填充(比如时间被压缩成原来的 1/2,那么就对输入视频填充空数据直至视频长度为偶数),也可能是输入了视频长度这一额外约束信息。

时空压缩图块

输入视频经过自编码器后,会被转换成一段空间和时间维度上都变小的压缩视频。这段压缩视频就是 Sora 的 DiT 的拟合对象。在处理视频数据时,DiT 较 U-Net 又有一些优势。

之前基于 U-Net 的去噪模型在处理视频数据时(如 [3]),都需要额外加入一些和时间维度有关的操作,比如时间维度上的卷积、自注意力。而 Sora 的 DiT 是一种完全基于图块的 Transformer 架构。要用 DiT 处理视频数据,不需要这种设计,只要把视频看成一个 3D 物体,再把 3D 物体分割成「图块」,并重组成一维数据输入进 DiT 即可。和原本图像 DiT 一样,假设视频边长为 $I$,时长也为 $I$,要切成边长为 $p$ 的图块,最后会得到 $(I/p)^3$ 个数据。

报告没有给出视频图块化的细节。

处理任意分辨率、时长的视频

报告中反复提及,Sora 在训练和生成时使用的视频可以是任何分辨率(在 1920x1080 以内)、任何长宽比、任何时长的。这意味着视频训练数据不需要做缩放、裁剪等预处理。这些特性是绝大多数其他视频生成模型做不到的,让我们来着重分析一下这一特性的原理。

Sora 的这种性质还是得益于 Transformer 架构。前文提到,Transformer 的计算与输入顺序无关,必须用位置编码来指明每个数据的位置。尽管报告没有提及,我觉得 Sora 的 DiT 使用了类似于 $(x, y, t)$ 的位置编码来表示一个图块的时空位置。这样,不管输入的视频的大小如何,长度如何,只要给每个图块都分配一个位置编码,DiT 就能分清图块间的相对关系了。

相比以前的工作,Sora 的这种设计是十分新颖的。之前基于 U-Net 的 Stable Diffusion 为了保证所有训练数据可以统一被处理,输入图像都会被缩放与裁剪至同一大小。由于训练数据中有被裁剪的图像,模型偶尔也会生成被裁剪的图像。生成训练分辨率以外的图像时,模型的表现有时也会不太好。SDXL [4] 的解决方式是把裁剪的长宽做为额外信息输入进 U-Net。为了生成没有裁剪的图像,只要令输入的裁剪长宽为 0 即可。类似地,SDXL 也把图像分辨率做为额外输入,使得 U-Net 学习不同分辨率、长宽比的图像。相比 SDXL,Sora 的做法就简洁多了。

之前基于 DiT 的模型 (比如华为的 PixArt [5])似乎都没有利用到 Transformer 可以随意设置位置编码这一性质。DiT 在处理输入图块时,会先把图块变形成一维数据,再从左到右编号,即从从左到右,从上到下地给二维图块组编号。这种位置编码并没有保留图像的二维空间信息,因此,在这种编码下,模型的输入分辨率必须固定。比如对于下面这个$4\times4$的图块组,如果是从左到右、从上到下编码,模型等于是强行学习到了「1号在0号右边、4号在0号下面」这样的位置信息。如果输入的图块形状为 $4 \times 5$,那么图块间的相对关系就完全对不上了。而如果像 Sora 这样以视频图块的 $(x, y, t)$ 来生成位置编码的话,就没有这种问题了,输入视频可以是任何分辨率、任何长度。

Transformer 在视频生成的可拓展性

前文提过,Transformer 的特点就是可拓展性强,即模型越大,训练越久,效果越好。报告中展示了1倍、4倍、16倍某单位训练时间下的生成结果,可以看出模型确实一直有进步。

语言理解能力

之前大部分文生图扩散模型都是在人工标注的图片-文字数据集上训练的。后来大家发现,人工标注的图片描述质量较低,纷纷提出了各种提升标注质量的方法。Sora 复用了自家 DALL·E 3 的重标注技术,用一个训练的能生成详细描述的标注器来重新为训练视频生成标注。这种做法不仅解决了视频缺乏标注的问题,且相比人工标注质量更高。Sora 的部分结果展示了其强大了抽象理解能力(如理解人和猫之间的交互),这多半是因为视频标注模型足够强大,视频生成模型学到了视频标注模型的知识。但同样,视频标注模型的相关细节完全没有公开。

其他生成功能

  • 基于已有图像和视频进行生成:除了约束文本外,Sora 还支持在一个视频前后补充内容(如果是在一张图片后面补充内容,就是图生视频)。报告没有给出实现细节,我猜测是直接做了反演(inversion)再把反演得到的隐变量替换到随机初始隐变量中。
  • 视频编辑:报告明确写出,只用简单的 SDEdit (即目前 Stable Diffusion 中的图生图)即可实现视频编辑。
  • 视频内容融合:可能是对两个视频的初始隐变量做了插值。
  • 图像生成:当然,Sora 也可以生成图像。报告表明,Sora 可以生成最大 2048x2048 的图像。

涌现出的能力

通过学习大量数据,Sora 还涌现出一些意想不到的能力。

  • 3D 一致性:视频中包含自然的相机视角变换。之前的 Stable Video Diffusion 也有类似发现。
  • 长距离连贯性:AI 生成出来的视频往往有物体在中途突然消失的情况。而 Sora 有时候能克服这一问题。
  • 与世界的交互:比如在描述画画的视频中,画纸上的内容随画笔生成。
  • 模拟数字世界:报告展示了在输入文本有”Minecraft”时,模型能生成非常真实的 Minecraft 游戏视频。这大概只能说明模型的拟合能力太强了,以至于学会了生成 Minecraft 这一种特定风格的视频。

局限性

报告结尾还是给出了一些失败的生成示例,比如玻璃杯在桌子上没有摔碎。这表明模型还不能完全学会某些物理性质。然而,我觉得现阶段 Sora 已经展示了足够强大的学习能力。想模拟现有视频中已经包含的物理现象,只需要增加数据就行了。

总结

Sora 是一个惊艳的视频生成模型,它以卓越的生成能力(高分辨率、长时间)与生成质量令一众同期的视频生成模型黯然失色。Sora 的技术报告非常简短,不过我们从中还是可以学到一些东西。从技术贡献上来看,Sora 的创新主要有两点:

  1. 让 LDM 的自编码器也在视频时间维度上压缩。
  2. 使用了一种不限制输入形状的 DiT

其中,第二点贡献是非常有启发性的。DiT 能支持不同形状的输入,大概率是因为它以视频的3D位置生成位置编码,打破了一维编码的分辨率限制。后续大家或许会逐渐从 U-Net 转向 DiT 来建模扩散模型的去噪模型。

我认为 Sora 的成功有三个原因。前两个原因对应两项创新。第一,由于在时间维度上也进行了压缩,Sora 最终能生成长达一分钟的视频;第二,使用 DiT 不仅去除了视频空间、时间长度上的限制,还充分利用了 Transformer 本身的可拓展性,使训练一个视频生成大模型变得可能。第三个原因来自于视频标注模型。之前 Stable Diffusion 能够成功,很大程度上是因为有一个能够关联图像与文本的 CLIP 模型,且有足够多的带标注图片。相比图像,视频训练本来就少,带标注的视频就更难获得了。一个能够理解视频内容,生成详细视频标注的标注器,一定是让视频生成模型理解复杂文本描述的关键。除了这几点原因外,剩下的就是砸钱、扩大模型、加数据了。

Sora 显然会对 AIGC 社区产生一定影响。对于 AIGC 爱好者而言,他们或许会多了一些生成创意视频的方法,比如给部分帧让 Sora 来根据文本补全剩余帧。当然,目前 Sora 依然不能取代视频创作者,长视频的质量依然有待观察。对于正在开发相似应用的公司,我觉得他们应该要连夜撤销之前的方案,转换为这套没有分辨率限制的 DiT 的方案。他们的压力应该会很大。对于相关科研人员而言,除了学习这种较为新颖的 DiT 用法外,也没有太多收获了。这份技术报告透露出一股「我绝对不会开源」的意思。没有开源模型,普通的研究者也就什么都做不了。新技术的诞生绝对不可能靠一家公司,一个模型就搞定。像之前的 Stable Diffusion,也是先开源了一个基础模型,科研者和爱好者再补充了各种丰富的应用。我呼吁各大公司尽快训练并开源一个这种不限分辨率的 DiT,这样科研界或许会抛开 U-Net,基于 DiT 开发出新的扩散模型应用。

参考论文

  1. Latent Diffusion Model, Stable Difusion: High-Resolution Image Synthesis with Latent Diffusion Models
  2. DiT: Scalable Diffusion Models with Transformers
  3. VideoLDM: Align your Latents: High-Resolution Video Synthesis with Latent Diffusion Models
  4. SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis
  5. PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis

在使用预训练 Stable Diffusion (SD) 生成图像时,如果将其 U-Net 的自注意力层在某去噪时刻的输入 K, V 替换成另一幅参考图像的,则输出图像会和参考图像更加相似。许多无需训练的 SD 编辑科研工作都运用了此性质。尤其对于是对于视频编辑任务,如果在生成某一帧时将注意力输入替换成之前帧的,则输出视频会更加连贯。在这篇文章中,我们将快速学习 SD 自注意力替换技术的原理,并在 Diffusers 里实现一个基于此技术的视频编辑流水线。

注意力计算

我们先来回顾一下 Transformer 论文中提出的注意力机制。所有注意力机制都基于一种叫做放缩点乘注意力(Scaled Dot-Product Attention)的运算:

其中,$Q \in \mathbb{R}^{a \times d_k}, K \in \mathbb{R}^{b \times d_k}, V \in \mathbb{R}^{b \times d_v}$。注意力计算可以理解成先算 $a$ 个长度为 $d_k$ 的向量对 $b$ 个长度为 $d_k$ 的向量的相似度,再以此相似度为权重算 $a$ 个向量对 $b$ 个长度为 $d_v$ 的向量的加权和。

注意力计算是没有可学习参数的。为了加入参数,Transformer 设计了如下所示的注意力层,其中 $W^Q, W^K, W^V, W^O$ 都是参数。

一般在使用注意力层时,会让$K=V$。这种注意力叫做交叉注意力。交叉注意力可以理解成数据 $A$ 想从数据 $B$ 里提取信息,提取的根据是 $A$ 里每个向量和 $B$ 里每个向量的相似度。

交叉注意力的特例是自注意力,此时 $Q=K=V$ 。这表示数据里的向量两两之间交换了一次信息。

SD 中的自注意力替换

SD 的 U-Net 既用到了自注意力,也用到了交叉注意力。自注意力用于图像特征自己内部信息聚合。交叉注意力用于让生成图像对齐文本,其 Q 来自图像特征,K, V 来自文本编码。

由于自注意力其实可以看成一种特殊的交叉注意力,我们可以把自注意力的 K, V 替换成来自另一幅参考图像的特征。这样,扩散模型的生成图片会既和原本要生成的图像相似,又和参考图像相似。当然,用来替换的特征必须和原来的特征「格式一致」,不然就生成不了有意义的结果了。

什么叫「格式一致」呢?我们知道,扩散模型在采样时有很多步,U-Net 中又有许多自注意力层。每一步时的每一个自注意力层的输入都有自己的「格式」。也就是说,如果你要把某时刻某自注意力层的 K, V 替换,就得先生成参考图像,用生成参考图像过程中此时刻此自注意力层的输入替换,而不能用其他时刻或者其他自注意力层的。

一般这种编辑技术只会用在自注意力层而不是交叉注意力层上,这是因为 SD 中的交叉注意力是用来关联图像与文字的,另一幅图像的信息无法输入。当然,除了 SD,只要是用到了自注意力模块的扩散模型,都能用此方法编辑,只不过大部分工作都是基于 SD 开发的。

自注意力替换的应用

自注意力替换最常见的应用是提升 SD 视频编辑的连续性。在此任务中,一般会先正常编辑第一帧,再将后续帧的自注意力的 K, V 替换成第一帧的。这种技术在文献中一般被称为帧间注意力(cross-frame attention)。较早提出此论文的工作是 Text2Video-Zero。

自注意力替换也可以用于提升单幅图像编辑的保真度。一个例子是拖拽单幅图像的 DragonDiffusion。此应用可以拓展到图像插值上,比如 DiffMorpher 在图像插值时对两幅参考图像的自注意力输入等比例插值,再替换掉对应插值图像的自注意力的 K, V。

在 Diffusers 里实现自注意力替换

Diffusers 的 U-Net 专门提供了用于修改注意力计算的 AttentionProcessor 类。借助相关接口,我们可以方便地修改注意力的计算方法。在这个示例项目中,我们来用 Diffusers 实现一个参考第一帧和上一帧的注意力输入的 SD 视频编辑流水线。相比逐帧生成编辑图片,该流水线的结果会更加平滑一点。项目网址:https://github.com/SingleZombie/DiffusersExample/tree/main/ReplaceAttn

AttentionProcessor

在 Diffusers 中,U-Net 的每一个注意力模块都有一个 AttentionProcessor 类的实例。AttentionProcessor 类的 __call__ 方法描述了注意力计算的过程。如果我们想修改某些注意力模块的计算,就需要自己定义一个注意力处理类,其 __call__ 方法的参数需与 AttentionProcessor 的兼容。之后,我们再调用相关接口把原来的处理类换成我们自己写的处理类。下面我们将先看一下 AttentionProcessor 类的实现细节,再实现我们自己的
注意力处理类。

AttentionProcessor 类在 diffusers/models/attention_processor.py 文件里。它只有一个 __call__ 方法,其主要内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class AttnProcessor:

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states

方法参数中,hidden_states 是 Q, encoder_hidden_states 是 K, V。如果 K, V 没有传入(为 None),则 K, V 会被赋值成 Q。该方法的实现细节和 Tranformer 中的注意力层完全一样,此处就不多加解释了。一般替换注意力的输入时,我们不用改这个方法的实现,只会在需要的时候调用这个方法。

attention_processor.py 文件中还有一个功能类似的类 AttnProcessor2_0,它和 AttentionProcessor 的区别在于它调用了 PyTorch 2.0 起启用的算子 F.scaled_dot_product_attention 代替手动实现的注意力计算。这个算子更加高效,如果你确定 PyTorch 版本至少为 2.0,就可以用 AttnProcessor2_0 代替 AttentionProcessor

看完了 AttentionProcessor 类后,我们来看该怎么在 U-Net 里将原注意力处理类替换成我们自己写的。U-Net 类的 attn_processors 属性会返回一个词典,它的 key 是每个处理类所在位置,比如 down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor,它的 value 是每个处理类的实例。为了替换处理类,我们需要构建一个格式一样的词典attn_processor_dict,再调用 unet.set_attn_processor(attn_processor_dict) ,取代原来的 attn_processors。假如我们自己实现了处理类 MyAttnProcessor,我们可以编写下面的代码来实现替换:

1
2
3
4
5
6
7
8
attn_processor_dict = {}
for k in unet.attn_processors.keys():
if we_want_to_modify(k):
attn_processor_dict[k] = MyAttnProcessor()
else:
attn_processor_dict[k] = AttnProcessor()

unet.set_attn_processor(attn_processor_dict)

实现帧间注意力处理类

熟悉了 AttentionProcessor 类的相关内容,我们来编写自己的帧间注意力处理类。在处理第一帧时,该类的行为不变。对于之后的每一帧,该类的 K, V 输入会被替换成视频第一帧和上一帧的输入在序列长度维度上的拼接结果,即:

你是否会感到疑惑:为什么 K, V 的序列长度可以修改?别忘了,在注意力计算中,Q, K, V 的形状分别是:$Q \in \mathbb{R}^{a \times d_k}, K \in \mathbb{R}^{b \times d_k}, V \in \mathbb{R}^{b \times d_v}$。注意力计算只要求 K,V 的序列长度 $b$ 相同,并没有要求 Q, K 的序列长度相同。

现在,注意力计算不再是一个没有状态的计算,它的运算结果取决于第一帧和上一帧的输入。因此,我们在注意力处理类中需要额外维护这两个变量。我们可以按照如下代码编写类的构造函数。除了处理继承外,我们还需要创建两个数据词典来存储不同时间戳下第一帧和上一帧的注意力输入。

1
2
3
4
5
class CrossFrameAttnProcessor(AttnProcessor):
def __init__(self):
super().__init__()
self.first_maps = {}
self.prev_maps = {}

在运行方法中,我们根据 encoder_hidden_states 是否为空来判断该注意力是自注意力还是交叉注意力。我们仅修改自注意力。当该注意力为自注意力时,假设我们知道了当前时刻 t,我们就可以根据 t 获取当前时刻第一帧和前一帧的输入,并将它们拼接起来得到 cross_map。以此 cross_map 为当前注意力的 K, V,我们就实现了帧间注意力。

1
2
3
4
5
6
7
8
9
10
11
12
13
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs):

if encoder_hidden_states is None:
# Is self attention
cross_map = torch.cat(
(self.first_maps[t], self.prev_maps[t]), dim=1)
res = super().__call__(attn, hidden_states, cross_map, **kwargs)

else:
# Is cross attention
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)

return res

由于 Diffusers 经常修改函数接口,在调用普通的注意力计算接口时,最好原封不动地按照 super().__call__(..., **kwargs) 写,不然这份代码就不能兼容后续版本的 Diffusers。

上述代码只描述了后续帧的行为。如前所述,我们的注意力计算有两种行为:对于第一帧,我们不修改注意力的计算过程,只缓存其输入;对于之后每一帧,我们替换注意力的输入,同时维护当前「上一帧」的输入。既然注意力在不同情况下有不同行为,我们就应该用一个变量来记录当前状态,让 __call__ 能根据此变量决定当前的行为。相关的伪代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs):

if encoder_hidden_states is None:
# Is self attention
if self.state == FIRST_FRAME:
res = super().__call__(attn, hidden_states, cross_map, **kwargs)
# update maps
else:
cross_map = torch.cat(
(self.first_maps[t], self.prev_maps[t]), dim=1)
res = super().__call__(attn, hidden_states, cross_map, **kwargs)
# update maps

else:
# Is cross attention
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)

return res

在伪代码中,self.state 表示当前注意力的状态,它的值表明注意力计算是在处理第一帧还是后续帧。在视频编辑流水线中,我们应按照下面的伪代码,先编辑第一帧,再修改注意力状态后编辑后续帧。

1
2
3
4
edit(frames[0])
set_attn_state(SUBSEQUENT_FRAMES)
for i in range(1, len(frames)):
edit(frames[i])

现在,有一个问题:我们该怎么修改怎么每一个注意力模块的处理器的状态呢?显然,最直接的方式是想办法访问每一个注意力模块的处理器,再直接修改对象的属性。

1
2
3
4
modules = unet.get_attn_moduels
for module in modules:
if we_want_to_modify(module):
module.processor.state = ...

但是,每次都去遍历所有模块会让代码更加凌乱。同时,这样写也会带来代码维护上的问题:我们每次遍历注意力模块时,都可能要判断该注意力模块是否应该修改。而在用前面讲过的处理类替换方法 unet.set_attn_processor 时,我们也得判断一遍。同一段逻辑重复写在两个地方,非常不利于代码更新。

一种更优雅的实现方式是:我们定义一个状态管理类,所有注意力处理器都从同一个全局状态管理类对象里获取当前的状态信息。想修改每一个处理器的状态,不需要遍历所有对象,只需要改一次全局状态管理类对象就行了。

按照这种实现方式,我们先编写一个状态类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class AttnState:
STORE = 0
LOAD = 1

def __init__(self):
self.reset()

@property
def state(self):
return self.__state

def reset(self):
self.__state = AttnState.STORE

def to_load(self):
self.__state = AttnState.LOAD

在注意力处理类中,我们在初始化时保存状态类对象的引用,在运行时根据状态类对象获取当前状态。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class CrossFrameAttnProcessor(AttnProcessor):

def __init__(self, attn_state: AttnState):
super().__init__()
self.attn_state = attn_state
self.first_maps = {}
self.prev_maps = {}

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs):

if encoder_hidden_states is None:
# Is self attention

if self.attn_state.state == AttnState.STORE:
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)
else:
cross_map = torch.cat(
(self.first_maps[t], self.prev_maps[t]), dim=1)
res = super().__call__(attn, hidden_states, cross_map, **kwargs)
else:
# Is cross attention
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)

return res

到目前为止,假设已经维护好了之前的输入,我们的注意力处理类能执行两种不同的行为了。现在,我们来实现之前输入的维护。使用之前的注意力输入时,我们其实需要知道当前的时刻 t。当前的时刻也算是另一个状态,最好是也在状态管理类里维护。但为了简化我们的代码,我们可以偷懒让每个处理类自己维护当前时刻。具体做法是:如果知道了去噪迭代的总时刻数,我们就可以令当前时刻从0开始不断自增,直到最大时刻时,再重置为0。加入了时刻处理及之前输入维护的完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class AttnState:
STORE = 0
LOAD = 1

def __init__(self):
self.reset()

@property
def state(self):
return self.__state

@property
def timestep(self):
return self.__timestep

def set_timestep(self, t):
self.__timestep = t

def reset(self):
self.__state = AttnState.STORE
self.__timestep = 0

def to_load(self):
self.__state = AttnState.LOAD

class CrossFrameAttnProcessor(AttnProcessor):

def __init__(self, attn_state: AttnState):
super().__init__()
self.attn_state = attn_state
self.cur_timestep = 0
self.first_maps = {}
self.prev_maps = {}

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs):

if encoder_hidden_states is None:
# Is self attention

tot_timestep = self.attn_state.timestep
if self.attn_state.state == AttnState.STORE:
self.first_maps[self.cur_timestep] = hidden_states.detach()
self.prev_maps[self.cur_timestep] = hidden_states.detach()
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)
else:
tmp = hidden_states.detach()
cross_map = torch.cat(
(self.first_maps[self.cur_timestep], self.prev_maps[self.cur_timestep]), dim=1)
res = super().__call__(attn, hidden_states, cross_map, **kwargs)
self.prev_maps[self.cur_timestep] = tmp

self.cur_timestep += 1
if self.cur_timestep == tot_timestep:
self.cur_timestep = 0
else:
# Is cross attention
res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)

return res

代码中,tot_timestep 为总时刻数,cur_timestep 为当前时刻。每运算一次,cur_timestep 加一,直至总时刻时再归零。在处理第一帧时,我们把当前时刻的输入同时存入第一帧缓存 first_maps 和上一帧缓存 prev_maps 中。对于后续帧,我们先做替换过输入的注意力计算,再更新上一帧缓存 prev_maps

视频编辑流水线

准备好了我们自己写的帧间注意力处理类后,我们来编写一个简单的 Diffusers 视频处理流水线。该流水线基于 ControlNet 与图生图流水线,其主要代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class VideoEditingPipeline(StableDiffusionControlNetImg2ImgPipeline):
def __init__(
self,
...
):
super().__init__(...)
self.attn_state = AttnState()
attn_processor_dict = {}
for k in unet.attn_processors.keys():
if k.startswith("up"):
attn_processor_dict[k] = CrossFrameAttnProcessor(
self.attn_state)
else:
attn_processor_dict[k] = AttnProcessor()

self.unet.set_attn_processor(attn_processor_dict)

def __call__(self, *args, images=None, control_images=None, **kwargs):
self.attn_state.reset()
self.attn_state.set_timestep(
int(kwargs['num_inference_steps'] * kwargs['strength']))
outputs = [super().__call__(
*args, **kwargs, image=images[0], control_image=control_images[0]).images[0]]
self.attn_state.to_load()
for i in range(1, len(images)):
image = images[i]
control_image = control_images[i]
outputs.append(super().__call__(
*args, **kwargs, image=image, control_image=control_image).images[0])
return outputs

在构造函数中,我们创建了一个全局注意力状态对象 attn_state。它的引用会传给每一个帧间注意力处理对象。一般修改自注意力模块时,只会修改 U-Net 上采样部分的,而不会动下采样部分和中间部分的。因此,在过滤注意力模块时,我们的判断条件是 k.startswith("up")。把新的注意力处理器词典填完后,用 unet.set_attn_processor 更新所有的处理类对象。

1
2
3
4
5
6
7
8
9
10
self.attn_state = AttnState()
attn_processor_dict = {}
for k in unet.attn_processors.keys():
if k.startswith("up"):
attn_processor_dict[k] = CrossFrameAttnProcessor(
self.attn_state)
else:
attn_processor_dict[k] = AttnProcessor()

self.unet.set_attn_processor(attn_processor_dict)

__call__ 方法中,我们要基于原图像编辑流水线 super().__call__(),实现我们的视频编辑流水线。在这个过程中,我们的主要任务是维护好注意力管理对象中的状态。一开始,我们要把管理类重置,根据参数设置最大去噪时刻数。经重置后,注意力处理器的状态默认为 STORE,即会保存第一帧的输入。处理完第一帧后,我们运行 attn_state.to_load() 改变注意力处理器的状态,让它们每次做注意力运算时先读第一帧和上一帧的输入,再维护上一帧输入的缓存。

1
2
3
4
5
6
7
8
9
10
11
12
13
def __call__(self, *args, images=None, control_images=None,  **kwargs):
self.attn_state.reset()
self.attn_state.set_timestep(
int(kwargs['num_inference_steps'] * kwargs['strength']))
outputs = [super().__call__(
*args, **kwargs, image=images[0], control_image=control_images[0]).images[0]]
self.attn_state.to_load()
for i in range(1, len(images)):
image = images[i]
control_image = control_images[i]
outputs.append(super().__call__(
*args, **kwargs, image=image, control_image=control_image).images[0])
return outputs

运行该流水线的示例脚本在项目根目录下的 replace_attn.py 文件中。示例中使用的视频可以在 https://github.com/williamyang1991/Rerender_A_Video/blob/main/videos/pexels-koolshooters-7322716.mp4 下载,下载后应重命名为 woman.mp4。不使用和使用新注意力处理器的输出结果如下:

可以看出,虽然注意力替换不能解决生成视频的闪烁问题,但帧间的一致性提升了不少。将注意力替换技术和其他技术结合起来的话,我们就能得到一个不错的 SD 视频生成工具。

总结

扩散模型中的自注意力替换是一种常见的提升图片一致性的技术。该技术的实现方法是将扩散模型 U-Net 中自注意力的 K, V 输入替换成另一幅图片的。在这篇文章中,我们学习了一个较为复杂的基于 Diffusers 开发的自注意力替换示例项目,用于提升 SD 视频生成的一致性。在这个过程中,我们学习了和 AttentionProcessor 相关接口函数的使用,并了解了如何基于全局管理类实现一个代码可维护性强的多行为注意力处理类。如果你能看懂这篇文章的示例,那你在开发 Diffusers 的注意力处理类时基本上不会碰到任何难题。

项目网址:https://github.com/SingleZombie/DiffusersExample/tree/main/ReplaceAttn

如果你想进一步学习 Diffusers 中视频编辑流水线的开发,可以参考我给 Diffusers 写的流水线:https://github.com/huggingface/diffusers/tree/main/examples/community#Rerender_A_Video

如果你一直关注 Stable Diffusion (SD) 社区,那你一定不会对 “LoRA” 这个名词感到陌生。社区用户分享的 SD LoRA 模型能够修改 SD 的画风,使之画出动漫、水墨或像素等风格的图片。但实际上,LoRA 不仅仅能改变 SD 的画风,还有其他的妙用。在这篇文章中,我们会先简单学习 LoRA 的原理,再认识科研中 LoRA 的三种常见应用:1) 还原单幅图像;2)风格调整;3)训练目标调整,最后阅读两个基于 Diffusers 的 SD LoRA 代码实现示例。

LoRA 的原理

在认识 LoRA 之前,我们先来回顾一下迁移学习的有关概念。迁移学习指在一次新的训练中,复用之前已经训练过的模型的知识。如果你自己动手训练过深度学习模型,那你应该不经意间地使用到了迁移学习:比如你一个模型训练了 500 步,测试后发现效果不太理想,于是重新读取该模型的参数,又继续训练了 100 步。之前那个被训练过的模型叫做预训练模型(pre-trained model),继续训练预训练模型的过程叫做微调(fine-tune)。

知道了微调的概念,我们就能来认识 LoRA 了。LoRA 的全称是 Low-Rank Adaptation (低秩适配),它是一种 Parameter-Efficient Fine-Tuning (参数高效微调,PEFT) 方法,即在微调时只训练原模型中的部分参数,以加速微调的过程。相比其他的 PEFT 方法,LoRA 之所以能脱颖而出,是因为它有几个明显的优点:

  • 从性能上来看,使用 LoRA 时,只需要存储少量被微调过的参数,而不需要把整个新模型都保存下来。同时,LoRA 的新参数可以和原模型的参数合并到一起,不会增加模型的运算时间。
  • 从功能上来看,LoRA 维护了模型在微调中的「变化量」。通过用一个介于 0~1 之间的混合比例乘变化量,我们可以控制模型的修改程度。此外,基于同一个原模型独立训练的多个 LoRA 可以同时使用。

这些优点在 SD LoRA 中的体现为:

  • SD LoRA 模型一般都很小,一般只有几十 MB。
  • SD LoRA 模型的参数可以合并到 SD 基础模型里,得到一个新的 SD 模型。
  • 可以用一个 0~1 之间的比例来控制 SD LoRA 新画风的程度。
  • 可以把不同画风的 SD LoRA 模型以不同比例混合。

为什么 LoRA 能有这些优点呢?LoRA 名字中的 「低秩」又是什么意思呢?让我们从 LoRA 的优点入手,逐步揭示它原理。

上文提到过,LoRA 之所以那么灵活,是因为它维护了模型在微调过程中的变化量。那么,假设我们正在修改模型中的一个参数 $W \in \mathbb{R}^{d \times d}$,我们就应该维护它的变化量 $\Delta W \in \mathbb{R}^{d \times d}$,训练时的参数用 $W + \Delta W$ 表示。这样,想要在推理时控制模型的修改程度,只要添加一个 $\alpha \in [0, 1]$,令使用的参数为 $W + \alpha \Delta W$即可。

可是,这样做我们还是要记录一个和原参数矩阵一样大的参数矩阵 $\Delta W$,这就算不上是参数高效微调了。为此,LoRA 的作者提出假设:模型参数在微调时的变化量中蕴含的信息没有那么多。为了用更少的信息来表示参数的变化量$\Delta W$,我们可以把$\Delta W$拆解成两个低秩矩阵的乘积:

其中,$A \in \mathbb{R}^{r \times d}$, $B \in \mathbb{R}^{d \times r}$,$d$ 是一个比 $r$ 小得多的数。这样,通过用两个参数量少得多的矩阵 $A, B$ 来维护变化量,我们不仅提高了微调的效率,还保持了使用变化量来描述微调过程的灵活性。这就是 LoRA 的全部原理,它十分简单,用 $\Delta W = BA$ 这一行公式足以表示。

了解了 LoRA 的原理,我们再回头看前文提及的 LoRA 的四项优点。LoRA 模型由许多参数量较少的矩阵 $A, B$ 来表示,它可以被单独存储,且占用空间不大。由于 $\Delta W = BA$ 维护的其实是参数的变化量,我们既可以把它与预训练模型的参数加起来得到一个新模型以提高推理速度,也可以在线地用一个混合比例来灵活地组合新旧模型。LoRA 的最后一个优点是各个基于同一个原模型独立训练出来的 LoRA 模型可以混合使用。LoRA 甚至可以作用于被其他方式修改过的原模型,比如 SD LoRA 支持带 ControlNet 的 SD。这一点其实来自于社区用户的实践。一个可能的解释是,LoRA 用低秩矩阵来表示变化量,这种低秩的变化量恰好与其他方法的变化量「错开」,使得 LoRA 能向着一个不干扰其他方法的方向修改模型。

我们最后来学习一下 LoRA 的实现细节。LoRA 有两个超参数,除了上文中提到的$r$,还有一个叫$\alpha$的参数。LoRA 的作者在实现 LoRA 模块时,给修改量乘了一个 $\frac{\alpha}{r}$ 的系数,即对于输入$x$,带了 LoRA 模块后的输出为 $Wx + \frac{\alpha}{r}BAx$。作者解释说,调这个参数几乎等于调学习率,一开始令$\alpha=r$即可。在我们要反复调超参数$r$时,只要保持$\alpha$不变,就不用改其他超参数了(因为不加$\alpha$的话,改了$r$后,学习率等参数也得做相应调整以维持同样的训练条件)。当然,实际运用中,LoRA 的超参数很好调。一般令$r=4, 8, 16$即可。由于我们不怎么会变$r$,总是令$\alpha=r$就够了。

为了使用 LoRA,除了确定超参数外,我们还得指定需要被微调的参数矩阵。在 SD 中使用 LoRA 时,大家一般会对 SD 的 U-Net 的所有多头注意力模块的所有参数矩阵做微调。即对于多头注意力模块的四个矩阵 $W_Q, W_K, W_V, W_{out}$ 进行微调。

LoRA 在 SD 中的三种运用

LoRA 在 SD 的科研中有着广泛的应用。按照使用 LoRA 的动机,我们可以把 LoRA 的应用分成:1) 还原单幅图像;2)风格调整;3)训练目标调整。通过学习这些应用,我们能更好地理解 LoRA 的本质。

还原单幅图像

SD 只是一个生成任意图片的模型。为了用 SD 来编辑一张给定的图片,我们一般要让 SD 先学会生成一张一模一样的图片,再在此基础上做修改。可是,由于训练集和输入图片的差异,SD 或许不能生成完全一样的图片。解决这个问题的思路很简单粗暴:我们只用这一张图片来微调 SD,让 SD 在这张图片上过拟合。这样,SD 的输出就会和这张图片非常相似了。

较早介绍这种提高输入图片保真度方法的工作是 Imagic,只不过它采取的是完全微调策略。后续的 DragDiffusion 也用了相同的方法,并使用 LoRA 来代替完全微调。近期的 DiffMorpher 为了实现两幅图像间的插值,不仅对两幅图像单独训练了 LoRA,还通过两个 LoRA 间的插值来平滑图像插值的过程。

风格调整

LoRA 在 SD 社区中最受欢迎的应用就是风格调整了。我们希望 SD 只生成某一画风,或者某一人物的图片。为此,我们只需要在一个符合我们要求的训练集上直接训练 SD LoRA 即可。

由于这种调整 SD 风格的方法非常直接,没有特别介绍这种方法的论文。稍微值得一提的是基于 SD 的视频模型 AnimateDiff,它用 LoRA 来控制输出视频的视角变换,而不是控制画风。

由于 SD 风格化 LoRA 已经被广泛使用,能否兼容 SD 风格化 LoRA 决定了一个工作是否易于在社区中传播。

训练目标调整

最后一个应用就有一点返璞归真了。LoRA 最初的应用就是把一个预训练模型适配到另一任务上。比如 GPT 一开始在大量语料中训练,随后在问答任务上微调。对于 SD 来说,我们也可以修改 U-Net 的训练目标,以提升 SD 的能力。

有不少相关工作用 LoRA 来改进 SD。比如 Smooth Diffusion 通过在训练目标中添加一个约束项并进行 LoRA 微调来使得 SD 的隐空间更加平滑。近期比较火的高速图像生成方法 LCM-LoRA 也是把原本作用于 SD 全参数上的一个模型蒸馏过程用 LoRA 来实现。

SD LoRA 应用总结

尽管上述三种 SD LoRA 应用的设计出发点不同,它们本质上还是在利用微调这一迁移学习技术来调整模型的数据分布或者训练目标。LoRA 只是众多高效微调方法中的一种,只要是微调能实现的功能,LoRA 基本都能实现,只不过 LoRA 更轻便而已。如果你想微调 SD 又担心计算资源不够,那么用 LoRA 准没错。反过来说,你想用 LoRA 在 SD 上设计出一个新应用,就要去思考微调 SD 能够做到哪些事。

Diffusers SD LoRA 代码实战

看完了原理,我们来尝试用 Diffusers 自己训一训 LoRA。我们会先学习 Diffusers 训练 LoRA 的脚本,再学习两个简单的 LoRA 示例: SD 图像插值与 SD 图像风格迁移。

项目网址:https://github.com/SingleZombie/DiffusersExample/tree/main/LoRA

Diffusers 脚本

我们将参考 Diffusers 中的 SD LoRA 文档 https://huggingface.co/docs/diffusers/training/lora ,使用官方脚本 examples/text_to_image/train_text_to_image_lora.py 训练 LoRA。为了使用这个脚本,建议直接克隆官方仓库,并安装根目录和 text_to_image 目录下的依赖文件。本文使用的 Diffusers 版本是 0.26.0,过旧的 Diffusers 的代码可能和本文展示的有所出入。目前,官方文档也描述的是旧版的代码。

1
2
3
4
5
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
cd examples/text_to_image
pip install -r requirements.txt

这份代码使用 accelerate 库管理 PyTorch 的训练。对同一份代码,只需要修改 accelerate 的配置,就能实现单卡训练或者多卡训练。默认情况下,用 accelerate launch 命令运行 Python 脚本会使用所有显卡。如果你需要修改训练配置,请参考相关文档使用 accelerate config 命令配置环境。

做好准备后,我们来开始阅读 examples/text_to_image/train_text_to_image_lora.py 的代码。这份代码写得十分易懂,复杂的地方都有注释。我们跳过命令行参数部分,直接从 main 函数开始读。

一开始,函数会配置 accelerate 库及日志记录器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()

随后的代码决定是否手动设置随机种子。保持默认即可。

1
2
3
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)

接着,函数会创建输出文件夹。如果我们想把模型推送到在线仓库上,函数还会创建一个仓库。我们的项目不必上传,忽略所有 args.push_to_hub 即可。另外,if accelerator.is_main_process: 表示多卡训练时只有主进程会执行这段代码块。

1
2
3
4
5
6
7
8
9
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

准备完辅助工具后,函数正式开始着手训练。训练前,函数会先实例化好一切处理类,包括用于维护扩散模型中间变量的 DDPMScheduler,负责编码输入文本的 CLIPTokenizer, CLIPTextModel,压缩图像的VAE AutoencoderKL,预测噪声的 U-Net UNet2DConditionModel。参数 args.pretrained_model_name_or_path 是 Diffusers 在线仓库的地址(如runwayml/stable-diffusion-v1-5),或者本地的 Diffusers 模型文件夹。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

函数还会设置各个带参数模型是否需要计算梯度。由于我们待会要优化的是新加入的 LoRA 模型,所有预训练模型都不需要计算梯度。另外,函数还会根据 accelerate 配置自动设置这些模型的精度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# Freeze the unet parameters before adding adapters
for param in unet.parameters():
param.requires_grad_(False)

# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

把预训练模型都调好了后,函数会配置 LoRA 模块并将其加入 U-Net 模型中。最近,Diffusers 更新了添加 LoRA 的方式。Diffusers 用 Attention 处理器来描述 Attention 的计算。为了把 LoRA 加入到 Attention 模块中,早期的 Diffusers 直接在 Attention 处理器里加入可训练参数。现在,为了和其他 Hugging Face 库统一,Diffusers 使用 PEFT 库来管理 LoRA。我们不需要关注 LoRA 的实现细节,只需要写一个 LoraConfig 就行了。

PEFT 中的 LoRA 文档参见 https://huggingface.co/docs/peft/conceptual_guides/lora

LoraConfig 中有四个主要参数: r, lora_alpha, init_lora_weights, target_modulesr, lora_alpha 的意义我们已经在前文中见过了,前者决定了 LoRA 矩阵的大小,后者决定了训练速度。默认配置下,它们都等于同一个值 args.rankinit_lora_weights 表示如何初始化训练参数,gaussian是论文中使用的方法。target_modules 表示 Attention 模块的哪些层需要添加 LoRA。按照通常的做法,会给所有层,即三个输入变换矩阵 to_k, to_q, to_v 和一个输出变换矩阵 to_out.0 加 LoRA。

创建了配置后,用 unet.add_adapter(unet_lora_config) 就可以创建 LoRA 模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
for param in unet.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

更新完了 U-Net 的结构,函数会尝试启用 xformers 来提升 Attention 的效率。PyTorch 在 2.0 版本也加入了类似的 Attention 优化技术。如果你的显卡性能有限,且 PyTorch 版本小于 2.0,可以考虑使用 xformers

1
2
3
4
5
6
7
8
9
10
11
12
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers

xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
...
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

做完了 U-Net 的处理后,函数会过滤出要优化的模型参数,这些参数稍后会传递给优化器。过滤的原则很简单,如果参数要求梯度,就是待优化参数。

1
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())

之后是优化器的配置。函数先是配置了一些细枝末节的训练选项,一般可以忽略。

1
2
3
4
5
6
7
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True

然后是优化器的选择。我们可以忽略其他逻辑,直接用 AdamW

1
2
3
4
5
6
7
8
9
10
11
12
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"..."
)

optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW

选择了优化器类,就可以实例化优化器了。优化器的第一个参数是之前准备好的待优化 LoRA 参数,其他参数是 Adam 优化器本身的参数。

1
2
3
4
5
6
7
optimizer = optimizer_cls(
lora_layers,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

准备了优化器,之后需要准备训练集。这个脚本用 Hugging Face 的 datasets 库来管理数据集。我们既可以读取在线数据集,也可以读取本地的图片文件夹数据集。在本文的示例项目中,我们将使用图片文件夹数据集。稍后我们再详细学习这样的数据集文件夹该怎么构建。相关的文档可以参考 https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

训练 SD 时,每一个数据样本需要包含两项信息:图像数据与对应的文本描述。在数据集 dataset 中,每个数据样本包含了多项属性。下面的代码用于从这些属性中取出图像与文本描述。默认情况下,第一个属性会被当做图像数据,第二个属性会被当做文本。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)

准备好了数据集,接下来要定义数据预处理流程以创建 DataLoader。函数先定义了一个把文本标签预处理成 token ID 的 token 化函数。我们不需要修改它。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids

接着,函数定义了图像数据的预处理流程。该流程是用 torchvision 中的 transforms 实现的。如代码所示,处理流程中包括了 resize 至指定分辨率 args.resolution、将图像长宽均裁剪至指定分辨率、随机翻转、转换至 tensor 和归一化。

经过这一套预处理后,所有图像的长宽都会被设置为 args.resolution 。统一图像的尺寸,主要的目的是对齐数据,以使多个数据样本能拼接成一个 batch。注意,数据预处理流程中包括了随机裁剪。如果数据集里的多数图片都长宽不一致,模型会倾向于生成被裁剪过的图片。为了解决这一问题,要么自己手动预处理图片,使训练图片都是分辨率至少为 args.resolution 的正方形图片,要么令 batch size 为 1 并取消掉随机裁剪。

1
2
3
4
5
6
7
8
9
10
11
12
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.Resize(
args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(
args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

定义了预处理流程后,函数对所有数据进行预处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [
train_transforms(image) for image in images]
examples["input_ids"] = tokenize_captions(examples)
return examples

with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(
seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)

之后函数用预处理过的数据集创建 DataLoader。这里要注意的参数是 batch size args.train_batch_size 和读取数据的进程数 args.dataloader_num_workers 。这两个参数的用法和一般的 PyTorch 项目一样。args.train_batch_size 决定了训练速度,一般设置到不爆显存的最大值。如果要读取的数据过多,导致数据读取成为了模型训练的速度瓶颈,则应该提高 args.dataloader_num_workers

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"]
for example in examples])
pixel_values = pixel_values.to(
memory_format=torch.contiguous_format).float()
input_ids = torch.stack([example["input_ids"] for example in examples])
return {"pixel_values": pixel_values, "input_ids": input_ids}

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)

如果想用更大的 batch size,显存又不够,则可以使用梯度累计技术。使用这项技术时,训练梯度不会每步优化,而是累计了若干步后再优化。args.gradient_accumulation_steps 表示要累计几步再优化模型。实际的 batch size 等于输入 batch size 乘 GPU 数乘梯度累计步数。下面的代码维护了训练步数有关的信息,并创建了学习率调度器。我们按照默认设置使用一个常量学习率即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)

# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(
args.max_train_steps / num_update_steps_per_epoch)

在准备工作的最后,函数会用 accelerate 库记录配置信息。

1
2
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))

终于,要开始训练了。训练开始前,函数会准备全局变量并记录日志。

1
2
3
4
5
6
7
8
# Train!
total_batch_size = args.train_batch_size * \
accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
...
global_step = 0
first_epoch = 0

此时,如果设置了 args.resume_from_checkpoint,则函数会读取之前训练过的权重。一般继续训练时可以把该参数设为 latest,程序会自动找最新的权重。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = ...
else:
# Get the most recent checkpoint
path = ...

if path is None:
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0

随后,函数根据总步数和已经训练过的步数设置迭代器,正式进入训练循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):

训练的过程基本和 LDM 论文中展示的一致。一开始,要取出图像batch["pixel_values"] 并用 VAE 把它压缩进隐空间。

1
2
3
4
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(
dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor

再随机生成一个噪声。该噪声会套入扩散模型前向过程的公式,和输入图像一起得到 t 时刻的带噪图像。

1
2
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

下一步,这里插入了一个提升扩散模型训练质量的小技巧,用上它后输出图像的颜色分布会更合理。原理见注释中的链接。args.noise_offset 默认为 0。如果要启用这个特性,一般令 args.noise_offset = 0.1

1
2
3
4
5
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)

然后是时间戳的随机生成。

1
2
3
4
5
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

时间戳和前面随机生成的噪声一起经 DDPM 的前向过程得到带噪图片 noisy_latents

1
2
3
4
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(
latents, noise, timesteps)

再把文本 batch["input_ids"] 编码,为之后的 U-Net 前向传播做准备。

1
2
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

在 U-Net 推理开始前,函数这里做了一个关于 U-Net 输出类型的判断。一般 U-Net 都是输出预测的噪声 epsilon,可以忽略这段代码。当 U-Net 是想预测噪声时,要拟合的目标是之前随机生成的噪声 noise

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
# set prediction_type of scheduler if defined
noise_scheduler.register_to_config(
prediction_type=args.prediction_type)

if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(
latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}")

之后把带噪图像、时间戳、文本编码输入进 U-Net,U-Net 输出预测的噪声。

1
2
3
# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps,
encoder_hidden_states).sample

有了预测值,下一步是算 loss。这里又可以选择是否使用一种加速训练的技术。如果使用,则 args.snr_gamma 推荐设置为 5.0。原 DDPM 的做法是直接算预测噪声和真实噪声的均方误差。

1
2
3
4
5
6
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(),
target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
...

训练迭代的最后,要用 accelerate 库来完成梯度计算和反向传播。在更新梯度前,可以通过设置 args.max_grad_norm 来裁剪梯度,以防梯度过大。args.max_grad_norm 默认为 1.0。代码中的 if accelerator.sync_gradients: 可以保证所有 GPU 都同步了梯度再执行后续代码。

1
2
3
4
5
6
7
8
9
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers
accelerator.clip_grad_norm_(
params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

一步训练结束后,更新和步数相关的变量。

1
2
3
4
5
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0

脚本默认每 args.checkpointing_steps 步保存一次中间结果。当需要保存时,函数会清理多余的 checkpoint,再把模型状态和 LoRA 模型分别保存下来。accelerator.save_state(save_path) 负责把模型及优化器等训练用到的所有状态存下来,后面的 StableDiffusionPipeline.save_lora_weights 负责存储 LoRA 模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = ...

if len(checkpoints) >= args.checkpoints_total_limit:
# remove ckpt
...

save_path = os.path.join(
args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)

unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)

StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)

logger.info(f"Saved state to {save_path}")

训练循环的最后,函数会更新进度条上的信息,并根据当前的训练步数决定是否停止训练。

1
2
3
4
5
6
logs = {"step_loss": loss.detach().item(
), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

训完每一个 epoch 后,函数会进行验证。默认的验证方法是新建一个图像生成 pipeline,生成一些图片并保存。如果有其他验证方法,如计算某一指标,可以自行编写这部分的代码。

1
2
3
4
5
6
7
8
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = DiffusionPipeline.from_pretrained(...)
...

所有训练结束后,函数会再存一次最终的 LoRA 模型权重。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)

unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet))
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)

if args.push_to_hub:
...

函数还会再测试一次模型。具体方法和之前的验证是一样的。

1
2
3
4
# Final inference
# Load previous pipeline
if args.validation_prompt is not None:
...

运行完了这里,函数也就结束了。

1
accelerator.end_training()

为了方便使用,我把这个脚本改写了一下:删除了部分不常用的功能,并且配置参数能通过配置文件而不是命令行参数传入。新的脚本为项目根目录下的 train_lora.py,示例配置文件在 cfg 目录下。

cfg 中的某个配置文件为例,我们来回顾一下训练脚本主要用到的参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
{
"log_dir": "log",
"output_dir": "ckpt",
"data_dir": "dataset/mountain",
"ckpt_name": "mountain",
"gradient_accumulation_steps": 1,
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
"rank": 8,
"enable_xformers_memory_efficient_attention": true,
"learning_rate": 1e-4,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"adam_weight_decay": 1e-2,
"adam_epsilon": 1e-08,
"resolution": 512,
"n_epochs": 200,
"checkpointing_steps": 500,
"train_batch_size": 1,
"dataloader_num_workers": 1,
"lr_scheduler_name": "constant",
"resume_from_checkpoint": false,
"noise_offset": 0.1,
"max_grad_norm": 1.0
}

需要关注的参数:output_dir 为输出 checkpoint 的文件夹,ckpt_name 为输出 checkpoint 的文件名。data_dir 是训练数据集所在文件夹。pretrained_model_name_or_path 为 SD 模型文件夹。rank 是决定 LoRA 大小的参数。learning_rate 是学习率。adam 打头的是 AdamW 优化器的参数。resolution 是训练图片的统一分辨率。n_epochs 是训练的轮数。checkpointing_steps 指每过多久存一次 checkpoint。train_batch_size 是 batch size。gradient_accumulation_steps 是梯度累计步数。

要修改这个配置文件,要先把文件夹的路径改对,填上训练时的分辨率,再通过 gradient_accumulation_stepstrain_batch_size 决定 batch size,接着填 n_epochs (一般训 10~20 轮就会过拟合)。最后就可以一边改 LoRA 的主要超参数 rank 一边反复训练了。

SD 图像插值

在这个示例中,我们来实现 DiffMorpher 工作的一小部分,完成一个简单的图像插值工具。在此过程中,我们将学会怎么在单张图片上训练 SD LoRA,以验证我们的训练环境。

这个工具的原理很简单:我们对两张图片分别训练一个 LoRA。之后,为了获取两张图片的插值,我们可以对两张图片 DDIM Inversion 的初始隐变量及两个 LoRA 分别插值,用插值过的隐变量在插值过的 SD LoRA 上生成图片就能得到插值图片。

该示例的所有数据和代码都已经在项目文件夹中给出。首先,我们看一下该怎么在单张图片上训 LoRA。训练之前,我们要准备一个数据集文件夹。数据集文件夹及包含所有图片及一个描述文件 metadata.jsonl。比如单图片的数据集文件夹的结构应如下所示:

1
2
3
├── mountain
│ ├── metadata.jsonl
│ └── mountain.jpg

metadata.jsonl 元数据文件的每一行都是一个 json 结构,包含该图片的路径及文本描述。单图片的元数据文件如下:

1
{"file_name": "mountain.jpg", "text": "mountain"}

如果是多图片,就应该是:

1
2
3
{"file_name": "mountain.jpg", "text": "mountain"}
{"file_name": "mountain_up.jpg", "text": "mountain"}
...

我们可以运行项目目录下的数据集测试文件 test_dataset.py 来看看 datasets 库的数据集对象包含哪些信息。

1
2
3
4
5
6
7
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="dataset/mountain")
print(dataset)
print(dataset["train"].column_names)
print(dataset["train"]['image'])
print(dataset["train"]['text'])

其输出大致为:

1
2
3
4
5
6
7
8
9
10
Generating train split: 1 examples [00:00, 66.12 examples/s]
DatasetDict({
train: Dataset({
features: ['image', 'text'],
num_rows: 1
})
})
['image', 'text']
[<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F0400246670>]
['mountain']

这说明数据集对象实际上是一个词典。默认情况下,数据集放在词典的 train 键下。数据集的 column_names 属性可以返回每项数据有哪些属性。在我们的数据集里,数据的 image 是图像数据,text 是文本标签。训练脚本默认情况下会把每项数据的第一项属性作为图像,第二项属性作为文本标签。我们的这个数据集定义与训练脚本相符。

认识了数据集,我们可以来训练模型了。用下面的两行命令就可以分别在两张图片上训练 LoRA。

1
2
python train_lora.py cfg/mountain.json
python train_lora.py cfg/mountain_up.json

如果要用所有显卡训练,则应该用 accelerate。当然,对于这个简单的单图片训练,不需要用那么多显卡。

1
2
accelerate launch train_lora.py cfg/mountain.json
accelerate launch train_lora.py cfg/mountain_up.json

这两个 LoRA 模型的配置文件我们已经在前文见过了。相比普通的风格化 LoRA,这两个 LoRA 的训练轮数非常多,有 200 轮。设置较大的训练轮数能保证模型在单张图片上过拟合。

训练结束后,项目的 ckpt 文件夹下会多出两个 LoRA 权重文件: mountain.safetensor, mountain_up.safetensor。我们可以用它们来做图像插值了。

图像插值的脚本为 morph.py,它的主要内容为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from inversion_pipeline import InversionPipeline

lora_path = 'ckpt/mountain.safetensor'
lora_path2 = 'ckpt/mountain_up.safetensor'
sd_path = 'runwayml/stable-diffusion-v1-5'


pipeline: InversionPipeline = InversionPipeline.from_pretrained(
sd_path).to("cuda")
pipeline.load_lora_weights(lora_path, adapter_name='a')
pipeline.load_lora_weights(lora_path2, adapter_name='b')

img1_path = 'dataset/mountain/mountain.jpg'
img2_path = 'dataset/mountain_up/mountain_up.jpg'
prompt = 'mountain'
latent1 = pipeline.inverse(img1_path, prompt, 50, guidance_scale=1)
latent2 = pipeline.inverse(img2_path, prompt, 50, guidance_scale=1)
n_frames = 10
images = []
for i in range(n_frames + 1):
alpha = i / n_frames
pipeline.set_adapters(["a", "b"], adapter_weights=[1 - alpha, alpha])
latent = slerp(latent1, latent2, alpha)
output = pipeline(prompt=prompt, latents=latent,
guidance_scale=1.0).images[0]
images.append(output)

对于每一个 Diffusers 的 Pipeline 类实例,都可以用 pipeline.load_lora_weights 来读取 LoRA 权重。如果我们在同一个模型上使用了多个 LoRA,为了区分它们,我们要加上 adapter_name 参数为每个 LoRA 命名。稍后我们会用到这些名称。

1
2
pipeline.load_lora_weights(lora_path, adapter_name='a')
pipeline.load_lora_weights(lora_path2, adapter_name='b')

读好了文件,使用已经写好的 DDIM Inversion 方法来得到两张图片的初始隐变量。

1
2
3
4
5
img1_path = 'dataset/mountain/mountain.jpg'
img2_path = 'dataset/mountain_up/mountain_up.jpg'
prompt = 'mountain'
latent1 = pipeline.inverse(img1_path, prompt, 50, guidance_scale=1)
latent2 = pipeline.inverse(img2_path, prompt, 50, guidance_scale=1)

最后开始生成不同插值比例的图片。根据混合比例 alpha,我们可以用 pipeline.set_adapters(["a", "b"], adapter_weights=[1 - alpha, alpha]) 来融合 LoRA 模型的比例。随后,我们再根据 alpha 对隐变量插值。用插值隐变量在插值 SD LoRA 上生成图片即可得到最终的插值图片。

1
2
3
4
5
6
7
8
9
n_frames = 10
images = []
for i in range(n_frames + 1):
alpha = i / n_frames
pipeline.set_adapters(["a", "b"], adapter_weights=[1 - alpha, alpha])
latent = slerp(latent1, latent2, alpha)
output = pipeline(prompt=prompt, latents=latent,
guidance_scale=1.0).images[0]
images.append(output)

下面两段动图中,左图和右图分别是无 LoRA 和有 LoRA 的插值结果。可见,通过 LoRA 权重上的插值,图像插值的过度会更加自然。

图片风格迁移

接下来,我们来实现最流行的 LoRA 应用——风格化 LoRA。当然,训练一个每张随机输出图片都质量很高的模型是很困难的。我们退而求其次,来实现一个能对输入图片做风格迁移的 LoRA 模型。

训练风格化 LoRA 对技术要求不高,其主要难点其实是在数据收集上。大家可以根据自己的需求,准备自己的数据集。我在本文中会分享我的实验结果。我希望把《弹丸论破》的画风——一种颜色渐变较多的动漫画风——应用到一张普通动漫画风的图片上。

由于我的目标是拟合画风而不是某一种特定的物体,我直接选取了 50 张左右的游戏 CG 构成训练数据集,且没有对图片做任何处理。训风格化 LoRA 时,文本标签几乎没用,我把所有数据的文本都设置成了游戏名 danganronpa

1
2
3
{"file_name": "1.png", "text": "danganronpa"}
...
{"file_name": "59.png", "text": "danganronpa"}

我的配置文件依然和前文的相同,LoRA rank 设置为 8。我一共训了 100 轮,但发现训练后期模型的过拟合很严重,其实令 n_epochs 为 10 到 20 就能有不错的结果。50 张图片训 10 轮最多几十分钟就训完。

由于训练图片的内容不够多样,且图片预处理时加入了随机裁剪,我的 LoRA 模型随机生成的图片质量较低。于是我决定在图像风格迁移任务上测试该模型。具体来说,我使用了 ControlNet Canny 加上图生图 (SDEdit)技术。相关的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel
from PIL import Image
import cv2
import numpy as np

lora_path = '...'
sd_path = 'runwayml/stable-diffusion-v1-5'
controlnet_canny_path = 'lllyasviel/sd-controlnet-canny'

prompt = '1 man, look at right, side face, Ace Attorney, Phoenix Wright, best quality, danganronpa'
neg_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, {multiple people}'
img_path = '...'
init_image = Image.open(img_path).convert("RGB")
init_image = init_image.resize((768, 512))
np_image = np.array(init_image)

# get canny image
np_image = cv2.Canny(np_image, 100, 200)
np_image = np_image[:, :, None]
np_image = np.concatenate([np_image, np_image, np_image], axis=2)
canny_image = Image.fromarray(np_image)
canny_image.save('tmp_edge.png')

controlnet = ControlNetModel.from_pretrained(controlnet_canny_path)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
sd_path, controlnet=controlnet
)
pipe.load_lora_weights(lora_path)

output = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
strength=0.5,
guidance_scale=7.5,
controlnet_conditioning_scale=0.5,
num_inference_steps=50,
image=init_image,
cross_attention_kwargs={"scale": 1.0},
control_image=canny_image,
).images[0]
output.save("tmp.png")

StableDiffusionControlNetImg2ImgPipeline 是 Diffusers 中 ControlNet 加图生图的 Pipeline。使用它生成图片的重要参数有:

  • strength:0~1 之间重绘比例。越低越接近输入图片。
  • controlnet_conditioning_scale: 0~1 之间的 ControlNet 约束比例。越高越贴近约束。
  • cross_attention_kwargs={"scale": scale}:此处的 scale 是 0~1 之间的 LoRA 混合比例。越高越贴近 LoRA 模型的输出。

这里贴一下输入图片和两张编辑后的图片。



可以看出,输出图片中人物的画风确实得到了修改,颜色渐变更加丰富。我在几乎没有调试 LoRA 参数的情况下得到了这样的结果,可见虽然训练一个高质量的随机生成新画风的 LoRA 难度较高,但只是做风格迁移还是比较容易的。

尽管实验的经历不多,我还是基本上了解了 SD LoRA 风格化的能力边界。LoRA 风格化的本质还是修改输出图片的分布,数据集的质量基本上决定了生成的质量,其他参数的影响不会很大(包括训练图片的文本标签)。数据集最好手动裁剪至 512x512。如果想要生成丰富的风格化内容而不是只生成人物,就要丰富训练数据,减少人物数据的占比。训练时,最容易碰到的机器学习上的问题是过拟合问题。解决此问题的最简单的方式是早停,即不用最终的训练结果而用中间某一步的结果。如果你想实现改变输出数据分布以外的功能,比如精确生成某类物体、向模型中加入一些改变画风的关键词,那你应该使用更加先进的技术,而不仅仅是用最基本的 LoRA 微调。

总结

LoRA 是当今深度学习领域中常见的技术。对于 SD,LoRA 则是能够编辑单幅图片、调整整体画风,或者是通过修改训练目标来实现更强大的功能。LoRA 的原理非常简单,它其实就是用两个参数量较少的矩阵来描述一个大参数矩阵在微调中的变化量。Diffusers 库提供了非常便利的 SD LoRA 训练脚本。相信读完了本文后,我们能知道如何用 Diffusers 训练 LoRA,修改训练中的主要参数,并在简单的单图片 LoRA 编辑任务上验证训练的正确性。利用这些知识,我们也能把 LoRA 拓展到风格化生成及其他应用上。

本文的项目网址:https://github.com/SingleZombie/DiffusersExample/tree/main/LoRA

看完了Stable Diffusion的论文,在最后这篇文章里,我们来学习Stable Diffusion的代码实现。具体来说,我们会学习Stable Diffusion官方仓库及Diffusers开源库中有关采样算法和U-Net的代码,而不会学习有关训练、VAE、text encoder (CLIP) 的代码。如今大多数工作都只会用到预训练的Stable Diffusion,只学采样算法和U-Net代码就能理解大多数工作了。

建议读者在阅读本文之前了解DDPM、ResNet、U-Net、Transformer。

本文用到的Stable Diffusion版本是v1.5。Diffusers版本是0.25.0。为了提升可读性,本文对源代码做了一定的精简,部分不会运行到的分支会被略过。

算法梳理

在正式读代码之前,我们先用伪代码梳理一下Stable Diffusion的采样过程,并回顾一下U-Net架构的组成。实现Stable Diffusion的代码库有很多,各个库之间的API差异很大。但是,它们实际上都是在描述同一个算法,同一个模型。如果我们理解了算法和模型本身,就可以在学习时主动去找一个算法对应哪一段代码,而不是被动地去理解每一行代码在干什么。

LDM 采样算法

让我们从最早的DDPM开始,一步一步还原Latent Diffusion Model (LDM)的采样算法。DDPM的采样算法如下所示:

1
2
3
4
5
6
7
8
9
10
def ddpm_sample(image_shape):
ddpm_scheduler = DDPMScheduler()
unet = UNet()
xt = randn(image_shape)
T = 1000
for t in T ... 1:
eps = unet(xt, t)
std = ddpm_scheduler.get_std(t)
xt = ddpm_scheduler.get_xt_prev(xt, t, eps, std)
return xt

在DDPM的实现中,一般会有一个类专门维护扩散模型的$\alpha, \beta$等变量。我们这里把这个类称为DDPMScheduler。此外,DDPM会用到一个U-Net神经网络unet,用于计算去噪过程中图像应该去除的噪声eps。准备好这两个变量后,就可以用randn()从标准正态分布中采样一个纯噪声图像xt。它会被逐渐去噪,最终变成一幅图片。去噪过程中,时刻t会从总时刻T遍历至1(总时刻T一般取1000)。在每一轮去噪步骤中,U-Net会根据这一时刻的图像xt和当前时间戳t估计出此刻应去除的噪声eps,根据xteps就能知道下一步图像的均值。除了均值,我们还要获取下一步图像的方差,这一般可以从DDPM调度类中直接获取。有了下一步图像的均值和方差,我们根据DDPM的公式,就能采样出下一步的图像。反复执行去噪循环,xt会从纯噪声图像变成一幅有意义的图像。

DDIM对DDPM的采样过程做了两点改进:1) 去噪的有效步数可以少于T步,由另一个变量ddim_steps决定;2) 采样的方差大小可以由eta决定。因此,改进后的DDIM算法可以写成这样:

1
2
3
4
5
6
7
8
9
10
11
def ddim_sample(image_shape, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
unet = UNet()
xt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
for t in timesteps:
eps = unet(xt, t)
std = ddim_scheduler.get_std(t, eta)
xt = ddim_scheduler.get_xt_prev(xt, t, eps, std)
return xt

其中,ddim_steps是去噪循环的执行次数。根据ddim_steps,DDIM调度器可以生成所有被使用到的t。比如对于T=1000, ddim_steps=20,被使用到的就只有[1000, 950, 900, ..., 50]这20个时间戳,其他时间戳就可以跳过不算了。eta会被用来计算方差,一般这个值都会设成0

DDIM是早期的加速扩散模型采样的算法。如今有许多比DDIM更好的采样方法,但它们多数都保留了stepseta这两个参数。因此,在使用所有采样方法时,我们可以不用关心实现细节,只关注多出来的这两个参数。

在DDIM的基础上,LDM从生成像素空间上的图像变为生成隐空间上的图像。隐空间图像需要再做一次解码才能变回真实图像。从代码上来看,使用LDM后,只需要多准备一个VAE,并对最后的隐空间图像zt解码。

1
2
3
4
5
6
7
8
9
10
11
12
13
def ldm_ddim_sample(image_shape, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
for t in timesteps:
eps = unet(zt, t)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt

而想用LDM实现文生图,则需要给一个额外的文本输入text。文本编码器会把文本编码成张量c,输入进unet。其他地方的实现都和之前的LDM一样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

text_encoder = CLIP()
c = text_encoder.encode(text)

for t = timesteps:
eps = unet(zt, t, c)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt

最后这个能实现文生图的LDM就是我们熟悉的Stable Diffusion。Stable Diffusion的采样算法看上去比较复杂,但如果能够从DDPM开始把各个功能都拆开来看,理解起来就不是那么困难了。

U-Net 结构组成

Stable Diffusion代码实现中的另一个重点是去噪网络U-Net的实现。仿照上一节的学习方法,我们来逐步学习Stable Diffusion中的U-Net是怎么从最经典的纯卷积U-Net逐渐发展而来的。

最早的U-Net的结构如下图所示:

可以看出,U-Net的结构有以下特点:

  • 整体上看,U-Net由若干个大层组成。特征在每一大层会被下采样成尺寸更小的特征,再被上采样回原尺寸的特征。整个网络构成一个U形结构。
  • 下采样后,特征的通道数会变多。一般情况下,每次下采样后图像尺寸减半,通道数翻倍。上采样过程则反之。
  • 为了防止信息在下采样的过程中丢失,U-Net每一大层在下采样前的输出会作为额外输入拼接到每一大层上采样前的输入上。这种数据连接方式类似于ResNet中的「短路连接」。

DDPM则使用了一种改进版的U-Net。改进主要有两点:

  • 原来的卷积层被替换成了ResNet中的残差卷积模块。每一大层有若干个这样的子模块。对于较深的大层,残差卷积模块后面还会接一个自注意力模块。
  • 原来模型每一大层只有一个短路连接。现在每个大层下采样部分的每个子模块的输出都会额外输入到其对称的上采样部分的子模块上。直观上来看,就是短路连接更多了一点,输入信息更不容易在下采样过程中丢失。

最后,LDM提出了一种给U-Net添加额外约束信息的方法:把U-Net中的自注意力模块换成交叉注意力模块。具体来说,DDPM的U-Net的自注意力模块被换成了标准的Transformer模块。约束信息$C$可以作为Cross Attention的K, V输入进模块中。

Stable Diffusion的U-Net还在结构上有少许修改,该U-Net的每一大层都有Transformer块,而不是只有较深的大层有。

至此,我们已经学完了Stable Diffusion的采样原理和U-Net结构。接下来我们来看一看它们在不同框架下的代码实现。

Stable Diffusion 官方 GitHub 仓库

安装

克隆仓库后,照着官方Markdown文档安装即可。

1
git clone git@github.com:CompVis/stable-diffusion.git

先用下面的命令创建conda环境,此后ldm环境就是运行Stable Diffusiion的conda环境。

1
2
conda env create -f environment.yaml
conda activate ldm

之后去网上下一个Stable Diffusion的模型文件。比较常见一个版本是v1.5,该模型在Hugging Face上:https://huggingface.co/runwayml/stable-diffusion-v1-5 (推荐下载v1-5-pruned.ckpt)。下载完毕后,把模型软链接到指定位置。

1
2
mkdir -p models/ldm/stable-diffusion-v1/
ln -s <path/to/model.ckpt> models/ldm/stable-diffusion-v1/model.ckpt

准备完毕后,只要输入下面的命令,就可以生成实现文生图了。

1
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" 

在默认的参数下,“一幅骑着马的飞行员的照片”的绘制结果会被保存在outputs/txt2img-samples中。你也可以通过--outdir <dir>参数来指定输出到的文件夹。我得到的一些绘制结果为:

如果你在安装时碰到了错误,可以在搜索引擎上或者GitHub的issue里搜索,一般都能搜到其他人遇到的相同错误。

主函数

接下来,我们来探究一下scripts/txt2img.py的执行过程。为了方便阅读,我们可以简化代码中的命令行处理,得到下面这份精简代码。(你可以把这份代码复制到仓库根目录下的一个新Python脚本里并直接运行。别忘了修改代码中的模型路径)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from pytorch_lightning import seed_everything
from torch import autocast
from torchvision.utils import make_grid

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler


def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)

model.cuda()
model.eval()
return model


def main():
seed = 42
config = 'configs/stable-diffusion/v1-inference.yaml'
ckpt = 'ckpt/v1-5-pruned.ckpt'
outdir = 'tmp'
n_samples = batch_size = 3
n_rows = batch_size
n_iter = 2
prompt = 'a photograph of an astronaut riding a horse'
data = [batch_size * [prompt]]
scale = 7.5
C = 4
f = 8
H = W = 512
ddim_steps = 50
ddim_eta = 0.0

seed_everything(seed)

config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt)

device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)

os.makedirs(outdir, exist_ok=True)
outpath = outdir

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
grid_count = len(os.listdir(outpath)) - 1

start_code = None
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(
batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

all_samples.append(x_samples_ddim)
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")


if __name__ == "__main__":
main()

抛开前面一大堆初始化操作,代码的核心部分只有下面几行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(
batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

x_samples_ddim = model.decode_first_stage(samples_ddim)

我们来逐行分析一下这段代码。一开始的几行是执行Classifier-Free Guidance (CFG)。uc表示的是CFG中的无约束下的约束张量。scale表示的是执行CFG的程度,scale不等于1.0即表示启用CFG。model.get_learned_conditioning表示用CLIP把文本编码成张量。对于文本约束的模型,无约束其实就是输入文本为空字符串("")。因此,在代码中,若启用了CFG,则会用CLIP编码空字符串,编码结果为uc

如果你没学过CFG,也不用担心。你可以暂时不要去理解上面这段话。等读完了后文中有关CFG的代码后,你差不多就能理解CFG的用法了。

1
2
3
4
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(
batch_size * [""])

之后的几行是在把用户输入的文本编码成张量。同样,model.get_learned_conditioning表示用CLIP把输入文本编码成张量c

1
2
3
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)

接着是用扩散模型的采样器生成图片。在这份代码中,sampler是DDIM采样器,sampler.sample函数直接完成了图像生成。

1
2
3
4
5
6
7
8
9
10
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

最后,LDM生成的隐空间图片被VAE解码成真实图片。函数model.decode_first_stage负责图片解码。x_samples_ddim在后续的代码中会被后处理成正确格式的RGB图片,并输出至文件里。

1
x_samples_ddim = model.decode_first_stage(samples_ddim)

Stable Diffusion 官方实现的主函数主要就做了这些事情。这份实现还是有一些凌乱的。采样算法的一部分内容被扔到了主函数里,另一部分放到了DDIM采样器里。在阅读官方实现的源码时,既要去读主函数里的内容,也要去读采样器里的内容。

接下来,我们来看一看DDIM采样器的部分代码,学完采样算法的剩余部分的实现。

DDIM 采样器

回头看主函数的前半部分,DDIM采样器是在下面的代码里导入的:

1
from ldm.models.diffusion.ddim import DDIMSampler

跳转到ldm/models/diffusion/ddim.py文件,我们可以找到DDIMSampler类的实现。

先看一下这个类的构造函数。构造函数主要是把U-Net model给存了下来。后文中的self.model都指的是U-Net。

1
2
3
4
5
6
7
8
9
10
11
12
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule

# in main

config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt)
model = model.to(device)
sampler = DDIMSampler(model)

再沿着类的self.sample方法,看一下DDIM采样的实现代码。以下是self.sample方法的主要内容。这个方法其实就执行了一个self.make_schedule,之后把所有参数原封不动地传到了self.ddim_sampling里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
...
):
if conditioning is not None:
...

self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')

samples, intermediates = self.ddim_sampling(...)

self.make_schedule用于预处理扩散模型的中间计算参数。它的大部分实现细节可以略过。DDIM用到的有效时间戳列表就是在这个函数里设置的,该列表通过make_ddim_timesteps获取,并保存在self.ddim_timesteps中。此外,由ddim_eta决定的扩散模型的方差也是在这个方法里设置的。大致扫完这个方法后,我们可以直接跳到self.ddim_sampling的代码。

1
2
3
4
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
...

穿越重重的嵌套,我们总算能看到DDIM采样的实现方法self.ddim_sampling了。它的主要内容如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@torch.no_grad()
def ddim_sampling(self, ...):
device = self.model.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
timesteps = self.ddim_timesteps
intermediates = ...
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]

iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)

outs = self.p_sample_ddim(img, cond, ts, ...)
img, pred_x0 = outs

return img, intermediates

这段代码和我们之前自己写的伪代码非常相似。一开始,方法获取了在make_schedule里初始化的DDIM有效时间戳列表self.ddim_timesteps,并预处理成一个iterator。该迭代器用于控制DDIM去噪循环。每一轮循环会根据当前时刻的图像img和时间戳ts计算下一步的图像img。具体来说,代码每次用当前的时间戳step创建一个内容全部为step,形状为(b,)的张量ts。该张量会和当前的隐空间图像img,约束信息张量cond一起传给执行一轮DDIM去噪的p_sample_ddim方法。p_sample_ddim方法会返回下一步的图像img。最后,经过多次去噪后,ddim_sampling方法将去噪后的隐空间图像img返回。

p_sample_ddim里的p_sample看上去似乎意义不明,实际上这个叫法来自于DDPM论文。在DDPM论文中,扩散模型的前向过程用字母$q$表示,反向过程用字母$p$表示。因此,反向过程的一轮去噪在代码里被叫做p_sample

最后来看一下p_sample_ddim这个方法,它的主体部分如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@torch.no_grad()
def p_sample_ddim(self, x, c, t, ...):
b, *_, device = *x.shape, x.device

if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)


# Prepare variables
...

# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0

方法的内容大致可以拆成三段:首先,方法调用U-Net self.model,使用CFG来计算除这一轮该去掉的噪声e_t。然后,方法预处理出DDIM的中间变量。最后,方法根据DDIM的公式,计算出这一轮去噪后的图片x_prev。我们着重看第一部分的代码。

不启用CFG时,方法直接通过self.model.apply_model(x, t, c)调用U-Net,算出这一轮的噪声e_t。而想启用CFG,需要输入空字符串的约束张量unconditional_conditioning,且CFG的强度unconditional_guidance_scale不为1。CFG的执行过程是:对U-Net输入不同的约束c,先用空字符串约束得到一个预测噪声e_t_uncond,再用输入的文本约束得到一个预测噪声e_t。之后令e_t = et_uncond + scale * (e_t - e_t_uncond)scale大于1,即表明我们希望预测噪声更加靠近有输入文本的那一个。直观上来看,scale越大,最后生成的图片越符合输入文本,越偏离空文本。下面这段代码正是实现了上述这段逻辑,只不过代码使用了一些数据拼接技巧,让空字符串约束下和输入文本约束下的结果在一次U-Net推理中获得。

1
2
3
4
5
6
7
8
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

p_sample_ddim 方法的后续代码都是在实现下面这个DDIM采样公式。代码工工整整地计算了公式中的predicted_x0, dir_xt, noise,非常易懂,没有需要特别注意的地方。

我们已经看完了p_sample_ddim的代码。该方法可以实现一步去噪操作。多次调用该方法去噪后,我们就能得到生成的隐空间图片。该图片会被返回到main函数里,被VAE的解码器解码成普通图片。至此,我们就学完了Stable Diffusion官方仓库的采样代码。

对照下面这份我们之前写的伪代码,我们再来梳理一下Stable Diffusion官方仓库的代码逻辑。官方仓库的采样代码一部分在main函数里,另一部分在ldm/models/diffusion/ddim.py里。main函数主要完成了编码约束文字、解码隐空间图像这两件事。剩下的DDIM采样以及各种Diffusion图像编辑功能都是在ldm/models/diffusion/ddim.py文件中实现的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0)
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
eta = input()
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

text_encoder = CLIP()
c = text_encoder.encode(text)

for t = timesteps:
eps = unet(zt, t, c)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt

在学习代码时,要着重学习DDIM采样器部分的代码。大部分基于Diffusion的图像编辑技术都是在DDIM采样的中间步骤中做文章,只要学懂了DDIM采样的代码,学相关图像编辑技术就会非常轻松。除此之外,和LDM相关的文字约束编码、隐空间图像编码解码的接口函数也需要熟悉,不少技术会调用到这几项功能。

还有一些Diffusion相关工作会涉及U-Net的修改。接下来,我们就来看Stable Diffusion官方仓库中U-Net的实现。

U-Net

我们来回头看一下main函数和DDIM采样中U-Net的调用逻辑。和U-Net有关的代码如下所示。LDM模型类 model在主函数中通过load_model_from_config从配置文件里创建,随后成为了sampler的成员变量。在DDIM去噪循环中,LDM模型里的U-Net会在self.model.apply_model方法里被调用。

1
2
3
4
5
6
7
8
# main.py
config = 'configs/stable-diffusion/v1-inference.yaml'
config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt)
sampler = DDIMSampler(model)

# ldm/models/diffusion/ddim.py
e_t = self.model.apply_model(x, t, c)

为了知道U-Net是在哪个类里定义的,我们需要打开配置文件 configs/stable-diffusion/v1-inference.yaml。该配置文件有这样一段话:

1
2
3
4
5
6
model:
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
conditioning_key: crossattn
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel

根据这段话,我们知道LDM类定义在ldm/models/diffusion/ddpm.pyLatentDiffusion里,U-Net类定义在ldm/modules/diffusionmodules/openaimodel.pyUNetModel里。一个LDM类有一个U-Net类的实例。我们先简单看一看LatentDiffusion类的实现。

ldm/models/diffusion/ddpm.py原本来自DDPM论文的官方仓库,内含DDPM类的实现。DDPM类维护了扩散模型公式里的一些变量,同时维护了U-Net类的实例。LDM的作者基于之前DDPM的代码进行开发,定义了一个继承自DDPMLatentDiffusion类。除了DDPM本身的功能外,LatentDiffusion还维护了VAE(self.first_stage_model),CLIP(self.cond_stage_model)。也就是说,LatentDiffusion主要维护了扩散模型中间变量、U-Net、VAE、CLIP这四类信息。这样,所有带参数的模型都在LatentDiffusion里,我们可以从一个checkpoint文件中读取所有的模型的参数。相关代码定义代码如下:

把所有模型定义在一起有好处也有坏处。好处在于,用户想使用Stable Diffusion时,只需要下载一个checkpoint文件就行了。坏处在于,哪怕用户只改了某个子模型(如U-Net),为了保存整个模型,他还是得把其他子模型一起存下来。这其中存在着信息冗余,十分不灵活。Diffusers框架没有把模型全存在一个文件里,而是放到了一个文件夹里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
...):
self.model = DiffusionWrapper(unet_config, conditioning_key)


class LatentDiffusion(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
...):

self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)

我们主要关注LatentDiffusion类的apply_model方法,它用于调用U-Net self.modelapply_model看上去有很长,但略过了我们用不到的一些代码后,整个方法其实非常短。一开始,方法对输入的约束信息编码cond做了一个前处理,判断约束是哪种类型。如论文里所描述的,LDM支持两种约束:将约束与输入拼接、将约束注入到交叉注意力层中。方法会根据self.model.conditioning_keyconcat还是crossattn,使用不同的约束方式。Stable Diffusion使用的是后者,即self.model.conditioning_key == crossattn。做完前处理后,方法执行了x_recon = self.model(x_noisy, t, **cond)。接下来的处理交给U-Net self.model来完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}

x_recon = self.model(x_noisy, t, **cond)

if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon

现在,我们跳转到ldm/modules/diffusionmodules/openaimodel.pyUNetModel类里。UNetModel只定义了神经网络层的运算,没有多余的功能。我们只需要看它的__init__方法和forward方法。我们先来看较为简短的forward方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)

forward方法的输入是x, timesteps, context,分别表示当前去噪时刻的图片、当前时间戳、文本约束编码。根据这些输入,forward会输出当前时刻应去除的噪声eps。一开始,方法会先对timesteps使用Transformer论文中介绍的位置编码timestep_embedding,得到时间戳的编码t_embt_emb再经过几个线性层,得到最终的时间戳编码emb。而context已经是CLIP处理过的编码,它不需要做额外的预处理。时间戳编码emb和文本约束编码context随后会注入到U-Net的所有中间模块中。

1
2
3
4
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

经过预处理后,方法开始处理U-Net的计算。中间结果h会经过U-Net的下采样模块input_blocks,每一个子模块的临时输出都会被保存进一个栈hs里。

1
2
3
4
 h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)

接着,h会经过U-Net的中间模块。

1
h = self.middle_block(h, emb, context)

随后,h开始经过U-Net的上采样模块output_blocks。此时每一个编码器子模块的临时输出会从栈hs里弹出,作为对应解码器子模块的额外输入。额外输入hs.pop()会与中间结果h拼接到一起输入进子模块里。

1
2
3
4
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)

最后,h会被输出层转换成一个通道数正确的eps张量。

1
return self.out(h)

这段代码的数据连接图如下所示:

在阅读__init__前,我们先看一下待会会用到的另一个模块类TimestepEmbedSequential的定义。在PyTorch中,一系列输入和输出都只有一个变量的模块在串行连接时,可以用串行模块类nn.Sequential来把多个模块合并简化成一个模块。而在扩散模型中,多数模块的输入是x, t, c三个变量,输出是一个变量。为了也能用类似的串行模块类把扩散模型的模块合并在一起,代码中包含了一个TimestepEmbedSequential类。它的行为类似于nn.Sequential,只不过它支持x, t, c的输入。forward中用到的多数模块都是通过TimestepEmbedSequential创建的。

1
2
3
4
5
6
7
8
9
10
11
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):

def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x

看完了数据的计算过程,我们回头来看各个子模块在__init__方法中是怎么被详细定义的。__init__的主要内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class UNetModel(nn.Module):
def __init__(self, ...):

self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)

self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)

for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(...)]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...))

self.input_blocks.append(TimestepEmbedSequential(*layers))
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(...)
if resblock_updown
else Downsample(...)
)
)

self.middle_block = TimestepEmbedSequential(
ResBlock(...),
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...),
ResBlock(...),
)

self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(...)
]
ch = model_channels * mult
if ds in attention_resolutions:
layers.append(
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(...)
if resblock_updown
else Upsample(...)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)

__init__方法的代码很长。在阅读这样的代码时,我们不需要每一行都去细读,只需要理解代码能拆成几块,每一块在做什么即可。__init__方法其实就是定义了forward中用到的5个模块,我们一个一个看过去即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class UNetModel(nn.Module):
def __init__(self, ...):

self.time_embed = ...

self.input_blocks = nn.ModuleList(...)
for level, mult in enumerate(channel_mult):
...

self.middle_block = ...

self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
...
self.out = ...

先来看time_embed。回忆一下,在forward里,输入的整数时间戳会被正弦编码timestep_embedding(即Transformer中的位置编码)编码成一个张量。之后,时间戳编码处理模块time_embed用于进一步提取时间戳编码的特征。从下面的代码中可知,它本质上就是一个由两个普通线性层构成的模块。

1
2
3
4
5
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)

再来看U-Net最后面的输出模块out。输出模块的结构也很简单,它主要包含了一个卷积层,用于把中间变量的通道数从dims变成model_channels

1
2
3
4
5
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)

接下来,我们把目光聚焦在U-Net的三个核心模块上:input_blocks, middle_block, output_blocks。这三个模块的组成都很类似,都用到了残差块ResBlock和注意力块。稍有不同的是,input_blocks的每一大层后面都有一个下采样模块,output_blocks的每一大层后面都有一个上采样模块。上下采样模块的结构都很常规,与经典的U-Net无异。我们把学习的重点放在残差块和注意力块上。我们先看这两个模块的内部实现细节,再来看它们是怎么拼接起来的。

Stable Diffusion的U-Net中的ResBlock和原DDPM的U-Net的ResBlock功能完全一样,都是在普通残差块的基础上,支持时间戳编码的额外输入。具体来说,普通的残差块是由两个卷积模块和一条短路连接构成的,即y = x + conv(conv(x))。如果经过两个卷积块后数据的通道数发生了变化,则要在短路连接上加一个转换通道数的卷积,即y = conv(x) + conv(conv(x))

在这种普通残差块的基础上,扩散模型中的残差块还支持时间戳编码t的输入。为了把t和输入x的信息融合在一起,t会和经过第一个卷积后的中间结果conv(x)加在一起。可是,t的通道数和conv(x)的通道数很可能会不一样。通道数不一样的数据是不能直接加起来的。为此,每一个残差块中都有一个用于转换t通道数的线性层。这样,tconv(x)就能相加了。整个模块的计算可以表示成y=conv(x) + conv(conv(x) + linear(t))。残差块的示意图和源代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class ResBlock(TimestepBlock):
def __init__(self, ...):
super().__init__()
...

self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)

self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)

if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

def forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h

代码中的in_layers是第一个卷积模块,out_layers是第二个卷积模块。skip_connection是用于调整短路连接通道数的模块。若输入输出的通道数相同,则该模块是一个恒等函数,不对数据做任何修改。emb_layers是调整时间戳编码通道数的线性层模块。这些模块的定义都在ResBlock__init__里。它们的结构都很常规,没有值得注意的地方。我们可以着重阅读模型的forward方法。

如前文所述,在forward中,输入x会先经过第一个卷积模块in_layers,再与经过了emb_layers调整的时间戳编码emb相加后,输入进第二个卷积模块out_layers。最后,做完计算的数据会和经过了短路连接的原输入skip_connection(x)加在一起,作为整个残差块的输出。

1
2
3
4
5
6
7
8
def forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h

这里有一点实现细节需要注意。时间戳编码emb_out的形状是[n, c]。为了把它和形状为[n, c, h, w]的图片加在一起,需要把它的形状变成[n, c, 1, 1]后再相加(形状为[n, c, 1, 1]的数据在与形状为[n, c, h, w]的数据做加法时形状会被自动广播成[n, c, h, w])。在PyTorch中,x=x[..., None]可以在一个数据最后加一个长度为1的维度。比如对于形状为[n, c]tt[..., None]的形状就会是[n, c, 1]

残差块的内容到此结束。我们接着来看注意力模块。在看模块的具体实现之前,我们先看一下源代码中有哪几种注意力模块。在U-Net的代码中,注意力模型是用以下代码创建的:

1
2
3
4
if ds in attention_resolutions:
layers.append(
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...)
)

第一行if ds in attention_resolutions:用于控制在U-Net的哪几个大层。Stable Diffusion每一大层都用了注意力模块,可以忽略这一行。随后,代码根据是否设置use_spatial_transformer来创建AttentionBlock或是SpatialTransformerAttentionBlock是DDPM中采样的普通自注意力模块,而SpatialTransformer是LDM中提出的支持额外约束的标准Transfomer块。Stable Diffusion使用的是SpatialTransformer。我们就来看一看这个模块的实现细节。

如前所述,SpatialTransformer使用的是标准的Transformer块,它和Transformer中的Transformer块完全一致。输入x先经过一个自注意力层,再过一个交叉注意力层。在此期间,约束编码c会作为交叉注意力层的K, V输入进模块。最后,数据经过一个全连接层。每一层的输入都会和输出做一个残差连接。

当然,标准Transformer是针对一维序列数据的。要把Transformer用到图像上,则需要把图像的宽高拼接到同一维,即对张量做形状变换n c h w -> n c (h * w)。做完这个变换后,就可以把数据直接输入进Transformer模块了。
这些图像数据与序列数据的适配都是在SpatialTransformer类里完成的。SpatialTransformer类并没有直接实现Transformer块的细节,仅仅是U-Net和Transformer块之间的一个过渡。Transformer块的实现在它的一个子模块里。我们来看它的实现代码。

SpatialTransformer有两个卷积层proj_in, proj_out,负责图像通道数与Transformer模块通道数之间的转换。SpatialTransformertransformer_blocks才是真正的Transformer模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class SpatialTransformer(nn.Module):

def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)

self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)

self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
)

self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))

forward中,图像数据在进出Transformer模块前后都会做形状和通道数上的适配。运算结束后,结果和输入之间还会做一个残差连接。context就是约束信息编码,它会接入到交叉注意力层上。

1
2
3
4
5
6
7
8
9
10
11
12

def forward(self, x, context=None):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in

每一个Transformer模块的结构完全符合上文的示意图。如果你之前学过Transformer,那这些代码你会十分熟悉。我们快速把这部分代码浏览一遍。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint

def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x

自注意力层和交叉注意力层都是用CrossAttention类实现的。该模块与Transformer论文中的多头注意力机制完全相同。当forward的参数context=None时,模块其实只是一个提取特征的自注意力模块;而当context为约束文本的编码时,模块就是一个根据文本约束进行运算的交叉注意力模块。该模块用不到mask,相关的代码可以忽略。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)

def forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

if exists(mask):
...

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

Transformer块的内容到此结束。看完了SpatialTransformerResBlock,我们可以回头去看模块之间是怎么拼接的了。先来看U-Net的中间块。它其实就是一个ResBlock接一个SpatialTransformer再接一个ResBlock

1
2
3
4
5
self.middle_block = TimestepEmbedSequential(
ResBlock(...),
SpatialTransformer(...),
ResBlock(...),
)

下采样块input_blocks和上采样块output_blocks的结构几乎一模一样,区别只在于每一大层最后是做下采样还是上采样。这里我们以下采样块为例来学习一下这两个块的结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)

for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(...)]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...))

self.input_blocks.append(TimestepEmbedSequential(*layers))
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(...)
if resblock_updown
else Downsample(...)
)
)

上采样块一开始是一个调整输入图片通道数的卷积层,它的作用和self.out输出层一样。

1
2
3
4
5
6
7
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)

之后正式进行上采样块的构造。此处代码有两层循环,外层循环表示正在构造哪一个大层,内层循环表示正在构造该大层的哪一组模块。也就是说,共有len(channel_mult)个大层,每一大层都有num_res_blocks组相同的模块。在Stable Diffusion中,channel_mult=[1, 2, 4, 4], num_res_blocks=2

1
2
3
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
...

每一组模块由一个ResBlock和一个SpatialTransformer构成。

1
2
3
4
5
6
7
8
9
10
11
layers = [
ResBlock(...)
]
ch = mult * model_channels
if ds in attention_resolutions:
...
layers.append(
SpatialTransformer(...)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
...

构造完每一组模块后,若现在还没到最后一个大层,则添加一个下采样模块。Stable Diffusion有4个大层,只有运行到前3个大层时才会添加下采样模块。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
...
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(...)
if resblock_updown
else Downsample(...)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2

至此,我们已经学完了Stable Diffusion的U-Net的主要实现代码。让我们来总结一下。U-Net是一种先对数据做下采样,再做上采样的网络结构。为了防止信息丢失,下采样模块和对应的上采样模块之间有残差连接。下采样块、中间块、上采样块都包含了ResBlockSpatialTransformer两种模块。ResBlock是图像网络中常使用的残差块,而SpatialTransformer是能够融合图像全局信息并融合不同模态信息的Transformer块。Stable Diffusion的U-Net的输入除了有图像外,还有时间戳t和约束编码ct会先过几个嵌入层和线性层,再输入进每一个ResBlock中。c会直接输入到所有Transformer块的交叉注意力块中。

Diffusers

Diffusers是由Hugging Face维护的一套Diffusion框架。这个库的代码被封装进了一个Python模块里,我们可以在安装了Diffusers的Python环境中用import diffusers随时调用该库。相比之下,Diffusers的代码架构更加清楚,且各类Stable Diffusion的新技术都会及时集成进Diffusers库中。

由于我们已经在上文中学过了Stable Diffusion官方源码,在学习Diffusers代码时,我们只会大致过一过每一段代码是在做什么,而不会赘述Stable Diffusion的原理。

安装

安装该库时,不需要克隆仓库,只需要直接用pip即可。

1
pip install --upgrade diffusers[torch]

之后,随便在某个地方创建一个Python脚本文件,输入官方的示例项目代码。

1
2
3
4
5
6
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline.to("cuda")
pipeline("An image of a squirrel in Picasso style").images[0].save('output.jpg')

运行代码后,”一幅毕加索风格的松鼠图片”的绘制结果会保存在output.jpg中。我得到的结果如下:

在Diffusers中,from_pretrained函数可以直接从Hugging Face的模型仓库中下载预训练模型。比如,示例代码中from_pretrained("runwayml/stable-diffusion-v1-5", ...)指的就是从模型仓库https://huggingface.co/runwayml/stable-diffusion-v1-5中获取模型。

如果在当前网络下无法从命令行中访问Hugging Face,可以先想办法在网页上访问上面的模型仓库,手动下载v1-5-pruned.ckpt。之后,克隆Diffusers的GitHub仓库,再用Diffusers的工具把Stable Diffusion模型文件转换成Diffusers支持的模型格式。

1
2
3
git clone git@github.com:huggingface/diffusers.git
cd diffusers
python scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path <src> --dump_path <dst>

比如,假设你的模型文件存在ckpt/v1-5-pruned.ckpt,你想把输出的Diffusers的模型文件存在ckpt/sd15,则应该输入:

1
python scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path ckpt/v1-5-pruned.ckpt --dump_path ckpt/sd15 

之后修改示例脚本中的路径,就可以成功运行了。
1
2
3
4
5
6
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("ckpt/sd15", torch_dtype=torch.float16)
pipeline.to("cuda")
pipeline("An image of a squirrel in Picasso style").images[0].save('output.jpg')

对于其他的原版SD checkpoint(比如在civitai上下载的),也可以用同样的方式把它们转换成Diffusers兼容的版本。

采样

Diffusers使用Pipeline来管理一类图像生成算法。和图像生成相关的模块(如U-Net,DDIM采样器)都是Pipeline的成员变量。打开Diffusers版Stable Diffusion模型的配置文件model_index.json(在 https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json 网页上直接访问或者在本地的模型文件夹中找到),我们能看到该模型使用的Pipeline:

1
2
3
4
{
"_class_name": "StableDiffusionPipeline",
...
}

diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py中,我们能找到StableDiffusionPipeline类的定义。所有Pipeline类的代码都非常长,一般我们可以忽略其他部分,只看运行方法__call__里的内容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
...
):

# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks

# 1. Check inputs. Raise error if not correct
self.check_inputs(...)

# 2. Define call parameters
batch_size = ...

device = self._execution_device

# 3. Encode input prompt


prompt_embeds, negative_prompt_embeds = self.encode_prompt(...)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(...)

# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
...

# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
...
)[0]

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]


# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()


if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None

...

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

虽然这段代码很长,但代码中的关键内容和我们在本文开头写的伪代码完全一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0)
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
eta = input()
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

text_encoder = CLIP()
c = text_encoder.encode(text)

for t = timesteps:
eps = unet(zt, t, c)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt

我们可以对照着上面的伪代码来阅读这个方法。经过Diffusers框架本身的一些前处理后,方法先获取了约束文本的编码。

1
2
3
# 3. Encode input prompt
# c = text_encoder.encode(text)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(...)

方法再从采样器里获取了要用到的时间戳,并随机生成了一个初始噪声。

1
2
3
4
5
6
7
8
9
10
11
12
13
# Preprocess
...

# 4. Prepare timesteps
# timesteps = ddim_scheduler.get_timesteps(T, ddim_steps)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

# 5. Prepare latent variables
# zt = randn(image_shape)
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
...
)

做完准备后,方法进入去噪循环。循环一开始是用U-Net算出当前应去除的噪声noise_pred。由于加入了CFG,U-Net计算的前后有一些对数据形状处理的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# eps = unet(zt, t, c)

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
...
)[0]

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

有了应去除的噪声,方法会调用扩散模型采样器对当前的噪声图片进行更新。Diffusers把采样的逻辑全部封装进了采样器的step方法里。对于包括DDIM在内的所有采样器,都可以调用这个通用的接口,完成一步采样。eta等采样器参数会通过**extra_step_kwargs传入采样器的step方法里。

1
2
3
4
5
# std = ddim_scheduler.get_std(t, eta)
# zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

经过若干次循环后,我们得到了隐空间下的生成图片。我们还需要调用VAE把隐空间图片解码成普通图片。代码中的self.vae.decode(latents / self.vae.config.scaling_factor, ...)用于解码图片。

1
2
3
4
5
6
7
8
9
10
11
12
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None

...

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

就这样,我们很快就看完了Diffusers的采样代码。相比之下,Diffusers的封装确实更合理,主要的图像生成逻辑都写在Pipeline类的__call__里,剩余逻辑都封装在VAE、U-Net、采样器等各自的类里。

U-Net

接下来我们来看Diffusers中的U-Net实现。还是打开模型配置文件model_index.json,我们可以找到U-Net的类名。

1
2
3
4
5
6
7
8
{
...
"unet": [
"diffusers",
"UNet2DConditionModel"
],
...
}

diffusers/models/unet_2d_condition.py文件中,我们可以找到类UNet2DConditionModel。由于Diffusers集成了非常多新特性,整个文件就像一锅大杂烩一样,掺杂着各种功能的实现代码。不过,这份U-Net的实现还是基于原版Stable Diffusion的U-Net进行开发的,原版代码的每一部分都能在这份代码里找到对应。在阅读代码时,我们可以跳过无关的功能,只看我们在Stable Diffusion官方仓库中见过的部分。

先看初始化函数的主要内容。初始化函数依然主要包括time_proj, time_embedding, down_blocks, mid_block, up_blocks, conv_in, conv_out这几个模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def __init__(...):
...
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
...
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
...
self.time_embedding = TimestepEmbedding(...)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
for i, down_block_type in enumerate(down_block_types):
...
down_block = get_down_block(...)

if mid_block_type == ...
self.mid_block = ...

for i, up_block_type in enumerate(up_block_types):
up_block = get_up_block(...)

self.conv_out = nn.Conv2d(...)

其中,较为重要的down_blocks, mid_block, up_blocks都是根据模块类名称来创建的。我们可以在Diffusers的Stable Diffusion模型文件夹的U-Net的配置文件unet/config.json中找到对应的模块类名称。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{
...
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"mid_block_type": "UNetMidBlock2DCrossAttn",
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
],
...
}

diffusers/models/unet_2d_blocks.py中,我们可以找到这几个模块类的定义。和原版代码一样,这几个模块的核心组件都是残差块和Transformer块。在Diffusers中,残差块叫做ResnetBlock2D,Transformer块叫做Transformer2DModel。这几个类的执行逻辑和原版仓库的也几乎一样。比如CrossAttnDownBlock2D的定义如下:

1
2
3
4
5
6
class CrossAttnDownBlock2D(nn.Module):
def __init__(...):
for i in range(num_layers):
resnets.append(ResnetBlock2D(...))
if not dual_cross_attention:
attentions.append(Transformer2DModel(...))

接着我们来看U-Net的forward方法。忽略掉其他功能的实现,该方法的主要内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
...):

# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0

# 1. time
timesteps = timestep
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb, timestep_cond)

# 2. pre-process
sample = self.conv_in(sample)

# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
...)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
...)

# 5. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
...)

# 6. post-process
sample = self.conv_out(sample)

return UNet2DConditionOutput(sample=sample)


该方法和原版仓库的实现差不多,唯一要注意的是栈相关的实现。在方法的下采样计算中,每个downsample_block会返回多个残差输出的元组res_samples,该元组会拼接到栈down_block_res_samples的栈顶。在上采样计算中,代码会根据当前的模块个数,从栈顶一次取出len(upsample_block.resnets)个残差输出。

1
2
3
4
5
6
7
8
9
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(...)
down_block_res_samples += res_samples

for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(...)

现在,我们已经看完了Diffusers中U-Net的主要内容。可以看出,Diffusers的U-Net包含了很多功能,一般情况下是难以自己更改这些代码的。有没有什么办法能方便地修改U-Net的实现呢?由于很多工作都需要修改U-Net的Attention,Diffusers给U-Net添加了几个方法,用于精确地修改每一个Attention模块的实现。我们来学习一个修改Attention模块的示例。

U-Net类的attn_processors属性会返回一个词典,它的key是每个Attention运算类所在位置,比如down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor,它的value是每个Attention运算类的实例。默认情况下,每个Attention运算类都是AttnProcessor,它的实现在diffusers/models/attention_processor.py文件中。

为了修改Attention运算的实现,我们需要构建一个格式一样的词典attn_processor_dict,再调用unet.set_attn_processor(attn_processor_dict),取代原来的attn_processors。假如我们自己实现了另一个Attention运算类MyAttnProcessor,我们可以编写下面的代码来修改Attention的实现:

1
2
3
4
5
6
7
8
9

attn_processor_dict = {}
for k in unet.attn_processors.keys():
if we_want_to_modify(k):
attn_processor_dict[k] = MyAttnProcessor()
else:
attn_processor_dict[k] = AttnProcessor()

unet.set_attn_processor(attn_processor_dict)

MyAttnProcessor的唯一要求是,它需要实现一个__call__方法,且方法参数与AttnProcessor的一致。除此之外,我们可以自由地实现Attention处理的细节。一般来说,我们可以先把原来AttnProcessor的实现代码复制过去,再对某些细节做修改。

总结

在这篇文章中,我们学习了Stable Diffusion的原版实现和Diffusers实现的主要内容:采样算法和U-Net。具体来说,在原版仓库中,采样的实现一部分在主函数中,一部分在DDIM采样器类中。U-Net由一个简明的PyTorch模块类实现,其中比较重要的子模块是残差块和Transformer块。相比之下,Diffusers实现的封装更好,功能更多。Diffusers用一个Pipeline类来维护采样过程。Diffusers的U-Net实现与原版完全相同,且支持更复杂的功能。此外,Diffusers还给U-Net提供了精确修改Attention计算的接口。

不管是哪个Stable Diffusion的框架,都会提供一些相同的原子操作。各种基于Stable Diffusion的应用都应该基于这些原子操作开发,而无需修改这些操作的细节。在学习时,我们应该注意这些操作在不同的框架下的写法是怎么样的。常用的原子操作包括:

  • VAE的解码和编码
  • 文本编码器(CLIP)的编码
  • 用U-Net预测当前图像应去除的噪声
  • 用采样器计算下一去噪迭代的图像

在原版仓库中,相关的实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
# VAE的解码和编码
model.decode_first_stage(...)
model.encode_first_stage(...)

# 文本编码器(CLIP)的编码
model.get_learned_conditioning(...)

# 用U-Net预测当前图像应去除的噪声
model.apply_model(...)

# 用采样器计算下一去噪迭代的图像
p_sample_ddim(...)

在Diffusers中,相关的实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
# VAE的解码和编码
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
latents = self.vae.encode(image).latent_dist.sample(generator) * self.vae.config.scaling_factor

# 文本编码器(CLIP)的编码
self.encode_prompt(...)

# 用U-Net预测当前图像应去除的噪声
self.unet(..., return_dict=False)[0]

# 用采样器计算下一去噪迭代的图像
self.scheduler.step(..., return_dict=False)[0]

如今zero-shot(无需训练)的Stable Diffusion编辑技术一般只会修改采样算法和Attention计算,需训练的编辑技术有时会在U-Net里加几个模块。只要我们熟悉了普通的Stable Diffusion是怎么样生成图像的,知道原来U-Net的结构是怎么样的,我们在阅读新论文的源码时就可以把这份代码与原来的代码进行对比,只看那些有修改的部分。相信读完了本文后,我们不仅加深了对Stable Diffusion本身的理解,以后学习各种新出的Stable Diffusion编辑技术时也会更加轻松。