0%

今天看到一则新闻,说 Google DeepMind 发布了一个叫做 GameNGen (英文读音 “game engine”,游戏引擎)的扩散模型,它可以仅用神经网络的生成结果,模拟经典射击游戏 DOOM。

项目网站:https://gamengen.github.io/

视频链接:https://gamengen.github.io/static/videos/e1m1_t.mp4

作为一名未来的游戏设计师,每次看到这类「今天 AI 又取代了创作者」的新闻,我的第一反应总会是愤怒:创作是人类智慧的最高结晶,能做到这种程度的 AI 必然是强人工智能。但显然现在 AI 的水平没有那么高,那么这类宣传完全是无稽之谈。我带着不满看完了论文,果然这个工作并没有在技术上有革命性的突破。不过,这篇论文还是提出了一个比较新颖的科研任务并漂亮地将其解决了的,算是一篇优秀的工作。除了不要脸地将自己的模型称为「游戏引擎」外,这篇工作在宣传时还算克制,对模型的能力没有太多言过其实的描述。

如果不懂相关技术的话,外行人很容易对这篇工作的应用前景产生一些不切实际的幻想。在这篇文章中,我将完整而清晰地介绍这篇工作的内容,再给出我从游戏开发方面和科研方面对这篇工作的评价。

问题定义

看完了这篇工作的展示视频,大家第一个想问的问题一定是:这个模型的输入是什么?是只能随机生成游戏视频,还是能够根据用户的输入来生成后续内容?

答案是,模型可以根据之前的游戏进度及当前的用户输入,输出下一帧的游戏图片。我们来详细看一下该工作对于「游戏」的定义。

论文的第二章详细定义了该工作要完成的任务。作者认为,一个游戏可以由游戏状态游戏画面玩家操作集这三类信息组成。游戏状态包括角色血量、装备、地图、敌人等所有影响游戏进程的信息;游戏画面就是游戏屏幕上显示的二维图片;玩家操作集就是移动、射击等玩家所有可能进行的操作。此外,为了让游戏运行,还需要两类游戏机制:如何根据游戏状态生成当前游戏画面的渲染机制;如何根据游戏状态和当前玩家操作更新下一时刻游戏状态的逻辑机制

而用一个生成模型来模拟游戏时,我们不需要让模型学会游戏机制、状态,只需要让模型根据之前所有的画面和玩家操作,以及当前时刻的玩家操作,输出当前时刻的画面。这样,我们给第一个初始帧和操作,模型就能输出第二帧;给前两帧和之前及现在的操作,模型就能输出第三帧……。也就是说,模型以自回归的方式模拟游戏画面。

给想读这篇论文的读者一点阅读上的提示。论文以「交互环境」$\mathcal{E}$来指代我上文中的「游戏」,以 「交互世界模拟」$q$ 来指代生成模型。在定义 $\mathcal{E}$ 用到的几个字母对应我前文的粗体名词。

在我看来,这篇工作的主要贡献,就是把游戏模拟任务以如此简明的形式清楚定义了出来。怎么把任务做好,纯粹只是工程实现问题。明确了任务后,相关领域的科研人员基本上能猜出这篇工作是怎么实现的了。下面,我们就来看一下论文中分享的实现过程。

GameNGen

用强化学习造数据

本工作的任务是根据某些信息生成下一帧的画面,这种生成某类图像的任务用日益成熟的图像扩散模型技术就能解决了。但是,为了训练扩散模型,本工作有一道跨不过的坎——缺乏游戏画面数据。

为了生成足够多的图片,作者利用强化学习训练了一个玩游戏的 AI。在这一块,作者用了一个非常巧妙的设计:和其他强化学习任务不同,这个玩游戏的 AI 并不是为了将游戏漂亮地通关,而是造出尽可能多样的数据。因此,该强化学习的奖励函数包括了击中敌人、使用武器、探索地图等丰富内容,鼓励 AI 制造出不同的游戏画面。

带约束图像扩散模型

有了数据后,问题进一步缩小:现在该怎么用 Stable Diffussion 这个比较成熟的图像生成模型来根据之前的图片之前及当前的操作来生成图片。

「之前的图片」和「之前及现在的操作」是输入给图像生成模型的两类额外信息。用专业术语来说,它们是给一个随机生成图像的模型的约束条件,用于让模型的输出不那么随机。而带约束图像生成也是一个被研究得比较透的任务了。

以下内容是写给相关科研者看的,看不懂可以跳过。

先谈怎么让模型约束于之前的图片。本文参考了经典的 Cascaded Diffusion 实现图像约束:将之前的图片过 VAE 编码器,与扩散模型原来的噪声输入拼接。然而,这种约束方式存在分布不匹配的问题:训练时,图像约束来自训练集;推理时,图像约束来自于模型之前自回归的生成结果。为了填平两类图像在分布上的差异,我们需要给所有约束图像加噪,并把加噪程度当成额外约束输入进模型。这种图像约束方法和图生视频的 Stable Video Diffusion 是一样的。

再看模型怎么约束于操作。每个操作都有独特的含义,它和语言中的单词是类似的。因此,我们可以把离散的操作变成嵌入向量,用 Stable Diffusion 处理单词的机制来处理操作。所以,在 GameNGen 的扩散模型中,文本约束没有了,被操作约束取代了。顺带一提,我们要输入之前及当前的操作,它们构成了一个操作序列。我们只要把这个操作序列当成由单词构成的文本,还是用原来 Stable Diffusion 那套处理文本的机制就行了。

此外,该工作还微调了 Stable Diffusion 的 VAE 的解码器,用以提升其在特定数据上的重建效果。这种操作也是比较常见的,比如 Stable Video Diffusion 将图像解码器微调成了视频解码器。

实验及结果

本工作仅对图像约束用了强度为 1.5 的 Classifier-Free Guidance (CFG),没有对操作约束加 CFG。采样用的是 DDIM 采样器,实验表明 4 步采样的结果就足够好了。在单个 TPU-v5 上,模拟模型每秒能渲染 20 帧画面。

扩散模型基于 Stable Diffusion V1.4 训练。输入图像分辨率为 $320 \times 256$。之前图像和操作的上下文窗口长度为 $64$。训练集包含 900M 张图片,要用 128 块 TPU-v5e 训练 700,000 步。

论文还评价了模型的生成质量。由于训练集能够提供当前要预测的帧的真值,因此我们可以用重建误差来反映模型的质量。具体来说,论文展示了表示图像相似度的 PSNR 和图像感知误差的 LPIPS。 重建质量通过两个任务来反映:

  • 图像质量:输入之前图像、操作均为数据集(按理说是测试集而不是训练集)里的真值,仅评测当前帧。此时重建质量较好,平均 PSNR 为 29.43,和一般图像经 JPEG 压缩产生的损耗相近。
  • 视频质量:给定初始图像和数据集里的之前操作,让模型自回归地生成一段视频。这种情况下每一帧的质量会逐渐降低,如下所示。

除此之外,论文还展示了人类的评估结果:给出两段长度为 1.6 秒或 3.2 秒的游戏视频,分别来自真实游戏和 AI 生成,请人类分辨哪段视频是 AI 生成的。对于 1.6 秒的视频,正确率为 58%;对于 3.2 秒的视频,正确率为 60%

此类评测的最低(最优)正确率是 50%,因为我们总是可以随便猜一个。

论文总结与评价

信息整理

到目前为止,我已经客观介绍了论文中展示的内容。让我用外行人也能看懂的语言总结一下:

本工作提出的 GameNGen 模型可以根据之前的游戏画面、用户的历史操作、当前操作,在完全不了解游戏机制的前提下,生成包含游戏逻辑(血量、弹药)的当前帧画面。生成画面的模型是一个深度学习模型。因此,需要一个包含了过往操作、真实游戏画面的大型数据集来训练该模型。本工作利用强化学习制造了大量数据。训练用了 128 张 TPU-v5e 计算卡。该生成模型能在单张 TPU-v5 (专门用来做深度学习的计算「显卡」,只租不卖,比价值 25000 美元的 H100 显卡要好) 上能以 20 帧每秒的速度生成画面。模型最多能记住 64 帧,即 3.2 秒内的游戏信息。若将 3.2 秒的生成视频与真实游戏视频比较,人类仅有 60% 的概率分辨出 AI 生成的视频。当然,利用自回归技术反复输入之前生成的画面,模型也能够生成更长的视频,正如本文开头所展示的那个视频。

论文中提及的缺陷

哪怕不懂深度学习,大家也可以根据上述信息,提出自己的看法。当然,在那之前,我们先看一下作者在论文里是怎么描述模型的缺陷的。

作者讲到,GameNGen 受制于有限的记忆。模型仅能获取三秒之短的历史信息,却能在极其长的时间里保持游戏逻辑稳定,这很了不起(怎么讲着讲着又开始夸起自己了?)。但由于模型学到了太多东西,模型会创造短期记忆处理不了的场景,比如角色在某处打倒了敌人,数分钟后又返回打倒了敌人的地方(文章讲了半天讲不出个东西,我替他们总结了一下)。对于这些问题,以当前的模型架构,再怎么加大记忆窗口也无济于事。

面向交互视频游戏的新范式

基于这篇工作的成果,作者还展望了未来:(以下是我对原文的精心翻译,欢迎对比)

如今,视频游戏是靠人类编程实现的。而 GameNGen 表明了以神经网络的权重来描述游戏这种新范式的部分可行性。GameNGen 展示了,在现有硬件上用神经网络模型来高效地交互运行一个复杂的游戏 (DOOM)不是奢望,这样的模型架构与模型权重是存在的。尽管还有很多重要问题要解决,我们希望这种范式能够造福众生。比如说,用这种新范式开发视频游戏的代价可能更小,门槛更低,借此,我们或许能仅通过编辑文本描述或示例图片来开发游戏。这个愿景的一小部分,即在现有游戏上略作调整或创造新行为,也许会在不久的将来就能实现。比如即使拿不到作者的源代码,我们也可能可以把几段游戏截图变成一个可玩的游戏关卡或者基于现有图像创建新角色。这个范式还能带来其他好处,比如严格控制帧率、内存占用不变。路漫漫其修远兮,尽管我们目前还没有探究这些方向,我们还是乐于尝试!希望我们这一小步,会在未来的某天,化为人们享受视频游戏的美好瞬间;甚至更进一步,化为人们使用交互软件系统的点点日常。

作为一个对这段抒情深有共鸣的人,在看论文前,我有千言万语想喷,却找不到切入点。看了这段话,我总算知道该怎么针对性地发表意见了。

锐评

总有人说,AI 要取代人类了。

之前是说 AI 可以代替画师,又是说 AI 能代替人写小说。现在,来说 AI 能够完全模拟游戏了。

我真的很不解:人类的作品怎么会沦落到和现有 AI 的对比了?

我想了很久,为什么我无法容忍「现有 AI 技术能代替人类创作」这种观点。我的核心论点是:1)现有基于深度学习的 AI 无法达到和人同等的智力水平;2)达到人类同等的智力水平,意味着能够理解人类的行为,进而在创作、编程、教学、心理咨询等所有现在看来比较困难的领域看齐甚至超越人类。

详细对深度学习了解了一段时间后,绝大多数人都能推理出深度学习的上限。通过对深度学习应用的种种观察,我们能够用我们脑中的那个「神经网络」,那种基于数据推理现实的能力,预知深度学习的能力上限。深度学习适用于一些数据定义良好,目标定义良好的任务。只要给了数据,给了目标,网络就能学习,甚至涌现出一些意想不到的强大理解、生成能力。

但是,数据不是全部,永远有从大量数据学不到的东西。

那就是人心。

就和无数探讨机器人的作品所展示的一样。

正因为我们是人,所以我们能感知我们生活在这个世界上。我们感受痛苦,所以我们思考,并追逐美好。

无论是穷尽多少数据,复读多少经书也体会不到的;无论是洞察多少规律,拟合多少逻辑也推理不出的,就是人心。

和人心等价的一切任务,是现在的深度学习 AI 做不到的。

有关深度学习是否能达到人类水平,那是技术讨论,我们在别处再谈。如果最终能认同目前深度学习无法达到人类水平,但目前认为「现有 AI 技术能代替人类创作」的话,简单来看有两种可能:

  1. 不懂深度学习

  2. 不懂创作的难度

而对于一个相关专业人士来说,只能是第二种可能了。所以听到懂深度学习的人讲出「AI 要代替人类」时,我都会下意识地认为他在贬低人类创作的含金量,自然是怒火中烧。不懂创作,就不要妄加评论。

除此之外,还有人明明知道当前 AI 的实力,还要违心地宣传 AI 如何如何,宣传自己的垃圾工作多么有价值。这种人只会为自己的利益考虑,纯纯的坏而已。这种无可救药的坏连讨论的价值也没有。

最后还有一种可能,很多人做研究时,并不会像我这样想这么多。在他们看来,科研就是从现有知识出发,朝外迈一小步。不管这一步是否方向正确,不管这一步有多么小,只要是拓宽人类的知识边界,那就是好的。他们的研究是纯洁的、无私的,不在意有生之年能否看到自己的成果被用上,甚至不在意自己的研究是否真的会被用到,只是为了科研增砖添瓦而已。我不得不承认,这是真正的、高尚的科研。

他们是幸福的,可以不在意眼前的得失。所以,他们可以望着远处高峰,轻松而豪放地说出:「希望我的研究,能化为人们享受视频游戏的美好瞬间」。

然而,这句话,对我而言,是沉重的。如果这种话是我说出来的,那么它不会是期盼,而是矢志不渝的誓言。不是拿着望远镜向远处眺望,不是用手指着地图挥斥方遒,而是用我的脚,一步,一步,踏出来的。

我是各种作品的鼓舞下走过来的,优秀的作品对我而言是神圣的。所以,我希望立刻,亲眼见到更多的好作品。没有对好作品急功近利的渴望,也就说明他们生活中有更加便捷的能量来源。所以我说,真正能以纯洁的心做科研的人,是幸福的人。

把我这些话总结一下,能认为深度学习能代替人类创作游戏,要么是深度学习的信徒,要么是不懂深度学习,要么是不懂创作,要么是坏,要么就是没想那么多觉得有新科研工作就是好事。

有人可能还会说:「我也同意深度学习代替不了人类,但也不能说这些技术就完全没用」。这我非常同意,我就认为大家应该把现在的 AI 当成一种全新的工具。基于这些新工具,我们把创新的重点放在如何适配这些工具上,辅助以前的应用,或者开发一些新的应用,而不是非得一步到位直接妄想着把人类取代了。比如,根据简笔画生成图片就是一个很好的新应用啊。

回到这篇文章的锐评上来。一上来,标题就写着《扩散模型是实时游戏引擎》。其实这是一个在顶级计算卡上每秒生成 20 张低分辨率图片的模型,这真的是我们认为的实时吗?作为一个游戏引擎,你能修改游戏机制吗?哦,我对引擎的理解有误,这不是游戏开发引擎,而是一个运行游戏的引擎。那踏踏实实叫做「游戏模拟器」不好吗?标题取得夸张一点,想吸引大家注意,能够理解,不多讲了。

文章主体部分都是客观陈述,写得非常清晰,我读起来也很舒服。本来都准备把「锐评」改成「简述」的,看完作者最后那段对未来的畅想后,一阵无名火在我心中燃起。通过「我们或许能仅通过编辑文本描述或示例图片来开发游戏」这段话,我感觉这是作者是几个对游戏制作质量没有那么高要求的人。可是,就是这样的人,却能写出「路漫漫其修远兮,尽管我们目前还没有探究这些方向,我们还是乐于尝试!」这说明他们可能是真心热爱游戏的玩家。那为什么,为什么只做出了这种程度的工作呢?900M 张图片,128 块卡,别说爱好者,就是一般的大学实验室,都难以跟进这篇工作,这是想要给设计师、爱好者开发新工具的态度吗?没有其他更加贴近用户的项目了吗?好,你说你们以长期的科研为主,这只是这个方向的初步尝试,你们重心在科研上而不是提供游戏开发工具上。那你们是抱着多大的觉悟说出「希望我们这一小步,会在未来的某天,化为人们享受视频游戏的美好瞬间」的?给人的感觉就是一群深居象牙塔的人,一辈子也不去了解业界真的需要什么,只是「正确地」做着科研而已。

非常抱歉,以上都是我的主观评价,请恕我对作者的想法妄加猜测。看完文章最后那段话后,我就有了这样一种矛盾的愤怒感。我喷了这么多,其实也不是想喷这篇工作的作者,更想批判的,是我长年以来在生活中的见闻。把气撒到这篇工作上,可能只是我嫉妒他们,没有 128 块卡去做想做的事情而已。

但我毫不怀疑地相信,我要有的资源最后都会有的。「游戏开发的新范式」、「造福众生」、「今天的一小步」……,如果有一天我说出了这些话,那必然不是在论文里,而是在我的产品得到了用户的充分肯定后,向世界吹嘘的胜利宣言吧。

新科研方向的讨论

先谈一下这篇工作在科研上给我们的启发。我认为有三点:

  1. 对于深度学习应用来说,不要去在意功能有多么异想天开,只要把问题定义好,数据准备好,问题就可解。
  2. 可以以用户操作为约束,用生成模型建模一个可交互世界。这个「用户操作」不一定局限于游戏玩家的操作。
  3. 强化学习可以用来造大批数据。并且,我们需要精心设计模型的学习目标,使其造出多样的数据。

再来看顺着这篇文章的结果,我们能有怎样的新思考。拿图像生成模型这种结果极不稳定的东西做要求输入输出可控的游戏是不可能的。但是,我们应该把思路逆转过来:哪些任务可以以不确定的图像为输入?最容易想到的是缺数据的自动驾驶任务。也就是说,这篇工作实际上是提出了一种带交互的图像数据生成器。

这篇工作用图像模型学习了 3D 场景在移动后的变化。也就是说,模型「理解」了 3D 场景。那么,有没有办法从模型中抽取出相关的知识呢?按理说,能理解 3D,就能理解物体是有远近的。那么,深度估计、语义分割这种任务是不是可以直接用这种模型来做呢?以交互为约束的图像生成模型可能蕴含了比文生图模型更加丰富的图像知识。很可惜,不知道这篇工作最后会不会开源。

如果一个模型能够建模一个简单的世界,我们下一步要思考的是怎么编辑这个世界。就像有了图像生成模型,我们要给它加上文本约束一样。最容易想到的是以二维图片为约束,生成一个世界里的三维物体。但目前这个模型要的训练量太大了,做这种新实验的代价根本不敢想。

当前这个模型的结果还是比较弱的。别看这里用了扩散模型,模型的训练目标实际上是一个重建任务而非生成任务,最后的评测指标也是重建误差而没有考虑 FID 等图像生成质量指标。有没有办法让模型设法输出有多样性的新内容?还有,这个模型在设计上应该是一个自回归模型,只不过下一张图片是用扩散模型隐式建立了概率分布。但由于用户操作是在线而不是离线一次性给出,这种任务在时序上没办法用视频扩散模型建模。所以,从本质上看,要设计带用户操作的生成模型,其实是要一种时序信息在线输入的生成模型。除了自回归模型外,能不能用一种全新的生成范式呢?一旦有了这样一种新式模型,训练交互世界神经网络的代价将大幅降低。

除了时序信息不好输入外,这个任务还面临另一个问题:随着时间不断推移,模型会忘记以前的内容。NLP 领域通过 Transformer 全局信息交互暴力地解决了这类问题,但在长视频生成中,这种问题还是没能很好解决。或许需要其他领域为长时序建模提供了更好的工具后,才能考虑长时间的世界交互。但我们也可以从另一个角度思考:把要模拟的世界简化,不要去做模拟游戏那么难的任务,做一个时序依赖少的模拟任务。

结束科研新方向的思考前,我再对这个模型的训练量喷两句。搞大批数据,搞大模型,一堆 GPU 狂训,强行拟合一个任务,在我看来是非常不优雅的做法。得不到学术界广泛跟进的工作,是发展不下去的。在思考这种世界模拟模型的种种应用前,最重要的还是把训练量降下去。我认为优化在线时序约束生成模型是最有价值的方向。

最后我再从游戏开发的角度讲一下这篇工作的启发。我们只讨论如何用生成模型减少美术工作量,不去探究怎么让模型学习设计复杂的游戏机制。

真的想要为游戏开发提供工具的话,应该从 2D 游戏而不是 3D 游戏入手。拟合 2D 游戏画面的训练代价较小,且 2D 游戏对开发者来说更容易做。仅基于现在的 AI 技术,我们就已经能够做出一些简单的美术生成应用了。比如前段比较火的植物大战僵尸杂交版,我们完全可以想办法定义一个科研任务,训练一个融合两种植物的生成模型。明明有更切实际的 2D 不做去做 3D,这也是我为什么觉得这篇工作在为游戏开发提供工具这个层面上显得非常没有诚意。这篇文章建模的 3D 世界模拟器其实更适合前面提到的自动驾驶等真实场景,反而不是很适合游戏开发。

从这篇工作出发,或许可以联想出非常多游戏开发应用。这里我随便举一个 2D 网格地图生成任务。用同一种地板拼接地形时,根据地板的连通情况,游戏开发引擎会自动生成完整地形。但是,这种生成是基于写死的规则的,且最终每种地形还是以正方形网格为单位呈现。能不能用生成模型让这种地形生成更加多样呢?

从本文的启发 2 出发,我们或许可以有一些新想法:可以以操作为约束,用生成模型建模一个可交互世界。那么,我们将「操作」定义为设计师向地图上铺上一块地板,而不是定义为玩家的操作。这样,模型就能站在设计师的角度的学习生成地图了。

如果数据足够的话,这种定义方式一定能让神经网络学会地形生成的。问题的关键就在于如何获取数据。根据玩家的操作,我们能够自然生成大量游戏数据。而根据设计师的每一步想法,每一步都绘制合理且高质量的图片,可能要花费大量资源。因此,如果是从这个角度出发,需要大量美术、游戏设计、深度学习的相关人员参与进来。

总结

GameNGen 将模拟 3D 可交互场景的任务定义为根据历史画面、历史及当前操作生成当前画面的带约束图像生成任务。该工作用强化学习巧妙造出大量数据,用扩散模型实现带约束图像生成。结果表明,该模型不仅能自回归地生成连贯的游戏画面,还能学会子弹、血量等复杂交互信息。然而,受制于硬件及模型架构限制,模型要求的训练资源极大,且一次只能看到 3.2 秒内的信息。这种大量数据驱动的做法难以在学校级实验室里复刻,也不能够归纳至更一般的 3D 世界模拟任务上。

我个人认为,从科研的角度来看,这篇工作最大的贡献是提出了一种用带约束图像生成来描述 3D 世界模拟任务的问题建模方式。其次的贡献是确确实实通过长期的工程努力把这个想法做成功了,非常不容易。但从游戏开发的角度来看,这个工作现阶段没什么用处。

从科研启发的角度思考,这篇工作告诉我们,定义好交互世界里的操作,我们就能部分地用图像生成模型建模一个交互世界。从本质上来看,这是一个每一个时刻的约束都在线给出的视频生成任务。针对这个任务,我们既可以去思考能否用自回归以外的更高效的方式来实现它,也可以去思考是否可以修改对于「操作」的定义来实现模拟玩家操作以外的世界模拟任务。

前几个月,推出了著名文生图模型 Stable Diffusion 的 Stability AI 公司曝出了核心团队集体离职的消息。一时间,AI 从业者们议论纷纷,不知道这究竟是团队出现了矛盾,还是这些员工觉得文生图模型做下去没有前途了。而近期,该核心团队重新组建的创业团队 Black Forest Labs(黑森林实验室)带着名为 FLUX.1 的文生图模型「复仇归来」。FLUX.1 受到了用户的广泛好评,让人们期盼更强开源文生图模型的热情得以延续。

Black Forest Labs 的成员基本上都是 Stable Diffusion 3 的作者,其中三名元老级成员还是 Stable Diffusion 论文的作者。同时,FLUX.1 也是一个在 Stable Diffusion 3 架构上做改进的模型。不管从哪个角度,FLUX.1 都称得上是Stable Diffusion 3 的「精神续作」。秉承着此前的开源精神,FLUX.1 也在上线之始就为社区开放了源代码和模型权重。不过,配套的技术论文并没能及时发布,想要了解 FLUX.1 技术细节的用户恐怕还得等上一阵子。为了尽快搞清楚 FLUX.1 相较 Stable Diffusion 3 做了哪些改进,我直接去细读了 FLUX.1 的源码。在这篇文章中,按照惯例,我将主要从源码层面上分享 FLUX.1 中已知的科研创新,做一个官方论文发布前的前瞻解读,而不会评测 FLUX.1 的图像生成效果。

具体来说,我会介绍 FLUX.1 中的以下改动:

  • 略微变动的图块化策略
  • 不使用 Classifier-Free Guidance 的指引蒸馏
  • 为不同分辨率图像调整流匹配噪声调度
  • 用二维旋转式位置编码 (RoPE) 代替二维正弦位置编码
  • 在原 Stable Diffusion 3 双流 Transformer 块后添加并行单流 Transformer 块

我会先简单介绍 FLUX.1 的官方公告及 Diffusers 版使用示例,再按照我读代码的逻辑,从熟悉整个代码框架,到深究每一处创新的代码细节,最后分享我对于 FLUX.1 科研改进上的分析。对源码不感兴趣的读者,可以跳过通读代码框架章节,或者直接阅读感兴趣的那部分改动。想看省流版文章的读者,可以直接跳到结尾看总结。

建议读者在学习 Flux.1 前熟悉 Stable Diffusion 3。欢迎参考我之前写的文章:Stable Diffusion 3 论文及源码概览。

模型简介与 Diffusers 示例脚本

在正式阅读源码前,我们先来看一下官方推文(https://blackforestlabs.ai/announcing-black-forest-labs/ )中有关 FLUX.1 的简介,并在 Diffusers 中跑通 FLUX.1 的图像生成示例脚本。

据官方介绍,FLUX.1 是一套文生图模型。它有三个变体(variant,可以理解成结构相似或相同,但权重不同的几个模型):

  • FLUX.1 [pro]: FLUX.1 系列的最强模型,只能通过付费的 API 或者在线平台使用。
  • FLUX.1 [dev]:FLUX.1 [pro] 的指引蒸馏(guidance-distilled)模型,质量与文本匹配度与原模型相近,运行时更高效。
  • FLUX.1 [schnell]:为本地开发和个人使用而裁剪过的本系列最快模型。据 Diffusers 中的文档介绍,这是一个 Timestep-distilled(时间戳蒸馏)的模型,因此仅需 1~4 步就可以完成生成。无法设置指引强度。

官方对这些模型的详细介绍少之又少。FLUX.1 [dev] 用到的指引蒸馏技术似乎来自论文 On Distillation of Guided Diffusion Models,其目标是让模型直接学习 Classifier-Free Guidance (CFG) 的生成结果,使得模型一次输出之前要运行两次才能得到的指引生成结果,节约一半的运行时间。官方也没有讲 FLUX.1 [schnell] 的蒸馏细节,似乎它是从 FLUX.1 [dev] 中用扩散模型加速蒸馏手段得到的模型。因此,FLUX.1 [schnell] 不仅能一次输出有指引的生成结果,还能在极少的采样步数里完成生成。

官方推文中还说,FLUX.1 的生成神经网络基于 Stable Diffusion 3 的 MMDiT 架构和并行的 DiT 块,参数量扩大至 120 亿。生成模型是根据流匹配(flow matching)推导的扩散模型。为了提升性能与效率,模型新引入了旋转式位置编码 (RoPE) 和并行注意力层。

这段话这么长,还把并行注意力说了两遍,其实没有多少新东西。说白了,FLUX.1 就是在 Stable Diffusion 3 的基础上,加了 RoPE 和并行注意力层。官方推文到这里就没有其他有关模型细节的介绍了。FLUX.1 具体做了哪些改动,我们直接去源码里看。

FLUX.1 的官方仓库是 https://github.com/black-forest-labs/flux 。相比 Stable Diffusion 那个臃肿杂乱的 generative-models 仓库,这个仓库的代码要简洁很多。不过,我还是推荐使用 Diffusers 框架来运行 FLUX.1。

Diffusers 中运行 FLUX.1 的官方文档为 https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux 。目前(2024 年 8 月 11 日),相关代码还在 Diffusers 的在线主分支里进行开发,并没有集成进 pip 版的 Diffusers 里。因此,要在 Diffusers 中使用 FLUX,必须要从源码安装 Diffusers:

1
2
3
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install -e .

安装完毕后,我们可以随便新建一个 python 脚本,填入以下的官方示例代码。在能够连通 Hugging Face 的环境中运行此脚本的话,脚本会自动下载模型并把生成结果保存在 image.png 中。注意,FLUX.1 的神经网络很大,显存占用极高,可能至少需要在 RTX 3090 同等级的显卡上运行。在示例代码中,我还改了一行,使用 pipe.enable_sequential_cpu_offload() 让模型把更多参数临时放到 CPU 上,避免显存不够。经测试,改了这一行后,FLUX.1 才勉强能在显存为 24G 的 RTX 3090 上运行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
# pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=1024,
width=1024,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("image.png")

由于随机数是固定的,运行后,我们应该总能得到这样的图片:

通读代码框架

由于开发还没有结束,在当前 Diffusers 的 FLUX.1 源码中,我们能看到各种潦草的写法及残缺不全的文档,这让读源码变成了一项颇具趣味的挑战性任务。让我们先看一下代码的整体框架,找出 FLUX.1 相较 Stable Diffusioni 3 在代码上的改动,再来详细分析这些创新。

和 Diffusers 中的其他生成模型一样,FLUX.1 的采样算法写在一个采样流水线类里。我们可以通过示例脚本里的 FluxPipeline 类跳转到定义该类的文件 diffusers/pipelines/flux/pipeline_flux.py 里。这个文件是从 Stable Diffusion 3 的采样流水线文件 diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py 改过来的,大部分文档都没有更新。我们可以用肉眼对比两份文件的区别。

先看构造函数。Stable Diffusion 3 用了三个文本编码器,clip-vit-large-patch14, CLIP-ViT-bigG-14-laion2B-39B-b160k, t5-v1_1-xxl,而 FLUX.1 没有用第二个 CLIP 编码器,只用了另外两个文本编码器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
def __init__(
self,
transformer: SD3Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
):

class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
):

再往下翻,我们能用火眼金睛发现 FLUX.1 的 VAE 压缩比是 16,是所有版本的 Stable Diffusion VAE 压缩比的两倍。这是为什么呢?不是增加压缩比会让 VAE 重建效果下降吗?

1
2
3
4
5
6
7
8
9
10
11
# SD3
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1)
if hasattr(self, "vae") and self.vae is not None else 8
)

# FLUX.1
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels))
if hasattr(self, "vae") and self.vae is not None else 16
)

查看周围其他代码,我们能找到 _pack_latents_unpack_latents 这两个方法。_pack_latents 其实就是一个图块化操作,它能把 $2 \times 2$ 个像素在通道维度上拼接到一起,而 _unpack_latents 是该操作的逆操作。原来,代码把图块化的两倍压缩比也算进 VAE 里了。这里直接把 vae_scale_factor 乘个 2 是一种非常差,歧义性极强的写法。

1
2
3
4
5
6
7
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
...

@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
...

相比 SD3, FLUX.1 将图块化操作写在了去噪网络外面。因此,SD3 的去噪网络的输入通道数是 16,和 VAE 的隐空间通道数相同;而 FLUX.1 由于把 $2 \times 2$ 个像素在通道上拼接到了一起,其去噪网络的输入通道数是 64。

1
2
3
4
5
6
7
8
9
{
"_class_name": "SD3Transformer2DModel",
"in_channels": 16,
}
{
"_class_name": "FluxTransformer2DModel",
"in_channels": 64,
}

再来看采样主方法 __call__。先看一下它的主要参数。相比之下,FLUX.1 少了一组提示词,且没有负面提示词。少一组提示词是因为少用了一个文本编码器。而没有负面提示词是因为该模型是指引蒸馏过的,在文本指引上没那么灵活。我们稍后还会看到 FLUX.1 具体是怎么利用文本指引的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# SD3
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
)

# FLUX.1
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
)

之后的内容都与其他扩散模型流水线一样,代码会判断输入是否合法、给输入文本编码、随机生成初始化噪声。值得关注的是初始化噪声采样器前的一段新内容:代码会算一个 mu,并传进 retrieve_timesteps 里。这个变量最后会传到流匹配采样算法里。我们先把该改动记在心里,不看细节。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)

在去噪循环部分,FLUX.1 没有做 Classifier-Free Guidance (CFG),而是把指引强度 guidance 当成了一个和时刻 t 一样的约束信息,传入去噪模型 transformer 中。CFG 的本意是过两遍去噪模型,一次输入为空文本,另一次输入为给定文本,让模型的输出远离空文本,靠近给定文本。而负面提示词只是一种基于 CFG 的技巧。把 CFG 里的空文本换成负面文本,就能让结果背离负面文本。但现在这个模型是一个指引蒸馏模型,指引强度会作为一个变量输入模型,固定地表示输入文本和空文本间的差距。因此,我们就不能在这个模型里把空文本换成负面文本了。

除了指引方式上的变动外,FLUX.1 的去噪网络还多了 txt_idsimg_ids 这两个输入。我们待会来看它们的细节。

FLUX.1 的去噪网络和 SD3 的一样,除了输入完整文本嵌入 prompt_embeds 外,依然会将池化过的短文本嵌入 pooled_prompt_embeds 输入进模型。我们现在可以猜测,FLUX.1 使用了和 SD3 类似的文本约束机制,输入了两类文本约束信息。

代码里的 /1000 是临时代码。之后所有涉及乘除 1000 的代码全可以忽略。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
for i, t in enumerate(timesteps):
timestep = t.expand(latents.shape[0]).to(latents.dtype)

# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=device)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None

noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

采样流水线最后会将隐空间图片解码。如前所述,由于现在图块化和反图块化是在去噪网络外面做的,这里隐空间图片在过 VAE 解码之前做了一次反图块化操作 _unpack_latents。对应的图块化操作是在之前随机生成初始噪声的 prepare_latents 方法里做的,为了节约时间我们就不去看了。

1
2
3
4
5
6
7
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

接下来,我们再简单看一下去噪网络的结构。在采样流水线里找到对应类 FluxTransformer2DModel,我们能用代码跳转功能定位到文件 diffusers/models/transformers/transformer_flux.py。SD3 去噪网络类是 SD3Transformer2DModel,它位于文件 diffusers/models/transformers/transformer_sd3.py

同样,我们先对比类的构造函数。构造函数的新参数我们暂时读不懂,所以直接跳到构造函数内部。

在使用位置编码时,SD3 用了二维位置编码类 PatchEmbed。该类会先对图像做图块化,再设置位置编码。 而 FLUX.1 的位置编码类叫 EmbedND。从官方简介以及参数里的单词 rope 中,我们能猜出这是一个旋转式位置编码 (RoPE)。

1
2
3
4
5
6
7
8
9
10
11
12
# SD3
self.pos_embed = PatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
)

# FLUX.1
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)

再往下看,FLUX.1 的文本嵌入类有两种选择。不设置 guidance_embeds 的话,这个类就是 CombinedTimestepTextProjEmbeddings,和 SD3 的一样。这说明正如我们前面猜想的,FLUX.1 用了和 SD3 一样的额外文本约束机制,将一个池化过的文本嵌入约束加到了文本嵌入上。

设置 guidance_embeds 的话,CombinedTimestepGuidanceTextProjEmbeddings 类应该就得额外处理指引强度了。我们待会来看这个类是怎么工作的。

1
2
3
4
5
6
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
)

之后函数定义了两个线性层。context_embedder 在 SD3 里也有,是用来处理文本嵌入的。但神秘的 x_embedder 又是什么呢?可能得在其他函数里才能知道了。

1
2
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)

函数的末尾定义了两个模块列表。相比只有一种 Transformer 块的 SD3,FLUX.1 用了两种结构不同的 Transformer 块。

1
2
3
4
5
6
7
8
9
10
11
12
13
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(...)
for i in range(self.config.num_layers)
]
)

self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(...)
for i in range(self.config.num_single_layers)
]
)

我们再来看 forward 方法,看看之前看构造函数时留下的问题能不能得到解答。

forward 里首先是用 x_embedder 处理了一下输入。原本在 SD3 中,输入图像会在 pos_embed 里过一个下采样两倍的卷积层,同时完成图块化和修改通道数两件事。而现在 FLUX.1 的图块化写在外面了,所以这里只需要用一个普通线性层 x_embedder 处理一下输入通道数就行了。这样说来,变量名有个 x 估计是因为神经网络的输入名通常叫做 x。既然这样,把它叫做 input_embedder 不好吗?

1
2
3
4
5
# SD3
hidden_states = self.pos_embed(hidden_states)

# FLUX.1
hidden_states = self.x_embedder(hidden_states)

下一步是求时刻编码。这段逻辑是说,如果模型输入了指引强度,就把指引强度当成一个额外的实数约束,将其编码加到时刻编码上。具体细节都在 time_text_embed 的类里。

1
2
3
4
5
6
7
8
9
10
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)

下一行是常规的修改约束文本嵌入。

1
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

再之后的两行出现了一个新操作。输入的 txt_idsimg_ids 拼接到了一起,构成了 ids,作为旋转式位置编码的输入。

1
2
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)

此后图像信息 hidden_states 和文本信息 encoder_hidden_states 会反复输入进第一类 Transformer 块里。和之前相比,模块多了一个旋转式位置编码输入 image_rotary_emb

1
2
3
4
5
6
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

本来过了这些块后,SD3 会直接会直接返回 hidden_states 经后处理后的信息。而 FLUX.1 在过完第一类 Transformer 块后,将图像和文本信息拼接,又输入了第二类 Transformer 块中。第二类 Transformer 块的输出才是最终输出。

1
2
3
4
5
6
7
8
9
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

到这里,我们就把 FLUX.1 的代码结构过了一遍。我们发现,FLUX.1 是一个基于 SD3 开发的模型。它在图块化策略、噪声调度器输入、位置编码类型、Transformer 块类型上略有改动。且由于开源的 FLUX.1 是指引蒸馏过的,该模型无法使用 CFG。[dev] 版可以以实数约束的方式设置指引强度,而 [schnell] 版无法设置指引强度。

在这次阅读中,我们已经弄懂了以下细节:

  • 采样流水线会在去噪网络外面以通道堆叠的方式实现图块化。
  • 指引强度不是以 CFG 的形式写在流水线里,而是以约束的形式输入进了去噪网络。

我们还留下了一些未解之谜:

  • 输入进噪声采样器的 mu 是什么?
  • 决定旋转式位置编码的 txt_idsimg_ids 是什么?
  • 旋转式位置编码在网络里的实现细节是什么?
  • 新的那种 Transformer 块的结构是怎么样的?

针对这些问题,我们来细读代码。

调整流匹配标准差

在采样流水线里,我们见到了这样一个神秘变量 mu。从名字中,我们猜测这是一个表示正态分布均值的变量,用来平移 (shift) 某些量的值。

1
2
3
4
5
mu = calculate_shift(...)
timesteps, num_inference_steps = retrieve_timesteps(
...
mu=mu,
)

我们先看 calculate_shift 做了什么。第一个参数 image_seq_len 表示图像 token 数,可以认为是函数的自变量 x。后面四个参数其实定义了一条直线。我们可以认为 base_seq_lenx1, max_seq_lenx2base_shifty1max_shifty2。根据这两个点的坐标就可以解出一条直线方程出来。也就是说,calculate_shift 会根据模型允许的最大 token 数 4096 ($64 \times 64$) 和最小 token 数 256 ($16 \times 16$),把当前的输入 token 数线性映射到 0.5 ~ 1.16 之间。但我们暂时不知道输出 mu 的意义是什么。

1
2
3
4
5
6
7
8
9
10
11
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu

再追踪进调用了 muretrieve_timesteps 函数里,我们发现 mu 并不在参数表中,而是在 kwargs 里被传递给了噪声迭代器的 set_timesteps 方法。

1
2
3
4
5
6
7
8
9
10
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
...
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)

根据流水线构造函数里的类名,我们能找到位于 diffusers/schedulers/scheduling_flow_match_euler_discrete.py 调度器类 FlowMatchEulerDiscreteScheduler

1
2
3
4
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
...)

再找到类的 set_timesteps 方法。set_timesteps 一般是用来设置推理步数 num_inference_steps 的。有些调度器还会在总推理步数确定后,初始化一些其他变量。比如这里的流匹配调度器,会在这个方法里初始化变量 sigmas。我们可以忽略这背后的原理,仅从代码上看,输入 mu 会通过 time_shift 修改 sigmas 的值。

这里的变量命名又乱七八糟,输入 time_shiftsigmas 是第三个参数,而在 time_shift 里的 sigmas 是除了 self 以外的第二个参数。这是因为 Diffusers 在移植官方代码时没有取好变量名。

1
2
3
4
5
6
7
8
9
10
11
12
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
):
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)

我们再跑出去看一下流水线里输入的 sigmas 是什么。假设总采样步数为 $T$,则 sigmas 是 $1$ 到 $\frac{1}{T}$ 间均匀采样的 $T$ 个实数。

1
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)

现在要解读 mu 的作用就很容易了。假设 sigmas 是下标和值构成的点,我们可以测试 mu 不同的情况下, sigmas 经过 time_shift 函数形成的曲线图。

可以看出,mu=0则不修改曲线。随着 mu 增大,曲线逐渐上凸。

我对流匹配的具体细节不是很懂,只能大概猜测 mu 的作用。流匹配中,图像沿着某条路线从纯噪声运动到训练集中,标准差 sigma 用于控制不同时刻图像的不确定性。时刻为 0 时,图像为纯噪声,标准差为 1; 时刻为 1 时,图像为生成集合中的图像,标准差要尽可能趋于 0。对于中间时刻,标准差默认按照时刻线性变化。而 mu 是一个 0.5 ~ 1.16 之间的数,可能控制的是中间时刻的噪声均值。图像分辨率越大,token 越多,mu 越大,要加的噪声越重。这也符合之前 Stable Diffusion 3 论文在 Resolution-dependent shifting of timestep schedules 小节里的设计,对于分辨率越高的图像,需要加更多噪声来摧毁原图像的信号。总之,这个 mu 可能是训练的时候加的,用于给高分辨率图像加更多噪声,推理时也不得不带上这个变量。

FLUX.1 官方仓库对应部分是这样写的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b


def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)

# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)

return timesteps.tolist()

mu 的作用确实和高信号图片有关。但他们的设计初衷是偏移时间戳,而不是根据某种公式修改 sigma。比如原来去噪迭代 0~500 步就表示 t=0 到 t=0.5,偏移时间戳后,0~500 步就变成了 t=0 到 t=0.3。偏移时间戳使得模型能够把更多精力学习对如何对高噪声的图像去噪。

使用单流并行注意力层的 Transformer 架构

接下来的问题都和 FLUX.1 的新 Transformer 架构相关。我们先把整个网络架构弄懂,再去看旋转式位置编码的细节。

为了理清网络架构,我们来根据已知信息,逐步完善网络的模块图。首先,我们先粗略地画一个 Transformer 结构,定义好输入输出。相比 SD3,FLUX.1 多了指引强度和编号集 txt_ids, img_ids这两类输入。

接下来,我们把和 SD3 相似的结构画进来。所有 Transformer 块都是那种同时处理两类 token 的双流注意力块。输入文本的 T5 嵌入会作为文本支流进入主模型。输入文本的 CLIP 嵌入会经池化与MLP,与经过了位置编码和 MLP 的时刻编码加到一起。时刻编码会以 AdaLayerNorm 的方式修改所有层的数据规模,以及数据在输出前的尺寸与均值。

CombinedTimestepGuidanceTextProjEmbeddings 类中,我们能知道小文本嵌入、时刻嵌入、指引嵌入是怎么加到一起的。我们主要关心指引嵌入的有关操作。由于指引强度 guidance 和时刻 timestep 都是实数,所以 guidance_emb 的处理方式与 timesteps_emb 一模一样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

def forward(self, timestep, guidance, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)

guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)

time_guidance_emb = timesteps_emb + guidance_emb

pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections

return conditioning

在去噪模型 FluxTransformer2DModelforward 方法中,原先的图块化及二维位置编码模块被一个简单的线性层 x_embedder 取代了,现在的位置编码 image_rotary_emb 会输入进所有层中,而不是一开始和输入加在一起。

1
2
3
4
def forward(hidden_states, ...):
hidden_states = self.x_embedder(hidden_states)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)

之后,除了过 MM-DiT 块以外,文本信息还会和图像信息融合在一起,过若干个单流 Transformer 块。过了这些模块后,原来文本 token 那部分会被丢弃。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

我们已经画完了去噪模型的结构,最后把 VAE 部分加上就比较完美了。

多数模块的细节都可以在 SD3 的论文里找到,除了 RoPE 和单流 DiT 块。我们在这一节里再仔细学习一下单流 DiT 块的结构。

根据官方介绍,FLUX.1 的 Transformer 里用到了并行 Transformer。准确来说,FLUX.1 仅在最后的单流 DiT 块里用到了并行注意力层。并行注意力层是在文章 Scaling Vision Transformers to 22 Billion Parameters 中提出的。如下图所示,这项技术很好理解,只不过是把注意力和线性层之间的串联结构变成并联结构。这样的好处是,由于数据在过注意力层前后本身就要各过一次线性层,在并联后,这些线性层和 MLP 可以融合。这样的话,由于计算的并行度更高,模型的运行效率会高上一些。

顺带一提,在 Q, K 后做归一化以提升训练稳定性也是在这篇文章里提出的。SD3 和 FLUX.1 同样用了这一设计,但用的是 RMSNorm 而不是 LayerNorm。

我们可以在 FluxSingleTransformerBlock 类里找到相关实现。代码不长,我们可以一次性读完。相比上面的示意图,Q, K, V 的投影操作被单独放进了 Attention 类里,并没有和第一个线性层融合。而做了注意力操作后,Att-out 和 MLP-out 确实是放在一起做的。attn_outputmlp_hidden_states 拼接了起来,一起过了 proj_out。此外,这里的归一化层还是 DiT 里的 AdaLN,模块能接收时刻编码的输入。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)

self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)

processor = FluxSingleAttnProcessor2_0()
self.attn = Attention(...)

def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))

attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)

hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states

return hidden_states

此处具体的注意力运算写在 FluxSingleAttnProcessor2_0 类里。跳过前面繁杂的形状变换操作,我们来看该注意力运算的关键部分。在做完了标准注意力运算 scaled_dot_product_attention 后,一般要调用 attn.to_out[0](hidden_states) 对数据做一次投影变换。但是,在这个注意力运算中,并没有对应的操作。这表明该模块确实是照着并行注意力层设计的,离开注意力的投影与 MLP 的第二个线性层融合到了一起。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:

...

if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

...

return hidden_states

旋转式位置编码思想及 FLUX.1 实现

旋转式位置编码是苏剑林在 RoFormer: Enhanced Transformer with Rotary Position Embedding 中提出的一种专门为注意力计算设计的位置编码。在这篇文章中,我们来简单地了解一下旋转式位置编码的设计思想,为学习 FLUX.1 的结构做准备。

想深究旋转式位置编码的读者可以去阅读苏剑林的博文,先阅读《让研究人员绞尽脑汁的Transformer位置编码》(https://kexue.fm/archives/8130) 了解该怎么设计位置编码,再阅读《Transformer升级之路:2、博采众长的旋转式式位置编码》(https://kexue.fm/archives/8265) 了解旋转式位置编码的细节。

Transformer 仅包括注意力和全连接两种运算,这两种运算都是和位置无关的。为了让 Transformer 知道词语的前后关系,或者像素间的空间关系,就要给 Transformer 中的 token 注入某种位置信息。然而,仅仅告诉每个 token 它的绝对位置是不够好的,这样做最明显的缺点是模型无法处理训练时没有见过的长序列。比如训练集里最长的句子是 512 个 token,如果输入 600 个 token,由于模型没有见过编号超过 512 的位置编码,就不能很好地处理 512 号以后的 token。因此,我们不仅希望每个 token 知道自己的绝对位置,还希望 token 能从位置编码里知道相对位置的信息。

在提出 Transfomer 的论文中,作者给出了如下的一套正弦位置编码方案。这也是多数工作默认使用的位置编码方式。为了简化表示,我们假设输入 token 是一个二维向量,这样,每个 token 需要的位置编码也是一个二维向量。

其中,$k$ 表示第 $k$ 个 token。这样做的好处是,根据三角函数和角公式,位置编码之间可以用线性组合来表示,这种编码蕴含了一定的相对位置信息。

当我们要把二维向量拓展成 $d$ 维向量时,只需要把 $d$ 维两两打包成一组,每组用不同周期的正弦函数即可。因此,在后文中,我们也不讨论 $d$ 维的 token,只需要搞明白二维的 token 该怎么编码就行。

尽管正弦编码能表示一定的相对信息,但是,由于位置编码之间是线性关系,经过了 Transformer 中最重要的操作——注意力操作后,这种相对位置信息几乎就消失了。有没有一种位置编码方式能够让注意力计算也能知道 token 间的相对位置关系呢?

经苏剑林设计,假设每个 token 的二维位置编码是一个复数,如果用以下的公式来定义绝对位置编码,那么经过注意力计算里的求内积操作后,结果里恰好会出现相对位置关系。设两个 token 分别位于位置 $m$ 和 $n$,令给位置为 $j$ 的注意力输入 Q, K $q_j, k_j$ 右乘上 $e^{ij/10000}$的位置编码,则求 Q, K 内积的结果为:

其中,$i$ 为虚数单位,$*$ 为共轭复数,$Re$ 为取复数实部。只是为了理解方法的思想的话,我们不需要仔细研究这个公式,只需要注意到输入的 Q, K 位置编码分别由位置 $m$, $n$ 决定,而输出的位置编码由相对位置 $m-n$ 决定。这种位置编码既能给输入提供绝对位置关系,又能让注意力输出有相对位置关系,非常巧妙。

根据欧拉公式,我们可以把 $e^i$ 用一个含 $sin$ 和 $cos$ 的向量表示。由于该变换对应向量的旋转,所以这种位置编码被称为「旋转式位置编码」。在实际实现时,我们不需要复数库,只需要用两个分别含 $sin$ 和 $cos$ 的数来表示一个位置编码。也就是说,原来正弦位置编码中每个位置的编码只有一个实数,现在需要两个实数,或者说要一个二维向量。

总结一下用旋转式位置编码替换正弦位置编码后,我们在实现时应该做的改动。现在,我们不是提前算好位置编码,再加到输入上,而是先预处理好位置编码,在每次注意力 Q,K 求内积前给输入乘上。和正弦编码一样,我们会把特征长度为 $d$ 的 token 向量的分量两两分组,分别维护位置关系。但是,现在每个分量的编码由两个而不是一个实数表示。所以,在之后的代码中,我们会看到生成位置编码时,会先把 token 特征向量长度除以二,再给每组 token 生成 $2 \times 2$ 个编码,对应每组两个编码,每个编码长度为二。

我们来看一下 FLUX.1 的 Transformer 是怎么处理位置编码的。在 FluxTransformer2DModelforward 方法中,我们能看到输入的 0, 1, 2, 3 这样的整数位置编码 ids 被传入了位置编码层 pos_embed 中。

1
2
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)

位置编码层类 EmbedND 定义了位置编码的具体计算方式。这个类的逻辑我们暂时跳过,直接看最后真正在算旋转式位置编码的 rope 函数。函数中,输入参数 pos 是一个 0, 1, 2, 3 这样的整数序号张量,dim 表示希望生成多长的位置编码,其值应该等于 token 的特征长度,theta 用来控制三角函数的周期,一般都是取常数 10000。我们能看到,rope 计算了输入的三角函数值,并把长度为 dim 的编码两两分组,每组有 (2, 2) 个位置编码值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."

scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)

batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)

stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()


class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim

def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)

我们来看一下位置编码是怎么传入 Transformer 块的注意力计算的。在预处理完位置编码后,image_rotary_emb 会作为输入参数传入所有 Transformer 块,包括前面的双流块和后面的单流块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def forward(...):
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)

...

for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
...
image_rotary_emb=image_rotary_emb,
)

for index_block, block in enumerate(self.single_transformer_blocks):
hidden_states = block(
...
image_rotary_emb=image_rotary_emb,
)

位置编码 image_rotary_emb 最后会传入双流注意力计算类 FluxAttnProcessor2_0 和单流注意力计算类 FluxSingleAttnProcessor2_0。由于位置编码在这两个类中的用法都相同,我们就找 FluxSingleAttnProcessor2_0 的代码来看一看。在其 __call__ 方法中,可以看到,在做完了 Q, K 的投影变换、形状变换、归一化后,方法调用了 apply_rope 来执行旋转式位置编码的计算。而 apply_rope 会把 Q, K 特征向量的分量两两分组,根据之前的公式,模拟与位置编码的复数乘法运算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class FluxSingleAttnProcessor2_0:

def __call__(
self,
...
image_rotary_emb: Optional[torch.Tensor] = None,
):
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)

def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

这样,我们就看完了旋转式位置编码在 FLUX.1 里的实现。但是,我们还遗留了一个重要问题:在 NLP 中,句子天然有前后关系,我们按照 0, 1, 2, 3 给 token 编号就行了。而在这个模型中,既有图像 token,又有文本 token,该怎么给 token 编号呢?

图像及文本 token 的位置编号

现在,我们把目光倒回到流水线类。输入给去噪模型的序号变量有两个:text_idslatent_image_ids。它们是怎么得到的?

1
2
3
4
5
6
noise_pred = self.transformer(
...
txt_ids=text_ids,
img_ids=latent_image_ids,
...
)[0]

在文本编码方法中,我们看到,text_ids 竟然只是一个全零张量。它的第一维表示 batch 大小,第二维序列长度等于文本编码 prompt_embeds 的长度,第三维序号长度为 3。也就是说,对于每一个文本 token 的每一个位置,都用 (0, 0, 0) 来表示它的位置编号。这也暗示在 FLUX.1 中,token 的位置是三维的。

1
2
3
4
5
6
7
8
def encode_prompt(
...
):
...
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)

return prompt_embeds, pooled_prompt_embeds, text_ids

latent_image_ids 主要是在 _prepare_latent_image_ids 函数里生成的。这个函数的主要输入参数是图像的高宽。根据高宽,函数会生成 (0, 0) ~ (height, width) 的二维位置坐标表格,作为位置坐标 latent_image_ids 的第二、第三维。而位置坐标的第一维全是 0。也就是说,位置为 (i, j) 的像素的位置编号为 (0, i, j)。代码里给高宽除以 2 是因为输入没有考虑 2 倍的图块化,这写得真够乱的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)

return latent_image_ids.to(device=device, dtype=dtype)

def prepare_latents(...):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)

...

latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids

文本位置编号 txt_idsimg_ids 会在第二维,也就是序列长度那一维拼接成 idsids 会输入给 EmbedND 类的实例 pos_embedEmbedND 的构造函数参数中,dim 完全没有被用到,theta 控制编码的三角函数周期,axes_dim 表示位置坐标每一维的编码长度。比如 FLUX.1 的位置坐标是三维的, axes_dim[16, 56, 56],那么它就表示第一个维度用长度 16 的位置编码,后两维用长度 56 的位置编码。位置编号经 rope 函数计算得到旋转式位置编码后,会拼接到一起,最后形成 128 维的位置编码。注意,所有 Transformer 块每个头的特征数 attention_head_dim 也是 128。这两个值必须相等。

「头」指的是「多头注意力」里的「头」。头数乘上每次参与注意力运算的特征长度才等于总特征长度。由于位置编码是给 Q, K 准备的,所以位置编码的长度应该与参与注意力运算的特征长度相同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class FluxTransformer2DModel():
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
):
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)

def forward(...):
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)

class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim

def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)

我们来整理一下 FLUX.1 的位置编码机制。每个文本 token 的位置编号都是 (0, 0, 0)。位于 (i, j) 的像素的位置编号是 (0, i, j)。它们会生成 128 维的位置编码。编码前 16 个通道是第一维位置编号的位置编码,后面两组 56 个通道分别是第二维、第三位位置编号的位置编码。也就是说,在每个头做多头注意力运算时,特征的前 16 个通道不知道位置信息,中间 56 个通道知道垂直的位置信息,最后 56 个通道知道水平的位置信息。

乍看下来,这种位置编号方式还是非常奇怪的。所有 token 的第一维位置编号都是 0,这一维岂不是什么用都没有?

FLUX.1 旋转式位置编码原理猜测与实验

在这一节中,我将主观分析 FLUX.1 的现有源码,猜测 FLUX.1 未开源的 [pro] 版本中旋转式位置编码是怎么设置的。此外,我还会分享一些简单的相关实验结果。

已开源的 FLUX.1 为什么会出现 (0, 0, 0), (0, i, j) 这样奇怪的位置编号呢?由于现在已开源的两版模型是在 FLUX.1 [pro] 上指引蒸馏的结果,很可能原模型在指引机制,也就是和文本相关的处理机制上与现有模型不同。因此,我使用我独创的代码心理学,对现有源码进行了分析。

首先,令我感到疑惑的是采样流水线里生成位置编号的代码。latent_image_ids 一开始是一个全零张量,你写它加一个数,和直接赋值的结果不是一样的吗?为什么要浪费时间多写一个加法呢?

1
2
3
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]

为了确认这段代码不是 Diffusers 的开发者写的,我去看了 FLUX.1 的官方代码,发现他们的写法是一样的。在看 Diffusers 源码时,我们还看到了其他一些写得很差的代码,这些代码其实也都是从官方仓库里搬过来的。

1
2
3
4
5
6
7
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
...

img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

从这些代码中,我们不难猜出开发者的心理。FLUX.1 的开发者想,我们要赶快搞一个大新闻,论文也不写了,直接加班加点准备开源。Diffusers 的开发者一看,你们这么急,我们也得搞快一点。于是他们先把 SD3 的代码复制了一遍,然后又照搬了 FLUX.1 官方仓库里的一些逻辑,直接修改 SD3 的代码。

相信大家都有这样的代码重构经历:把自己写的个人开发代码,急忙删删改改,变成能给别人看的代码。能少改一点,就少改一点。上面的代码用加法而不是赋值,就是重构的时候代码没删干净的痕迹。这说明,一开始的 img_ids 很可能不是一个全零张量,而是写了一些东西在里面。

而另一边,设置文本位置编号的官方源码里,非常干脆地写着一个全零向量。我倾向于这部分代码没有在开源时改过。

1
txt_ids = torch.zeros(bs, txt.shape[1], 3)

那么,问题就来了,这个看似全零的图像位置编号一开始是什么?它对整个位置编码的设计有什么影响?

1
img_ids = torch.zeros(h // 2, w // 2, 3)

我猜开发者设置这个变量的目的是为了区分文本和图像 token。目前,所有文本 token 的位置编号是 (0, 0, 0),这其实不太合理,因为这种做法实际上是把所有文本 token 都默认当成位置为 (0, 0) 图像 token。为了区分文本和图像 token,应该还有其他设计。我猜最简单的方法是在第一维上做一些改动,比如令所有图像 token 的第一维都是 1。但看起来更合理的做法是对三个维度的编号都一些更改,比如给所有图像位置编号都加上一个常量 (a, b, c)。这样,图像 token 间的相对位置并不会变,而图像和文本 token 的相对位置就不同了,文本就不会默认在图像 (0, 0) 处了。从代码里的加法来看,我更倾向于认为 img_ids 原来是一个三个维度都有值的常量,且这个量或许是可以学习的。而在指引蒸馏时,位置编号的设计被简化了。

网上有人说文本位置编码全零是因为 t5 编码器自带位置编码。而在我看来,过了一个文本编码器后,文本的每个 token 已经包含所有文本的全局信息,文本 token 之间的位置编码在这里已经不重要了。重要的是文本 token 和图像 token 之间的「位置」关系,这并不能通过 t5 的位置编码来反映。

为了验证位置编码的作用,我尝试修改了图像位置编号的定义,还是跑本文开头那个测试示例。

如果把图像位置编号全置零,会得到下面的结果。这说明位置编码对结果的影响还是很大的,模型只能从位置编码处获取 token 间的相对关系。

如果把位置编号除以二,会得到下面的结果。我们能发现,图像好像变模糊了一点,且像素有锯齿化的倾向。这非常合理,因为位置编号除以二后,模型实际上被要求生成分辨率低一倍的结果。但突然又多了一些距离为 0.5 的像素,模型突然就不知道怎么处理了,最终勉强生成了这种略显模糊,锯齿现象明显的图片。注意哦,这里虽然像素间的关系不对,但图中的文字很努力地想要变得正常一点。

位置编号乘二的结果如下所示。可能模型并没有见过没有距离为 1 的图像 token 的情况,结果全乱套了。但尽管是这样,我们依然能看到图中的 “Hello World”。结合上面的结果,这说明文本指引对结果的影响还是很大的,正常的文本 token 在努力矫正图像生成结果。

位置编号乘 1.2 的结果如下所示。图像的结果还是比较正常的。这说明这套位置编码允许位置编号发生小的扰动,且模型能认识非整数的位置编号,即在模型看来,位置编号是连续的。

原图片和将位置编号第一维全置 1 的结果如下所示。如我所料,位置编号的第一维几乎没什么作用。图片只是某些地方发生了更改,整体的画面结构没有变化。

目前看下来,由于现在我们有了显式定义 token 相对位置关系的方法,要在 FLUX.1 上做一些图像编辑任务的科研,最容易想到地方就是位置编码这一块。我目前随便能想到的做法有两个:

  • 直接基于位置编号做超分辨率。想办法修改位置编码的机制,使得所有图像 token 距离 2 个单位时也能某种程度上正常输出图片。以此配置反演一张低分辨率图片,得到纯噪声,重新以图像 token 距离 1 单位的正常配置来生成图片,但旧像素不对新像素做注意力,再想一些办法控制文本那部分,尽量保持旧像素输出不变,最后就能得到两倍超分辨率的结果了。inpainting 似乎也能拿类似的思路来做。
  • 目前所有文本 token 的位置默认是 (0, 0),改变文本 token 的位置编号或许能让我们精确控制文本指定的生成区域。当然,这个任务在之前的 Stable Diffusion 里好像已经被做滥了。

总结

在这篇文章中,我们围绕 FLUX.1 相对 Stable Diffusion 3 的改动,仔细阅读了 FLUX.1 在 Diffusers 中的源码。这些改动具体总结如下:

  • SD3 是在去噪网络里用下采样 2 倍的卷积实现图块化,而 FLUX.1 通过把 $2 \times 2$ 个图像 token 在通道上堆叠直接实现图块化。
  • FLUX.1 目前公布的两个模型都是指引蒸馏过的。我们无需使用 Classifier-Free Guidance,只要把指引强度当成一个约束条件输出进模型,就能在一次推理中得到带指定指引强度的输出。
  • FLUX.1 遵照 Stable Diffusion 3 的噪声调度机制,对于分辨率越高的图像,把越多的去噪迭代放在了高噪声的时刻上。但相较 Stable Diffusion 3,似乎不仅训练时有这种设计,采样时也需要用到这种设计。
  • FLUX.1 将文本的位置编号设为 (0, 0, 0),图像的位置编号设为 (0, i, j),之后用标准的旋转式位置编码对三个维度的编号编码,再把三组编码拼接。这种看似不太合理的位置编号设计方式或许是指引蒸馏导致的,目前从源代码中看不出原 FLUX.1 模型的位置编号设计方式。
  • 在原 Stable Diffusion 的 MM-DiT 块之后,FLUX.1 将文本和图像 token 拼接,输入进了一个单流的 Transformer 块。该 Transformer 块遵照之前并行注意力层的设计,注意力层和 MLP 并联执行,在执行速度上有所提升。

FLUX.1 的总模型结构图如下所示。

作为最强开源 DiT 文生图模型,FLUX.1 狠狠打脸了拖拖拉拉刚开源没多久的 Stable Diffusion 3。可以预见,之后大家会把开发图像编辑工作的基础模型从 U-Net 版 Stable Diffusion 逐渐换成 FLUX.1。这方面的研究目前还是蓝海,值得大家投入精力研究。

FLUX.1 还是在科研上能给我们一些启示的。RoPE 都是 NLP 那边已经出了很久的工作了,直到现在才搬到图像生成这边来。我们或许能够把 NLP 或者其他视觉任务中使用的神经网络技术搬到图像生成这边来,不费什么力气地改进现有的图像生成模型。

但是,在搬运 NLP 技术中,我们也要思考如何更合理地在视觉应用中使用这些技术。文本和图像存在本质上的区别:文本是离散的,而图像是连续的。这种连续性不仅体现在图像的颜色值上,还体现在图像像素间的位置关系上。就以这里的旋转式位置编码为例,NLP 中,token 间的距离就得是整数。而在 CV 中,如果我们认为图像是一种连续信号,那么非整数的 token 距离或许也是有意义的。从文本和图像的本质区别出发,我们或许能够把 NLP 的技术更好地适配到 CV 上,而不是把 Transformer 搬过来,然后加数据一把梭。

自回归是一种根据之前已生成内容,不断递归预测下一项要生成的内容的生成模型。这种生成方式十分易懂,符合我们对生活的观察。比如我们希望模型生成一句话,第一个是「今」字,那么第二个字很可能就是「天」字。如果前三个字是「今天早」,那么第四个字就很可能是「上」。

1
2
3
4
(空)  -> 今
今 -> 天
今天 -> 早
今天早 -> 上

为这种自回归模型的而设计的 Transformer 网络在自然语言处理(NLP)中取得了极大的成功。然而,尽管许多人也尝试用它生成图像,自回归模型却一直没有成为最强大、最受欢迎的图像生成模型。

为了解决此问题,何恺明团队公布了论文 Autoregressive Image Generation without Vector Quantization。作者分析了目前最常见的自回归图像生成模型后,发现模型中的向量离散化 (Vector Quantization, VQ) 是拖累模型能力的罪魁祸首。作者用一些巧妙的方法绕过了 VQ,最终设计出了一种新式自回归模型。该模型在图像生成任务上表现出色,在 ImageNet 图像生成指标上不逊于最先进的图像扩散模型。在这篇博文中,我们就来学习一下这种新颖的无 VQ 自回归图像生成模型。

建议读者在阅读本文前熟悉 VQ-VAE、Transformer、DDPM 等经典工作,了解 NLP 和图像生成中连续值和离散值的概念。可以参考我之前写的文章:

轻松理解 VQ-VAE:首个提出 codebook 机制的生成模型

VQGAN 论文与源码解读:前Diffusion时代的高清图像生成模型

Transformer 论文精读

扩散模型(Diffusion Model)详解:直观理解、数学原理、PyTorch 实现

Stable Diffusion 解读(一):回顾早期工作

知识回顾

连续值与离散值

在计算机科学中,我们既会用到连续值,也会用到离散值。比如颜色就是一个常见的连续值,我们用 0~1 之间的实数表示灰度从全黑到全白。而词元 (token) 需要用离散值表示,比如我们用 “0” 表示字母 “A”,”1” 表示 “B”, “2” 表示 “C”,并不代表 “B” 是 「’A’ 和 ‘C’ 的平均值」。离散值的数值只是用来区分不同概念的。

神经网络默认输入是连续变化的。因此,一个连续值可以直接输入进网络。而代表离散值的整数不能直接输入网络,需要先过一个嵌入层,再正常输入进网络。

自回归与类别分布

在自回归文本生成模型中,为了不断预测下一个词元,通常的做法是用一个神经网络建模下一个词元的类别分布(categorical distribution)。如下面的例子所示,所谓类别分布,就是下一步选择每一个词元的概率。有了概率分布后,我们就能用采样算法采样出下一个词元。

要训练这个预测模型也很简单。每次预测下一个词元的类别分布,其实就是一个分类任务。我们直接照着分类任务的做法,以数据集里现有句子为真值,用交叉熵损失函数就能训练这个预测模型了。

自回归图像生成

由于 Transformer 在 NLP 中的成功,大家也想用 Transformer 做图像生成。在用自回归模型生成图像时,需要考虑图像和文本的两个区别:

  1. 文本是一维的,天然有先后顺序以供自回归生成。而图像是二维的,没有先后顺序。
  2. 图像的颜色值是连续而非离散的。而只有离散值才能用类别分布表示。

解决问题 1 的方法很简单:没有先后顺序,我们就人工定义一个先后顺序就好了,比如从左上到右下给图像编号。

而对于问题 2,一种最简单的方式是把连续的颜色值离散化。比如将原来 0 ~ 1 的灰度值转换为「0 号灰度」、「1 号灰度」、…… 「7号灰度」。神经网络像对待词元一样对待这些灰度值,不知道它们之间的大小关系,只知道生成图像的颜色只能由这 8 种「颜色词语」构成。

向量离散化

把颜色值离散化后,我们的确可以用自回归做图像生成了。但是,由于图像的像素数比文章的词元数要多很多,这种逐像素生成方式会非常慢。为了加速自回归生成,VQ-VAE, VQGAN 等工作借由向量离散化自编码器(VQ 自编码器)实现了一个两阶段的图像生成方法:

  • 训练时,先训练一个包括编码器 (encoder) 和解码器 (decoder) 两个子模型的 VQ 自编码器,再训练一个生成压缩图像的自回归模型。
  • 生成时,先用自回归模型生成出一个压缩图像,再用 VQ 自编码器将其复原成真实图像。

相比普通的自编码器,VQ 自编码器有一项特点:它生成的压缩图像仅由离散值组成。这样,它就同时完成了两项任务,使得自回归模型能够高效地实现图像生成:1)将连续图像变成离散图像;2)减少要生成的像素数。

如果你还是不太理解 VQ 的作用,请先回顾 VQ-VAE 工作,再来学习这篇工作。

抛弃 VQ,拥抱扩散模型

我们来总结一下为什么要使用基于 VQ 的自回归图像生成:大家想用基于 Transformer 的自回归模型做图像生成。自回归模型在预测下一个词元/像素时,通常会用一个类别分布来建模下一项数据。由于类别分布只能描述离散数据,而图像又是连续数据,我们需要把连续像素值变成离散值。一种常用的将连续图像变成离散图像的方法是 VQ 自编码器,它既能减少图像尺寸以提高生成效率,又能将连续图像变成离散图像。

但相比普通的自编码器,如 VAE,VQ 自编码器有着一些缺点:

  • VQ 自编码器很难训练
  • VQ 自编码器的重建效果没有 VAE 好。比如在 Stable Diffusion 中,开发者选择了用 VAE 而不是 VQ-VAE 作为自编码器

出于抛弃 VQ 的想法,论文的作者发问道:「自回归图像生成真的需要和 VQ 绑定起来吗?」注意到,在我们刚刚阐述使用 VQ 自回归生成的动机时,用了几个「通常」、「常用」这样的非肯定词。这表明我们的这条推理链不是必然的。要取代 VQ,我们可以从两个方面入手:

  1. 换一种更强力的把连续图像变成离散图像的方法
  2. 从更根本处入手,不用类别分布来建模下一项数据

论文的作者选择了第二种做法:不就是建模一个像素值的分布吗?我们为什么要用死板的类别分布呢?既然扩散模型如此强大,能够拟合复杂的图像分布,那用它来拟合一个像素值的分布还不是轻轻松松?论文的核心思想也就呼之欲出了:用扩散模型而不是类别分布来建模自回归模型中下一个像素值的分布,从而抛弃自编码器里的 VQ 操作,提升模型能力。

可能读者第一次看到这个想法时会有些疑惑:扩散模型不是用来生成一整张图像的吗?它怎么建模一个像素值的分布?它和自回归模型又有什么关系?我们来多花点时间深入理解这个想法。

在文本自回归生成中,输入是已生成文本,输出是下一个词元的类别分布。

而在图像自回归生成中,输入是已生成像素,输出是下一个像素的类别分布。现在,我们希望不用类别分布,而用另一种方式,根据之前的像素生成出下一个像素

论文作者从扩散模型中获取了灵感。扩散模型是一种强力的生成模型,它可以不根据任何信息,或根据类别、文本等信息,隐式建模训练集的图像分布,从而生成符合训练集分布的图像。既然扩散模型能够建模复杂的图像分布,那它也可以根据之前像素的信息,建模下一个像素的分布。

那么,在这种新式自回归模型里,我们可以用约束于 Transformer 输出的上下文信息的扩散模型来建模下一个像素的分布,尽管现在我们并不知道每种颜色出现的概率。

这样做的好处是,以前我们只能用离散的有限类型的颜色(准确来说是图像词元)来表示图像,现在我们能够用连续值来表示图像。模型能够更加轻松地生成内容丰富的图像。

当然,抛弃了 VQ 后,自回归模型确实不需要 VQ 自编码器来把连续图像变成离散图像了。但是,我们依然需要用自编码器来压缩图像,减少要生成的像素数。本工作依然采取了 VQ-VAE、VQGAN 那种两阶段的生成方式,只不过把 VQ 自编码器换成了用 KL loss 约束的 VAE。

训练这种扩散模型的方法很简单。在每一步训练时,我们知道上下文像素是什么,也知道当前像素的真值是什么。那么,只要以上下文像素为约束,用当前像素的真值去训练一个带约束扩散模型就行了。作者把训练这种隐式描述下一个像素值分布的误差函数称为 Diffusion Loss。

具体来说,本工作使用了最基础的带约束 DDPM 扩散模型。它和标准 DDPM 的唯一区别在于误差函数多了一个约束信息 $z$,该信息是上下文像素过 Transformer 的输出。

$t$ 时刻的噪声图像 $x_t$也是由 DDPM 加噪公式得来的。

Diffusion Loss 不仅可以用来训练表示分布的扩散模型,还可以训练前面提取上下文信息的 Transformer。由于约束信息 $z$ 来自 Transformer,可以把 Diffusion Loss 的梯度通过 $z$ 回传到 Transformer 的参数里。

扩散模型的采样公式也和 DDPM 的一样,这里不再赘述。特别地,以前的自回归模型在使用类别分布时,会用温度来控制采样的多样性。为了在扩散模型中也加入类似的温度参数,本工作参考了 Diffusion models beat GANs on image synthesis 论文的有关设计。

在具体模型超参数上,本工作的 DDPM 训练时有 1000 步,采样时有 100 步。乍看之下,DDPM 会为整个生成模型增加许多计算量,但由于只需要建模一个像素的分布,这套模型的 DDPM 可以用非常轻量级的结构。默认配置下,这套模型的 DDPM 的去噪模型是一个由 3 个残差块组成小型 MLP。每个残差块由 LayerNorm、线性层、SiLU、线性层组成。约束信息 $z$ 会和时刻 $t$ 的编码加在一起,用 DiT (Scalable diffusion models with Transformers) 里的 AdaLN 约束机制输入进 LayerNorm 层里。

套用更先进的自回归模型

仅是去掉 VQ,把 Diffusion Loss 加进标准自回归模型,并不能得到一个很好的图像生成模型。于是,作者用更加先进的一些自回归模型(掩码生成模型 Masked Gernerative Models,如 MaskGIT: Masked generative image TransformerMAGE: Masked generative encoder to unify representation learning and image synthesis)代替标准自回归模型,极大提升了模型的生成能力。

双向注意力

在标准 Transformer 中(如下图 (a) causal 所示),每一个词元只能看到自己及之前词元的信息。这样做的好处是模型能够并行训练,串行推理。训练和推理的速度都会比较快。但是,由于每个词元看不到后面词元的信息,Transformer 提取整个句子(图像)特征的能力会下降。

而 MAE (Masked autoencoders are scalable vision learners) 论文提出了一种双向注意力机制,它可以让词元两两之间都传递信息。但是,这样模型就不能用同一个句子并行训练了,也失去了 KV cache 加速推理的手段。

如果你不太了解 Transformer 为什么是并行训练,请仔细回顾 Transformer 论文中有关自回归机制的描述。

广义自回归模型

除了双向注意力外,作者还将一些掩码生成模型的设计融合进标准自回归模型。这种广义上的自回归模型效果更好,且能缓解双向注意力导致的推理速度慢的问题。

一般来说,用图像自回归模型时,我们都是按从左到右,从上到下的顺序生成词元,如下图 (a) 所示。但是,这种顺序不一定是最合理的。

按理来说,模型应该可以通过任何顺序生成词元,这样模型学到的生成方式更加多样。更合理的生成方式应该如下图 (b) 所示,不是从左到右,从上到下给词元编号,而是随机选择一个排列给图像编号。这样就能按照随机的顺序生成图像的词元了。

而在掩码自回归生成中,模型可以一次性生成任意一个集合的词元。因此,为了加速 (b) 模型,我们可以如下图 (c) 所示,在随机给词元编号后一次生成多个词元。(b) 可以看成是 (c) 一次只预测下一个词元的特例。

Transformer 模型配置

本工作并没有给 Transformer 加入新设计,我们来确认一遍论文中介绍的 Transformer 配置。

本工作依然采取了两阶段的生成方法。第一个阶段的自编码器(又可以理解成 NLP 中的 tokenizer)来自 LDM 工作官方仓库的 VQ-16 和 KL-16 模型。前者是 VQ 自编码器(VQGAN),后者是一个加强版的 VAE。

本工作用的 Transformer 和 ViT 一样。得到图像词元后,词元会加上位置编码,且词元序列开头会附加一个 [cls] 词元,用以在类别约束生成任务里输入类别。

基于这个类别词元,本工作使用了一种特别的 Classifier-free guidance (CFG) 机制:模型用一个假类别词元来表示「类别不明」。训练时,10% 的正确类别词元被替换成了假类别词元。这样,在用扩散模型时,就可以根据标准 CFG 的做法,用正确类别和假类别实现 CFG。详情请参见论文附录 B。

在训练掩码自回归模型时,70%~100% 的词元是未知的。由于采样序列可能会很短,作者在输入序列前附加了 64 个 [cls] 词元。掩码自回归模型的其他主要设计都与 MAE 相同。

实验结果

本工作面向的是图像生成任务,主要评估 ImageNet 数据集上按类别生成的 FID 和 IS 指标。FID 越低越好,IS 越高越好。这篇工作的实验结果中有许多信息,让我们来仔细看一看这份结果。

Diffusion Loss 与广义自回归模型

论文首先展示了 Diffusion Loss、广义自回归模型这两项主要设计的优越性,如下表所示。由于图像是按类别生成的,可以用 CFG 提升模型的生成效果。为了公平比较,模型使用的 VQ 自编码器和 KL 自编码器都来自 LDM 仓库。

表格的 4 大行展示了改进自回归模型的影响,每一大行里不同 loss 的对比体现了 Diffusioin Loss 的影响。

从第一大行可以看出,Diffusion Loss 似乎对标准自回归的改进不是很明显,且这一套方法的生成能力并不出色。只有把自回归模型逐渐改进后,Diffusion Loss 的效果才能逐渐体现出来。在后几行掩码自回归模型中,Diffusion Loss 的作用还是很大的。

而对比前三大行,我们可以发现自回归模型的架构极大地提升了生成效果,且似乎将 Transformer 由 causal 改成 bidirect 的提升更加显著。

第四大行相比第三大行,提升了每次预测的词元数,主要是为了加速。这两行的对比结果表明,做了这个加速操作后,模型生成能力并没有下降多少。后续实验都是基于第四行的配置。

Diffusion Loss 适配不同的自编码器

相比原来类别分布,用 Diffusion Loss 解除了自编码器必须输出离散图像的限制。因此,目前的模型能够适配多种自编码器,如下表所示。图中 rFID 指的是图像重建任务的 FID,越低越好。这里的 VQ-16 指的是将 VQGAN 的 VQ 层当作解码器的一部分,这样 VQGAN 的编码器输出也可以看成是连续图像,和 LDM 里的做法一样。最后一行的 KL-16 是作者重新重新在 ImageNet 上训练的 VAE,而前两行的 VQ-16 和 KL-16 是在 OpenImages 上训练的。由于后文的实验都基于 ImageNet,所以后文都会用第五行那个 VAE。

首先对比一下这里 VQ-16 w/o CFG 的 FID 和上表里最后一大行 CrossEnt 的 FID。这两组实验的自编码器相同,仅有误差函数不同。将误差函数从交叉熵换成了 Diffusion Loss 后,FID 从 8.79 变成了 7.82。这一项直接对比的实验证明了不考虑自编码器的改进时,Diffusion Loss 本身的优越性。

再对比前两行,KL 的自编码器无论是图像恢复指标还是最后的生成指标都优于 VQ 的自编码器。这印证了论文开头想要抛弃 VQ 自编码器的动机:VQ 自编码器逊于 KL 自编码器。

第三、第四行展示了方法也可以兼容下采样 8 倍的自编码器。本来测试用的 ImageNet 是 $256 \times 256$ 大小的,按照一开始下采样 16 倍的配置,能得到 $16 \times 16$ 的压缩图像,即输入 Transformer 的词元序列长度为 $16 \times 16$。现在改成了下采样 8 倍后,为了兼容之前 $16 \times 16$ 的序列长度,作者把 $2 \times 2$ 个像素打包成一个词元。论文里没讲是怎么打包的,我猜测是在通道上拼接。Consistency 是另一套自编码器,作者展示这个估计是为了说明这套方法兼容性很强。

和 SOTA 图像生成模型对比

为了证明方法的优越性,论文还展示了本工作与其他 SOTA 工作在 ImageNet 图像生成任务上的定量对比结果。下表是 ImageNet $256 \times 256$ 的结果。为了方便对比,我还贴出了 DiT 论文里展示的表格(左表)。本文的模型在表里被称作 MAR。

下表是 ImageNet $512 \times 512$ 的结果。左边那张表是 EDM2 展示的结果。

从表里可以看出,本工作在 ImageNet 图像生成任务上表现很不错,超越了绝大多数模型。

图像生成速度对比

下面是不同生成模型的速度对比结果。第一张图是本论文展示的和 DiT 的对比结果。DIT 采用的扩散模型采样步数是 (50 ,75, 150, 250)。由于本工作的性能瓶颈在自回归模型而不在扩散模型上,所以本工作展示的不同采样步数由自回归步数决定。图中的自回归步数是 (8, 16, 32, 64, 128)。中间的图是 LDM 的结果,同模型不同点表示的是采样步数为 (10, 20, 50, 100, 200) 的结果。右边的表是 EDM2 的采样速度等指标。左边两张图是 ImageNet $256 \times 256$ 上的,最右边的表是 ImageNet $512 \times 512$ 上的。

由于不同图表的采样速度指标不太一样,我们将指标统一成每秒生成的图像。从第一张图的对比可以看出,DiT 最快也是一秒 2.5 张图像左右,而 MAR 又快又好,默认(自回归步数 64)一秒生成 3 张图左右。同时,通过 MAR 和有 kv cache 加速的标准 AR 的对比,我们能发现 MAR 在默认自回归步数下还是比标准 AR 慢了不少。

我们再看中间 LDM 的速度。我们观察一下最常使用的 LDM-8。如果是令 DDIM 步数为 20 (第二快的结果)的话,LDM-8 的生成速度在一秒 16 张图像左右,还是比 MAR 快很多。DDIM 步数取 50 时也会比 MAR 快一些。

最后看右边较新的图像扩散模型 EDM2 的速度。由于这个是在 $512 \times 512$ 的图片上测试的,和前面的速度相比时大概要乘个 4。哪怕是最大的 XXL 模型,在有 guidance 时,生成速度也是 2 张图片每秒。换算到 $256 \times 256$ 上约 8 张图片每秒,还是比 MAR 快。

总结

自回归图像生成中的向量离散化和类别分布必须同时使用。为了去除表现较差的向量离散化操作,本工作的作者重新用扩散模型建模了自回归中下一个图像词元的分布,从而提升了模型的生成能力。由于标准自回归模型生成能力有限,为了进一步提升模型,作者又引入了最新的掩码自回归模型。最终的模型在 ImageNet 图像生成指标上取得了几乎最顶尖的结果。

以上是论文的叙述逻辑。但掩码自回归那一块应该是之前工作的研究成果,这篇文章实际上就是把新提出的 diffusion loss 用到了掩码自回归上,把本来在 ImageNet 上生成能力尚可的掩码自回归推到了最前列。

这篇文章在科研上的最大创新是打破了大家在图像自回归上的固有思维,认为必须用离散词元,必须用类别分布。但仔细一想,建模一个分布的方法其实许许多多。随便把另一种生成完整图像的模型用到生成一个像素上,就能取代之前的类别分布,得到更好的图像生成结果。这篇文章用简单的 DDPM 只是为了验证这个想法的可行性,用更复杂的模型或许能有更好的结果,但用 DDPM 做验证就足够了。之后肯定会有各种后续工作,研究如何用更好的模型来建模本框架中一个像素值的分布。

反过来想,这篇文章也在提醒我们,扩散模型并不只是可以用来生成图像,它的本质是建模一个分布。如果某个模型中间需要建模一个简单的分布的话,都可以尝试用 DDPM。

相比其科研创新,这篇文章在 ImageNet 图像生成指标的成就反而没有那么耀眼了。本工作在 ImageNet 的 FID 等指标上取得了几乎最优的结果,战胜了多数最强的扩散模型,有望将大家的科研眼光从扩散模型移到自回归上。但由于自回归本身步数较多,且每一步要在 Transformer 里做完整的注意力操作,这种方法的速度还是比扩散模型要慢一点。

目前 GitHub 上已有本工作的复现:https://github.com/lucidrains/autoregressive-diffusion-pytorch

Diffusers 库为社区用户提供了多种扩散模型任务的训练脚本。每个脚本都平铺直叙,没有多余的封装,把训练的绝大多数细节都写在了一个脚本里。这种设计既能让入门用户在不阅读源码的前提下直接用脚本训练,又方便高级用户直接修改脚本。

可是,这种设计就是最好的吗?关于训练脚本的最佳设计风格,社区用户们往往各执一词。有人更喜欢更贴近 PyTorch 官方示例的写法,而有人会喜欢用 PyTorch Lightning 等封装度高、重复代码少的库。而在我看来,选择哪种风格的训练脚本,确实是个人喜好问题。但是,在开始使用训练脚本之前,我们要从细节入手,理解训练脚本到底要做哪些事。学懂了之后,不管是用别人的训练库,还是定制适合自己的训练脚本,都是很轻松的。不管怎么说,Diffusers 的这种训练脚本是一份很好的学习素材。

当然,我在用 Diffusers 的训练脚本时,发现一旦涉及多类任务的训练,比如既要能训练 Stable Diffusion,又要能训练 VAE,那么这份脚本就会用起来比较困难,而写两份训练脚本又会有很大的冗余。Diffusers 的训练脚本依然有改进的空间。

在这篇文章中,我会主要面向想系统性学习扩散模型训练框架的读者,先详细介绍 Diffusers 官方训练脚本,再分享我重构训练脚本的过程,使得脚本能够更好地兼容多类模型的训练。文章的末尾,我会展示几个简单的扩散模型训练实例。

在阅读本文时,建议大家用电脑端,一边看源代码一边读文章。「官方训练脚本细读」一节细节较多,初次阅读时可以快速浏览,看完「训练脚本内容总结」中的流程图,再回头仔细看一遍。

准备源代码

我们将以最简单的 DDPM 官方训练脚本 examples/unconditional_image_generation/train_unconditional.py 为例,学习训练脚本的通用写法。examples 文件夹在位于 Diffusers 官方 GitHub 仓库中,用 pip 安装的 Diffusers 可能没有这个文件夹,最好是手动 clone 官方仓库,再在本地查看这个文件夹。使用 Diffusers 训练时,可能还要安装其他库。官方在不同的训练教程里给了不同的安装指令,建议大家都安装上。

1
2
3
cd examples/text_to_image
pip install -r requirements.txt
pip install diffusers[training]

我为本教程准备的脚本在仓库 https://github.com/SingleZombie/DiffusersExample 中。请 clone 这个仓库,再切换到 TrainingScript 目录下。train_official.py 是原官方训练脚本 train_unconditional.pytrain_0.py 是第一次修改后的训练脚本
train_1.py 是第二次修改后的训练脚本。

官方训练脚本细读

先拉到文件的最底部,我们能在这找到程序的入口。在 parse_args 函数中,脚本会用 argparse 库解析命令行参数,并将所有参数保存在 args 里。args 会传进 main 函数里。稍后我们看到所有 args. 打头的变量调用,都表明该变量来自于命令行参数。

1
2
3
if __name__ == "__main__":
args = parse_args()
main(args)

接着,我们正式开始学习训练主函数。一开始,函数会配置 accelerate 库及日志记录器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=logging_dir)

# a big number for high resolution or big dataset
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))
accelerator = Accelerator(...)

if args.logger == "tensorboard":
if not is_tensorboard_available():
...

elif args.logger == "wandb":
if not is_wandb_available():
...
import wandb

在配置日志的中途,函数插入了一段修改模型存取逻辑的代码。为了让我们阅读代码的顺序与实际运行顺序一致,我们等待会用到了这段代码时再回头来读。

1
2
3
4
5
6
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
def save_model_hook(models, weights, output_dir):
...
def load_model_hook(models, input_dir):
...

跳过上面的代码,还是日志配置。

1
2
3
4
5
6
7
8
9
10
11
12
13
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()

之后其他版本的训练脚本会有一段设置随机种子的代码,我们给这份脚本补上。

1
2
3
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)

接着,函数会创建输出文件夹。如果我们想把模型推送到在线仓库上,函数还会创建一个仓库。

这段代码还出现了一行比较重要的判断语句:if accelerator.is_main_process:。在多卡训练时,只有主进程会执行这个条件语句块里的内容。该判断在并行编程中十分重要。很多时候,比如在输出、存取模型时,我们只需要让一个进程执行操作就行了。这个时候就要用到这行判断语句。

1
2
3
4
5
6
7
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(...).repo_id

准备完辅助工具后,函数开始准备模型。输入参数里的 model_config_name_or_path 表示预定义的模型配置文件。如果该配置文件不存在,则函数会用默认的配置创建一个 DDPM 的 U-Net 模型。在写我们自己的训练脚本时,我们需要在这个地方初始化我们需要的所有模型。比如训练 Stable Diffusion 时,除了 U-Net,需要在此处准备 VAE、CLIP 文本编码器。

1
2
3
4
5
6
# Initialize the model
if args.model_config_name_or_path is None:
model = UNet2DModel(...)
else:
config = UNet2DModel.load_config(args.model_config_name_or_path)
model = UNet2DModel.from_config(config)

这份脚本还帮我们写好了维护 EMA(指数移动平均)模型的功能。EMA 模型用于存储模型可学习的参数的局部平均值。有时 EMA 模型的效果会比原模型要好。

1
2
3
4
5
6
7
# Create EMA for the model.
if args.use_ema:
ema_model = EMAModel(
model.parameters(),
model_cls=UNet2DModel,
model_config=model.config,
...)

此处函数还会根据 accelerate 配置自动设置模型的精度。

1
2
3
4
5
6
7
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
args.mixed_precision = accelerator.mixed_precision

函数还会尝试启用 xformers 来提升 Attention 的效率。PyTorch 在 2.0 版本也加入了类似的 Attention 优化技术。如果你的显卡性能有限,且 PyTorch 版本小于 2.0,可以考虑使用 xformers

1
2
3
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
...

准备了 U-Net 后,函数会准备噪声调度器,即定义扩散模型的细节。

注意,扩散模型不是一个神经网络,而是一套定义了加噪、去噪公式的模型。扩散模型中需要一个去噪模型来去噪,去噪模型一般是一个神经网络。

1
2
3
4
5
6
7
# Initialize the scheduler
accepts_prediction_type = "prediction_type" in set(
inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_prediction_type:
noise_scheduler = DDPMScheduler(...)
else:
noise_scheduler = DDPMScheduler(...)

准备完所有扩散模型组件后,函数开始准备其他和训练相关的模块。其他版本的训练脚本会在这个地方加一段缓存梯度和自动放缩学习率的代码,我们给这份脚本补上。

1
2
3
4
5
6
7
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()

if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

函数先准备的训练模块是优化器。这里默认使用的优化器是 AdamW

1
2
3
4
5
6
7
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

函数随后会准备训练集。这个脚本用 HuggingFace 的 datasets 库来管理数据集。我们既可以读取在线数据集,也可以读取本地的图片文件夹数据集。自定义数据集的方法可以参考 https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

1
2
3
4
5
6
7
8
9
10
11
12
if args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
split="train",
)
else:
dataset = load_dataset(
"imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

有了数据集后,函数会继续准备 PyTorch 的 DataLoader。在这一步中,除了定义 DataLoader 外,我们还要编写数据预处理的方法。下面这段代码的编写顺序和执行顺序不同,我们按执行顺序来整理一遍下面的代码:

  1. 将预定义的预处理函数传给数据集对象 dataset.set_transform(transform_images)。在使用数据集里的数据时,才会调用这个函数预处理图像。
  2. 使用 PyTorch API 定义 DataLoader。train_dataloader = ...
  3. 每次用 DataLoader 获取数据时,一个数据词典 examples 会被传入预处理函数 transform_imagesexamples 里既包含了图像数据,也包含了数据的各种标签。而对于无约束图像生成任务,我们只需要图像数据,因此可以直接通过词典的 "image" 键得到 PIL 格式的图像数据。用 convert("RGB") 把图像转成三通道后,该 PIL 图像会被传入预处理流水线。
  4. 图像预处理流水线 augmentations 是用 Torchvision 里的 transform API 定义的。默认的流水线包括短边缩放至指定分辨率、按分辨率裁剪、随机反转、归一化。
  5. 处理过的数据会被存到词典的 "input" 键里。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Preprocessing the datasets and DataLoaders creation.
augmentations = transforms.Compose(
[
transforms.Resize(
args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(
args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def transform_images(examples):
images = [augmentations(image.convert("RGB"))
for image in examples["image"]]
return {"input": images}

logger.info(f"Dataset size: {len(dataset)}")

dataset.set_transform(transform_images)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)

在准备工作的最后,函数会准备学习率调度器。

1
2
3
4
5
6
7
# Initialize the learning rate scheduler
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs),
)

准备完了所有模块,函数会调用 accelerate 库来把所有模块变成适合并行训练的模块。

1
2
3
4
5
6
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)

if args.use_ema:
ema_model.to(accelerator.device)

之后函数还会用 accelerate 库配置训练日志。默认情况下日志名 run 由当前脚本名决定。如果不想让之前的日志被覆盖的话,可以让日志名 run 由当前的时间决定。

1
2
3
if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run)

马上就要开始训练了。在此之前,函数会准备全局变量并记录日志。注意,这里函数会算一次总的 batch 数,它由输入 batch 数、进程数(显卡数)、梯度累计步数共同决定。梯度累计是一种用较少的显存实现大 batch 训练的技术。使用这项技术时,训练梯度不会每步优化,而是累计了若干步后再优化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
total_batch_size = args.train_batch_size * \
accelerator.num_processes * args.gradient_accumulation_steps
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
max_train_steps = args.num_epochs * num_update_steps_per_epoch

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(
f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_train_steps}")

global_step = 0
first_epoch = 0

在开始训练前,如果设置了 args.resume_from_checkpoint,则函数会读取之前训练过的权重。负责读取训练权重的函数是 load_state

1
2
3
4
5
6
7
8
9
10
11
12
13
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = ..
else:
# Get the most recent checkpoint
...

if path is None:
...
else:
accelerator.load_state(os.path.join(args.output_dir, path))
accelerator.print(f"Resuming from checkpoint {path}")
...

在每个 epoch 中,函数会重置进度条。接着,函数会进入每一个 batch 的训练迭代。

1
2
3
4
5
6
7
# Train!
for epoch in range(first_epoch, args.num_epochs):
model.train()
progress_bar = tqdm(total=num_update_steps_per_epoch,
disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):

如果是继续训练的话,训练开始之前会更新当前的步数 step

1
2
3
4
5
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

训练的一开始,函数会从数据的 "input" 键里取出图像数据。此处的键名是我们之前在数据预处理函数 transform_images 里写的。

1
clean_images = batch["input"].to(weight_dtype)

之后函数会设置扩散模型训练中的其他变量,包含随机噪声、时刻。由于本文的重点并不是介绍扩散模型的原理,这段代码我们就快速略过。

1
2
3
4
noise = torch.randn(...)
timesteps =...
noisy_images = noise_scheduler.add_noise(
clean_images, noise, timesteps)

接下来,函数会用去噪网络做前向传播。为了让模型能正确累计梯度,我们要用 with accelerator.accumulate(model): 把模型调用与反向传播的逻辑包起来。在这段代码中,我们会先得到模型的输出 model_output,再根据扩散模型得到损失函数 loss,最后用 accelerate 库的 API accelerator 代替原来 PyTorch API 来完成反向传播、梯度裁剪,并完成参数更新、学习率调度器更新、优化器更新。

1
2
3
4
5
6
7
8
9
10
11
12
13
with accelerator.accumulate(model):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample

loss = ...

accelerator.backward(loss)

if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

确保一步训练结束后,函数会更新和步数相关的变量。

1
2
3
4
5
if accelerator.sync_gradients:
if args.use_ema:
ema_model.step(model.parameters())
progress_bar.update(1)
global_step += 1

在这个地方,函数还会尝试保存模型。默认情况下,每 args.checkpointing_steps 步保存一次中间结果。确认要保存后,函数会算出当前的保存点名称,并根据最大保存点数 checkpoints_total_limit 决定是否要删除以前的保存点。做完准备后,函数会调用 save_state 保存当前训练时的所有中间变量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
f accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [
d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1]))

if len(checkpoints) >= args.checkpoints_total_limit:
...

save_path = os.path.join(
args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

在这个地方,主函数开头设置的存取模型回调函数终于派上用场了。在调用 save_state 时,会自动触发下面的回调函数来保存模型。如果不加下面的代码,所有模型默认会以 .safetensor 的形式存下来。而用了下面的代码后,模型能够被 save_pretrained 存进一个文件夹里,就像其他标准 Diffusers 模型一样。

这里的输入参数 models 来自于之前的 accelerator.prepare,感兴趣可以去阅读文档或源码。

1
2
3
4
5
6
7
8
9
10
11
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema:
ema_model.save_pretrained(
os.path.join(output_dir, "unet_ema"))

for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

与上面的这段代码对应,脚本还提供了读取文件的回调函数。它会在继续中断的训练后调用 load_state 时被调用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(
os.path.join(input_dir, "unet_ema"), UNet2DModel)
ema_model.load_state_dict(load_model.state_dict())
ema_model.to(accelerator.device)
del load_model

for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
load_model = UNet2DModel.from_pretrained(
input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model

两个回调函数需要用下面的代码来设置。

1
2
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

回到最新的代码处。训练迭代的末尾,脚本会记录当前步的日志。

1
2
3
4
5
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.cur_decay_value
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

执行完了一个 epoch 后,脚本调用 accelerate API 保证所有进程均训练完毕。

1
2
progress_bar.close()
accelerator.wait_for_everyone()

此处脚本可能会在主进程中验证模型或保存模型。如果当前是最后一个 epoch,或者达到了配置指定的验证/保存时刻,脚本就会执行验证/保存。

1
2
3
4
5
6
if accelerator.is_main_process:
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
...

if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
...

脚本默认的验证方法是随机生成图片,并用日志库保存图片。生成图片的方法是使用标准 Diffusers 采样流水线 DDPMPipeline。由于此时模型 model 可能被包裹成了一个用于多卡训练的 PyTorch 模块,需要用相关 API 把 model 解包成普通 PyTorch 模块 unet。如果使用了 EMA 模型,为了避免对 EMA 模型的干扰,此处需要先保存 EMA 模型参数,采样结束再还原参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
unet = accelerator.unwrap_model(model)
if args.use_ema:
ema_model.store(unet.parameters())
ema_model.copy_to(unet.parameters())

pipeline = DDPMPipeline(
unet=unet,
scheduler=noise_scheduler,
)

generator = torch.Generator(device=pipeline.device).manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(...).images

if args.use_ema:
ema_model.restore(unet.parameters())

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")

if args.logger == "tensorboard":
...
elif args.logger == "wandb":
...

在保存模型时,脚本同样会先用去噪模型 model 构建一个流水线,再调用流水线的保存方法 save_pretrained 将扩散模型的所有组件(去噪模型、噪声调度器)保存下来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
unet = accelerator.unwrap_model(model)

if args.use_ema:
ema_model.store(unet.parameters())
ema_model.copy_to(unet.parameters())

pipeline = DDPMPipeline(
unet=unet,
scheduler=noise_scheduler,
)

pipeline.save_pretrained(args.output_dir)

if args.use_ema:
ema_model.restore(unet.parameters())

if args.push_to_hub:
upload_folder(...)

一个 epoch 训练的代码就到此结束了。所有 epoch 的训练结束后,脚本调用 API 结束训练。这个 API 会自动关闭所有的日志库。训练代码到这里也就结束了。

1
accelerator.end_training()

训练脚本内容总结

大概熟悉了一遍这份训练脚本后,我们可以用下面的流程图概括训练脚本的执行顺序和主要内容。

去掉命令行参数

我不喜欢用命令行参数传训练参数,而喜欢把训练参数写进配置文件里,理由有:

  • 我一般会直接在命令行里手敲命令。如果命令行参数过多,我则会把要运行的命令及其参数保存在某文件里。这样还不如把参数写在另外的文件里。
  • 将大量参数藏在一个词典 args 里,而不是把所有需用的参数在某处定义好,是一种很差的编程方式。各个参数将难以追踪。

在正式重构脚本之前,我做的第一步是去掉脚本中原来的命令行参数,将所有参数先塞进一个数据类里面。脚本将只留一个命令行参数,表示参数配置文件的路径。具体做法如下:

先编写一个存命令行参数的数据类。这个类是一个 Python 的 dataclass。Python 中 dataclass 是一种专门用来放数据的类。定义数据类时,我们只需要定义类中所有数据的类型及默认值,不需要编写任何方法。初始化数据类时,我们只需要传一个词典或列表。一个示例如下(示例来源 https://www.geeksforgeeks.org/understanding-python-dataclasses/):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from dataclasses import dataclass

# A class for holding an employees content
@dataclass
class employee:

# Attributes Declaration
# using Type Hints
name: str
emp_id: str
age: int
city: str


emp1 = employee("Satyam", "ksatyam858", 21, 'Patna')
emp2 = employee("Anurag", "au23", 28, 'Delhi')
emp3 = employee({"name": "Satyam",
"emp_id": "ksatyam858",
"age": 21,
"city": 'Patna'})

print("employee object are :")
print(emp1)
print(emp2)
print(emp3)

我们可以用 dataclass 编写一个存储所有命令行参数的数据类,该类开头内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from dataclasses import dataclass

@dataclass
class BaseTrainingConfig:
# Dir
logging_dir: str
output_dir: str

# Logger and checkpoint
logger: str = 'tensorboard'
checkpointing_steps: int = 500
checkpoints_total_limit: int = 20
valid_epochs: int = 100
valid_batch_size: int = 1
save_model_epochs: int = 100
resume_from_checkpoint: str = None

之后在训练脚本里,我们可以把旧的命令行参数全删了,再加一个命令行参数 cfg,表示训练配置文件的路径。我们可以用 omegaconf 打开这个配置文件,得到一个词典 data_dict,再用这个词典构建配置文件 cfg。接下来,只需要把原来代码里所有 args. 改成 cfg. 就行了。

1
2
3
4
5
6
7
8
9
from omegaconf import OmegaConf
from training_cfg_0 import BaseTrainingConfig

parser = argparse.ArgumentParser()
parser.add_argument('cfg', type=str)
args = parser.parse_args()

data_dict = OmegaConf.load(args.cfg)
cfg = BaseTrainingConfig(**data_dict)

第一次修改过的训练脚本为 train_0.py,配置文件类在 training_cfg_0.py 里,示例配置文件为 cfg_0.json,一个简单 DDPM 模型配置写在 unet_cfg 目录里。可以直接运行下面的命令测试此训练脚本。

1
python train_0.py cfg_0.json

在配置文件里,我们只需要改少量的训练参数就行了。如果想知道还有哪些参数可以改,可以去查看 training_cfg_0.py 文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
{
"logging_dir": "logs",
"output_dir": "models/ddpm_0",

"model_config": "unet_cfg",
"num_epochs": 10,
"train_batch_size": 64,
"checkpointing_steps": 5000,
"valid_epochs": 1,
"valid_batch_size": 4,
"dataset_name": "ylecun/mnist",
"resolution": 32,
"learning_rate": 1e-4
}

读者感兴趣的话也可以尝试这样改一遍代码。这样做会强迫自己读一遍训练脚本,让自己更熟悉这份代码。

适配多种任务的训练脚本

如果只是训练一种任务,Diffusers 的这种训练脚本还算好用。但如果我们想用完全相同的训练流程训练多种任务,这种脚本的弊端就暴露出来了:

  • 各任务的官方示例脚本本身就不完全统一。比如有的训练脚本支持设置随机种子,有的不支持。
  • 一旦想修改训练过程,就得同时修改所有任务的脚本。这不符合编程中「代码复用」的思想。

为此,我想重构一下官方训练脚本,将训练流程和每种任务的具体训练过程解耦开,让一份训练脚本能够被多种任务使用。于是,我又从头过了一遍训练脚本,将代码分成两类:所有任务都会用到的代码、仅 DDPM 训练会用到的代码。如下图所示,我用红字表示了训练脚本中应该由具体任务决定的部分。

根据这个划分规则,我将仅和 DDPM 相关的代码剥离出来,并用一个描述某具体任务的训练器接口类的方法调用代替原有代码。这样,每次换一个训练任务,只需要重新实现一个训练器类就行了。如下图所示,原流程图中所有红字的内容都可以由接口类的方法代替。对于不同任务,我们需要实现不同的训练器类。

具体在代码中,我写了一个接口类 Trainer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Trainer(metaclass=ABCMeta):
def __init__(self, weight_dtype, accelerator, logger, cfg):
self.weight_dtype = weight_dtype
self.accelerator = accelerator
self.logger = logger
self.cfg = cfg

@abstractmethod
def init_modules(self,
enable_xformer: bool = False,
gradient_checkpointing: bool = False):
pass

@abstractmethod
def init_optimizers(self, train_batch_size):
pass

@abstractmethod
def init_lr_schedulers(self, gradient_accumulation_steps, num_epochs):
pass

def set_dataset(self, dataset, train_dataloader):
self.dataset = dataset
self.train_dataloader = train_dataloader

@abstractmethod
def prepare_modules(self):
pass

@abstractmethod
def models_to_train(self):
pass

@abstractmethod
def training_step(self, global_step, batch) -> dict:
pass

@abstractmethod
def validate(self, epoch, global_step):
pass

@abstractmethod
def save_pipeline(self):
pass

@abstractmethod
def save_model_hook(self, models, weights, output_dir):
pass

@abstractmethod
def load_model_hook(self, models, input_dir):
pass

根据类型名和初始化参数可以创建具体的训练器。

1
2
3
4
5
6
7
8
9
10
def create_trainer(type, weight_dtype, accelerator, logger, cfg_dict) -> Trainer:
from ddpm_trainer import DDPMTrainer
from sd_lora_trainer import LoraTrainer

__TYPE_CLS_DICT = {
'ddpm': DDPMTrainer,
'lora': LoraTrainer
}

return __TYPE_CLS_DICT[type](weight_dtype, accelerator, logger, cfg_dict)

原来训练脚本里的具体训练逻辑被接口类方法调用代替。比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# old
if cfg.model_config is None:
model = UNet2DModel(...)
else:
config = UNet2DModel.load_config(cfg.model_config)
model = UNet2DModel.from_config(config)

# Create EMA for the model.
if cfg.use_ema:
ema_model = EMAModel(...)
...

# new
trainer.init_modules(enable_xformers, cfg.gradient_checkpointing)

原来仅和 DDPM 训练相关的代码全被我搬到了 DDPMTrainer 类中。与之对应,除了代码需要搬走外,原配置文件里的数据也需要搬走。我在 DDPMTrainer 类里加了一个 DDPMTrainingConfig 数据类,用来存对应的配置数据。

1
2
3
4
5
6
7
8
9
@dataclass
class DDPMTrainingConfig:
# Diffuion Models
model_config: str
ddpm_num_steps: int = 1000
ddpm_beta_schedule: str = 'linear'
prediction_type: str = 'epsilon'
ddpm_num_inference_steps: int = 100
...

因此,我们需要用稍微复杂一点的方式来创建配置文件。现在全局训练配置和任务配置放在两组配置里。配置文件最外层除 "base" 外的那个键表明了训练器的类型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{
"base": {
"logging_dir": "logs",
"output_dir": "models/ddpm_1",
"checkpointing_steps": 5000,
"valid_epochs": 1,
"dataset_name": "ylecun/mnist",
"resolution": 32,
"train_batch_size": 64,
"num_epochs": 10
},
"ddpm": {
"model_config": "unet_cfg",
"learning_rate": 1e-4,
"valid_batch_size": 4
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
__TYPE_CLS_DICT = {
'base': BaseTrainingConfig,
'ddpm': DDPMTrainingConfig,
'lora': LoraTrainingConfig
}


def load_training_config(config_path: str) -> Dict[str, BaseTrainingConfig]:
data_dict = OmegaConf.load(config_path)

# The config must have a "base" key
base_cfg_dict = data_dict.pop('base')

# The config must have one another model config
assert len(data_dict) == 1
model_key = next(iter(data_dict))
model_cfg_dict = data_dict[model_key]
model_cfg_cls = __TYPE_CLS_DICT[model_key]

return {'base': BaseTrainingConfig(**base_cfg_dict),
model_key: model_cfg_cls(**model_cfg_dict)}

这样改完过后,训练脚本开头也需要稍作更改,其他地方保持不变。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from training_cfg_1 import BaseTrainingConfig, load_training_config
from trainer import Trainer, create_trainer

def main():
parser = argparse.ArgumentParser()
parser.add_argument('cfg', type=str)
args = parser.parse_args()

cfgs = load_training_config(args.cfg)
cfg: BaseTrainingConfig = cfgs.pop('base')
trainer_type = next(iter(cfgs))
trainer_cfg_dict = cfgs[trainer_type]

...

trainer: Trainer = create_trainer(
trainer_type, weight_dtype, accelerator, cfg.logger, trainer_cfg_dict)

这次修改过的训练脚本为 train_1.py,配置文件类在 training_cfg_1.py 里,DDPM 训练器在 TrainingScript/ddpm_trainer.py 里,示例配置文件为 cfg_1.json。可以直接运行下面的命令测试此训练脚本。

1
python train_1.py cfg_1.json

运行这一版或者上一版的训练脚本后,我们都能很快训练完一个 MNIST 上的 DDPM 模型。从训练可视化结果可以看出,代码重构大概是没有出错,模型能正确生成图片。

对训练器类的程序设计思路感兴趣的话,欢迎阅读附录。

添加新的训练任务

为了验证这套新代码的可拓展性,我仿照 Diffusers 官方 SD LoRA 训练脚本 examples/text_to_image/train_text_to_image_lora.py,快速实现了一个 SD LoRA 训练器类。这个类在 sd_lora_trainer.py 文件里。

我来简单介绍添加新训练任务的过程。要添加新训练任务,要修改三处:

  1. 创建新文件,在文件里定义配置数据类及实现训练器类。
  2. trainer.py 里导入新训练器类。
  3. training_cfg_1.py 里导入新配置数据类。

先来看较简单的第二处和第三处修改。导入新训练器类只需要加一行 import 和一条词典项。

1
2
3
4
5
6
7
8
9
10
def create_trainer(type, weight_dtype, accelerator, logger, cfg_dict) -> Trainer:
from ddpm_trainer import DDPMTrainer
from sd_lora_trainer import LoraTrainer

__TYPE_CLS_DICT = {
'ddpm': DDPMTrainer,
'lora': LoraTrainer
}

return __TYPE_CLS_DICT[type](weight_dtype, accelerator, logger, cfg_dict)

导入新配置数据类也一样,一行 import 和一项词典项。

1
2
3
4
5
6
7
from sd_lora_trainer import LoraTrainingConfig

__TYPE_CLS_DICT = {
'base': BaseTrainingConfig,
'ddpm': DDPMTrainingConfig,
'lora': LoraTrainingConfig
}

而实现一个训练器类会比较繁琐。我是先把 DDPM 训练器类复制了过来,在此基础上进行修改。由于 SD LoRA 训练器有官方训练脚本作为参考,我还是和之前实现 DDPM 训练器一样,从官方训练脚本里抠出对应代码,将其填入训练器类方法里。比如在初始化模块时,我们不仅需要初始化 U-Net,还有 VAE 等模块。在初始化优化器时,应该只优化 LoRA 参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class LoraTrainer(Trainer):
def init_modules(self,
enable_xformer=False,
gradient_checkpointing=False):
cfg = self.cfg
# Load scheduler, tokenizer and models.
self.noise_scheduler = DDPMScheduler...
self.tokenizer = CLIPTokenizer...
self.text_encoder = CLIPTextModel...
self.vae = AutoencoderKL...
self.unet = UNet2DConditionModel...
# freeze parameters of models to save more memory
self.unet.requires_grad_(False)
self.vae.requires_grad_(False)
self.text_encoder.requires_grad_(False)

for param in self.unet.parameters():
param.requires_grad_(False)

unet_lora_config = LoraConfig(...)
self.lora_layers = filter(
lambda p: p.requires_grad, self.unet.parameters())
...

def init_optimizers(self, train_batch_size):
...
self.optimizer = torch.optim.AdamW(
self.lora_layers,
...)

SD LoRA 训练器类在 sd_lora_trainer.py 文件里,对应配置文件为 cfg_lora.json。用下面的代码即可尝试 LoRA 训练。

1
python train_1.py cfg_lora.json

可能是 MNIST 数据集的图片太小了,而 SD 又是为较大的图片设计的,又或是 LoRA 的拟合能力有限,生成的效果不是很好。但可以看出,SD LoRA 学到了 MNIST 的图片风格。

就我自己使用下来,添加一个新的训练任务还是非常轻松的。我可以只关心初始化模型、训练、验证等实现细节,而不用关心那些通用的训练代码。当然,这份通用训练脚本还不够强大,还不能处理更复杂的数据集。SD LoRA 其实需要一个带文本标注的数据集,但由于我只是想测试添加新训练器的难度,就没有去改数据集,只是默认用了空文本来训练 LoRA。

总结

我自己在使用 Diffusers 训练脚本时,发现这种训练脚本难以适配多任务训练,于是重构了一份拓展性更强的训练脚本。在这篇文章中,我先是介绍了 Diffusers 训练脚本的通用框架,再分享了我改写脚本的过程。相信读者在读完本文后,不仅能够熟悉 Diffusers 训练脚本的具体原理,还能够动手修改它,或者基于我的这一版改进脚本,编写一份适合自己的训练脚本。

我重构的这套训练器也没有太多封装,在维持 Diffusers 那种平铺直叙风格的同时,将每种训练任务独有的代码、数据搬了出来,让开发者专注于编写新的逻辑。我没怎么用过别的训练框架,不太好直接对比。但至少相比于 PyTorch Lightning 那种模型和训练逻辑写在同一个类里的写法,我更认可 Diffusers 这种将模型结构和训练、采样分离的设计。这套框架的训练器也只有训练的逻辑,不会掺杂其他逻辑。

本文的代码链接为 https://github.com/SingleZombie/DiffusersExample/tree/main/TrainingScript

注意,这份代码是我随手写的,只测试了简单的训练命令。如果发现 bug,欢迎提 issue。这份代码仅供本文教学使用,功能有限,以后我会在其他地方更新这份代码。另外,以后我写其他训练教程时也会复用这套代码。

附录:训练器程序设计思路

在设计训练器接口类的接口时,其实我没有做多少主观设计,基本上都是按照一些设计原则,机械地将原来的训练脚本进行重构。我也不知道这些原则是怎么想出来的,只是根据我多年写代码的经验,我感觉按照这些规则做可以保证训练脚本和训练器之间耦合度更低,易于拓展。这些原则有:

  1. 如果在另一项任务里这行代码会变动,则这项代码应写入训练器类。
  2. 如果某一数据的调用全部都被放入了训练器类里,那么这个数据应该是训练器类的成员变量。如果该数据来自配置文件,则将该数据的定义从全局配置移入训练器配置。
  3. 如果某数据既要在训练脚本中使用,又要在训练器类里使用,则在训练脚本中初始化该数据,并以初始化参数或者接口参数两种方式将数据传入训练器。传入方式由数据被确定的时刻决定。比如脚本一开始就初始化好的日志对象应该作为初始化参数,而一些中途计算的当前 batch 数等参数应该作为接口参数。
  4. 原则上,训练脚本不从数据类里获取数据。

根据这些原则,在设计训练器接口类时,我并没有一开始就定下有哪些接口、接口的参数分别是什么,而是一边搬运代码,一边根据代码的实际内容动态地编写接口类。比如一开始,我的接口类构造函数并没有加入日志库类型。

1
2
3
class Trainer(metaclass=ABCMeta):
def __init__(self, weight_dtype, accelerator, cfg):
...

后来写训练器验证方法时,我发现这里必须要获取日志类的类型,不得已在构造函数里多加了一个参数。

1
2
3
4
5
6
def validate(self, epoch, global_step):
...
if self.logger == "tensorboard":
...

def __init__(self, weight_dtype, accelerator, logger, cfg):

原则 3 和原则 4 本质上是将训练脚本也看成一个对象。所有数据要么属于训练脚本,要么属于训练类。原则 4 不从训练器里获取信息,某种程度上体现了面向对象中的封装性,不让训练器去改训练脚本里的数据。我尽可能地遵守了原则 4,但只有一处例外。在调用 accelerate.prapare 后,train_dataloader 在训练器里发生了更改。而 train_dataloader 其实是属于训练脚本的。没办法,这里只能去训练器里获取一次数据。我没来得及仔细研究,说不定 accelerate.prapare 可以多次调用,这样我就能让训练脚本自己维护 train_dataloader

1
2
trainer.prepare_modules()
train_dataloader = trainer.train_dataloader

这样看下来,这份代码框架在各种角度上都有很大的改进空间。以后我会来慢慢改进这份代码。就目前的设计,训练中整体逻辑、数据集、训练器三部分应该是相互独立的。数据集我还没有单独拿出来写。应该至少实现纯图像、带文本标注图像这两种数据集。

这次重构之后,我也有一些程序设计上的体会。重构代码比从头做程序设计要简单很多。重构只需要根据已有代码,设计出一套更合理的逻辑,像我这样按照某些原则,无脑地修改代码就行了。而程序设计需要考虑未知的情况,为未来可能加入的功能铺路。也正因为从头设计更难,有时会出现设计过度或者设计不足的情况。感觉更合理的开发方式是从头设计与重构交替进行。

近期,最受开源社区欢迎的文生图模型 Stable Diffusion 的最新版本 Stable Diffusion 3 开放了源码和模型参数。开发者宣称,Stable Diffusion 3 使用了全新的模型结构和文本编码方法,能够生成更符合文本描述且高质量的图片。得知 Stable Diffusion 3 开源后,社区用户们纷纷上手测试,在网上分享了许多测试结果。而在本文中,我将面向之前已经熟悉 Stable Diffusion 的科研人员,快速讲解 Stable Diffusion 3 论文的主要内容及其在 Diffusers 中的源码。对于 Stable Diffusion 3 中的一些新技术,我并不会介绍其细节,而是会讲清其设计动机并指明进一步学习的参考文献。

内容索引

本文会从多个角度简单介绍 SD3,具体要介绍的方面如下所示。读者可以根据自己的需求,跳转到感兴趣的部分阅读。

流匹配原理简介

流匹配是一种定义图像生成目标的方法,它可以兼容当前扩散模型的训练目标。流匹配中一个有代表性的工作是整流 (rectified flow),它也正是 SD3 用到的训练目标。我们会在本文中通过简单的可视化示例学习流匹配的思想。

SD3 中的 DiT

我们会从一个简单的类 ViT 架构开始,学习 SD3 中的去噪网络 DiT 模型是怎么一步一步搭起来的。读者不需要提前学过 DiT,只需要了解 Transformer 的结构,并大概知道视觉任务里的 Transformer 会做哪些通用的修改(如图块化),即可学懂 SD3 里的 DiT。

SD3 模型与训练策略改进细节

除了将去噪网络从 U-Net 改成 DiT 外,SD3 还在模型结构与训练策略上做了很多小改进:

  • 改变训练时噪声采样方法
  • 将一维位置编码改成二维位置编码
  • 提升 VAE 隐空间通道数
  • 对注意力 QK 做归一化以确保高分辨率下训练稳定

本文会简单介绍这些改进。

大型消融实验

对于想训练大型文生图模型的开发者,SD3 论文提供了许多极有价值的大型消融实验结果。本文会简单分析论文中的两项实验结果:各训练目标在文生图任务中的表现、SD3 的参数扩增实验结果。

SD3 Diffusers 源码解读

本文会介绍如何配置 Diffusers 环境以用代码运行 SD3,并简单介绍相比于 SD,SD3 的采样代码和模型代码有哪些变动。

论文阅读

核心贡献

介绍 Stable Diffusion 3 (SD3) 的文章标题为 Scaling Rectified Flow Transformers for High-Resolution Image Synthesis。与其说它是一篇技术报告,更不如说它是一篇论文,因为它确实是按照撰写学术论文的一般思路,将正文的叙述重点放到了方法的核心创新点上,而没有过多叙述工程细节。正如其标题所示,这篇文章的内容很简明,就是用整流 (rectified flow) 生成模型、Transformer 神经网络做了模型参数扩增实验,以实现高质量文生图大模型。

由于这是一篇实验主导而非思考主导的文章,论文的开头没有太多有价值的内容。从我们读者学习论文的角度,文章的核心贡献如下:

从方法设计上:

  • 首次在大型文生图模型上使用了整流模型。
  • 用一种新颖的 Diffusion Transformer (DiT) 神经网络来更好地融合文本信息。
  • 使用了各种小设计来提升模型的能力。如使用二维位置编码来实现任意分辨率的图像生成。

从实验上:

  • 开展了一场大规模、系统性的实验,以验证哪种扩散模型/整流模型的学习目标最优。
  • 开展了扩增模型参数的实验 (scaling study),以证明提升参数量能提升模型的效果。

整流模型简介

由于 SD3 最后用了整流模型来建模图像生成,所以文章是从一种称为流匹配 (Flow Matching) 的角度而非更常见的扩散模型的角度来介绍各种训练目标。鉴于 SD3 并没有对其他论文中提出的整流模型做太多更改,我们在阅读本文时可以主要关注整流的想法及其与扩散模型的关系,后续再从其他论文中学习整流的具体原理。在此,我们来大致认识一下流匹配与整流的想法。

所谓图像生成,其实就是让神经网络模型学习一个图像数据集所表示的分布,之后从分布里随机采样。比如我们想让模型生成人脸图像,就是要让模型学习一个人脸图像集的分布。为了直观理解,我们可以用二维点来表示一张图像的数据。比如在下图中我们希望学习红点表示的分布,即我们希望随机生成点,生成的点都落在红点处,而不是落在灰点处。

我们很难表示出一个适合采样的复杂分布。因此,我们会把学习一个分布的问题转换成学习一个简单好采样的分布到复杂分布的映射。一般这个简单分布都是标准正态分布。如下图所示,我们可以用简单的算法采样在原点附近的来自标准正态分布的蓝点,我们要想办法得到蓝点到红点的映射方法。

学习这种映射依然是很困难的。而近年来包括扩散模型在内的几类生成模型用一种巧妙的方法来学习这种映射:从纯噪声(标准正态分布里的数据)到真实数据的映射很难表示,但从真实数据到纯噪声的逆映射很容易表示。所以,我们先人工定义从图像数据集到噪声的变换路线(红线),再让模型学习逆路线(蓝线)。让噪声数据沿着逆路线走,就实现了图像生成。

我们又可以用一种巧妙的方法间接学习图像生成路线。知道了预定义的数据到噪声的路线后,我们其实就知道了数据在路线上每一位置的速度(红箭头)。那么,我们可以以每一位置的反向速度(蓝箭头)为真值,学习噪声到真实数据的速度场。这样的学习目标被称为流匹配。

对于不同的扩散模型及流匹配模型,其本质区别在于图像到噪声的路线的定义方式。在扩散模型中,图像到噪声的路线是由一个复杂的公式表示的。而整流模型将图像到噪声的路线定义为了直线。比如根据论文的介绍,整流中 $t$ 时刻数据 $z_t$ 由真实图像 $x_0$ 变换成纯噪声 $\epsilon$ 的位置为:

而较先进的扩散模型 EDM 提出的路线公式为($b_t$ 是一个形式较为复杂的变量):

由于整流最后学习出来的生成路线近乎是直线,这种模型在设计上就支持少步数生成。

虽然整流模型是这样宣传的,但实际上 SD3 还是默认用了 28 步来生成图像。单看这篇文章,原整流论文里的很多设计并没有用上。对整流感兴趣的话,可以去阅读原论文 Flow straight and fast: Learning to generate and transfer data with rectified flow

流匹配模型和扩散模型的另一个区别是,流匹配模型天然支持 image2image 任务。从纯噪声中生成图像只是流匹配模型的一个特例。

非均匀训练噪声采样

在学习这样一种生成模型时,会先随机采样一个时刻 $t \in [0, 1]$,根据公式获取此时刻对应位置在生成路线上的速度,再让神经网络学习这个速度。直观上看,刚开始和快到终点的路线很好学,而路线的中间处比较难学。因此,在采样时刻 $t$ 时,SD3 使用了一种非均匀采样分布。

如下图所示,SD3 主要考虑了两种公式: mode(左)和 logit-norm (右)。二者的共同点是中间多,两边少。mode 相比 logit-norm,在开始和结束时概率不会过分接近 0。

网络整体架构

以上内容都是和训练相关的理论基础,下面我们来看多数用户更加熟悉的文生图架构。

从整体架构上来看,和之前的 SD 一样,SD3 主要基于隐扩散模型(latent diffusion model, LDM)。这套方法是一个两阶段的生成方法:先用一个 LDM 生成隐空间低分辨率的图像,再用一个自编码器把图像解码回真实图像。

扩散模型 LDM 会使用一个神经网络模型来对噪声图像去噪。为了实现文生图,该去噪网络会以输入文本为额外约束。相比之前多数扩散模型,SD3 的主要改进是把去噪模型的结构从 U-Net 变为了 DiT。

DiT 的论文为 Scalable Diffusion Models with Transformers。如果只是对 DiT 的结构感兴趣的话,可以去直接通过读 SD3 的源码来学习。读 DiT 论文时只需要着重学习 AdaLayerNormZero 模块。

提升自编码器通道数

在当时设计整套自编码器 + LDM 的生成架构时,SD 的开发者并没有仔细改进自编码器,用了一个能把图像下采样 8 倍,通道数变为 4 的隐空间图像。比如输入 $512 \times 512 \times 3$ 的图像会被自编码器编码成 $64 \times 64 \times 4$。而近期有些工作发现,这个自编码器不够好,提升隐空间的通道数能够提升自编码器的重建效果。因此,SD3 把隐空间图像的通道数从 $4$ 改为了 $16$。

多模态 DiT (MM-DiT)

SD3 的去噪模型是一个 Diffusion Transformer (DiT)。如果去噪模型只有带噪图像这一种输入的话,DiT 则会是一个结构非常简单的模型,和标准 ViT 一样:图像过图块化层 (Patching) 并与位置编码相加,得到序列化的数据。这些数据会像标准 Transformer 一样,经过若干个子模块,再过反图块层得到模型输出。DiT 的每个子模块 DiT-Block 和标准 Transformer 块一样,由 LayerNorm, Self-Attention, 一对一线性层 (Pointwise Feedforward, FF) 等模块构成。

图块化层会把 $2\times2$ 个像素打包成图块,反图块化层则会把图块还原回像素。

然而,扩散模型中的去噪网络一定得支持带约束生成。这是因为扩散模型约束于去噪时刻 $t$。此外,作为文生图模型,SD3 还得支持文本约束。DiT 及本文的 MM-DiT 把模型设计的重点都放在了处理额外约束上。

我们先看一下模块是怎么处理较简单的时刻约束的。此处,如下图所示,SD3 的模块保留了 DiT 的设计,用自适应 LayerNorm (Adaptive LayerNorm, AdaLN) 来引入额外约束。具体来说,过了 LayerNorm 后,数据的均值、方差会根据时刻约束做调整。另外,过完 Attention 层或 FF 层后,数据也会乘上一个和约束相关的系数。

我们再来看文本约束的处理。文本约束以两种方式输入进模型:与时刻编码拼接、在注意力层中融合。具体数据关联细节可参见下图。如图所示,为了提高 SD3 的文本理解能力,描述文本 (“Caption”) 经由三种编码器编码,得到两组数据。一组较短的数据会经由 MLP 与文本编码加到一起;另一组数据会经过线性层,输入进 Transformer 的主模块中。

将约束编码与时刻编码相加是一种很常见的做法。此前 U-Net 去噪网络中处理简单约束(如 ImageNet 类型约束)就是用这种方法。

SD3 的 DiT 的子模块结构图如下所示。我们可以分几部分来看它。先看时刻编码 $y$ 的那些分支。和标准 DiT 子模块一样,$y$ 通过修改 LayerNorm 后数据的均值、方差及部分层后的数据大小来实现约束。再看输入的图像编码 $x$ 和文本编码 $c$。二者以相同的方式做了 DiT 里的 LayerNorm, FF 等操作。不过,相比此前多数基于 DiT 的模型,此模块用了一种特殊的融合注意力层。具体来说,在过注意力层之前,$x$ 和 $c$ 对应的 $Q, K, V$ 会分别拼接到一起,而不是像之前的模型一样,$Q$ 来自图像,$K, V$ 来自文本。过完注意力层,输出的数据会再次拆开,回到原本的独立分支里。由于 Transformer 同时处理了文本、图像的多模态信息,所以作者将模型取名为 MM-DiT (Multimodal DiT)。

论文里讲:「这个结构可以等价于两个模态各有一个 Transformer,但是在注意力操作时做了拼接,使得两种表示既可以在独自的空间里工作也可以考虑到另一个表示。」然而,我不太喜欢这种尝试去凭空解读神经网络中间表示的表述。仅从数据来源来看,过了一个注意力层后,图像信息和文本信息就混在了一起。你很难说,也很难测量,之后的 $x$ 主要是图像信息,$c$ 主要是文本信息。只能说 $x, c$ 都蕴含了多模态的信息。之前 SD U-Net 里的 $x, c$ 可以认为是分别包含了图像信息和文本信息,因为之前的 $x$ 保留了二维图像结构,而 $c$ 仅由文本信息决定。

比例可变的位置编码

此前多数方法在使用类 ViT 架构时,都会把图像的图块从左上到右下编号,把二维图块拆成一维序列,再用这种一维位置编码来对待图块。

这样做有一个很大的坏处:生成的图像的分辨率是无法修改的。比如对于上图,假如采样时输入大小不是 $4 \times 4$,而是 $4 \times 5$,那么 $0$ 号图块的下面就是 $5$ 而不是 $4$ 了,模型训练时学习到的图块之间的位置关系全部乱套。

解决此问题的方法很简单,只需要将一维的编码改为二维编码。这样 Transformer 就不会搞混二维图块间的关系了。

SD3 的 MM-DiT 一开始是在 $256^2$ 固定分辨率上训练的。之后在高分辨率图像上训练时,开发者用了一些巧妙的位置编码设置技巧,让不同比例的高分辨率图像也能共享之前学到的这套位置编码。详细公式请参见原论文。

训练数据预处理

看完了模块设计,我们再来看一下 SD3 在训练中的一些额外设计。在大规模训练前,开发者用三个方式过滤了数据:

  1. 用了一个 NSFW 过滤器过滤图片,似乎主要是为了过滤色情内容。
  2. 用美学打分器过滤了美学分数太低的图片。
  3. 移除了看上去语义差不多的图片。

虽然开发者们自信满满地向大家介绍了这些数据过滤技术,但根据社区用户们的反馈,可能正是因为色情过滤器过分严格,导致 SD3 经常会生成奇怪的人体。

由于在训练 LDM 时,自编码器和文本编码器是不变的,因此可以提前处理好所有训练数据的图像编码和文本编码。当然,这是一项非常基础的工程技巧,不应该写在正文里的。

用 QK 归一化提升训练稳定度

按照之前高分辨率文生图模型的训练方法,SD3 会先在 $256^2$ 的图片上训练,再在高分辨率图片上微调。然而,开发者发现,开始微调后,混合精度训练常常会训崩。根据之前工作的经验,这是由于注意力输入的熵会不受控制地增长。解决方法也很简单,只要在做注意力计算之前对 Q, K 做一次归一化就行,具体做计算的位置可以参考上文模块图中的 “RMSNorm”。不过,开发者也承认,这个技巧并不是一个长久之策,得具体问题具体分析。看来这种 DiT 模型在大规模训练时还是会碰到许多训练不稳定的问题,且这些问题没有一个通用解。

哪种扩散模型训练目标最适合文生图任务?

最后我们来看论文的实验结果部分。首先,为了寻找最好的扩散模型/流匹配模型,开发者开展了一场声势浩大的实验。实验涉及 61 种训练公式,其中的可变项有:

  • 对于普通扩散模型,考虑 $\epsilon$- 或 $\mathbf{v}$-prediction,考虑线性或 cosine 噪声调度。
  • 对于整流,考虑不同的噪声调度。
  • 对于 EDM,考虑不同的噪声调度,且尽可能与整流的调度机制相近以保证可比较。

在训练时,除了训练目标公式可变外,优化算法、模型架构、数据集、采样器都不可变。所有模型在 ImageNet 和 CC12M 数据集上训练,在 COCO-2014 验证集上评估 FID 和 CLIP Score。根据评估结果,可以选出每个模型的最优停止训练的步数。基于每种目标下的最优模型,开发者对模型进行最后的排名。由于在最终评估时,仍有采样步数、是否使用 EMA 模型等可变采样配置,开发者在所有 24 种采样配置下评估了所有模型,并用一种算法来综合所有采样配置的结果,得到一个所有模型的最终排名。最终的排名结果如下面的表 1 所示。训练集上的一些指标如表 2 所示。

根据实验结果,我们可以得到一些直观的结论:整流领先于扩散模型。惊人的是,较新推出的 EDM 竟然没有战胜早期的 LDM (“eps/linear”)。

当然,我个人认为,应该谨慎看待这份实验结果。一般来说,大家做图像生成会用一个统一的指标,比如 ImageNet 上的 FID。这篇论文相当于是新提出了一种昂贵的评价方法。这种评价方法是否合理,是否能得到公认还犹未可知。另外,想说明一个生成模型的拟合能力不错,用 ImageNet 上的 FID 指标就足够有说服力了,大家不会对一个简单的生成模型有太多要求。然而,对于大型文生图模型,大家更关心的是模型的生成效果,而 FID 和 CLIP Score 并不能直接反映文生图模型的质量。因此,光凭这份实验结果,我们并不能说整流一定比之前的扩散模型要好。

会关注这份实验结果的应该都是公司里的文生图开发者。我建议体量小的公司直接参考这份实验结果,无脑使用整流来代替之前的训练目标。而如果有能力做同等级的实验的话,则不应该错过改良后的扩散模型,如最新的 EDM2,说不定以后还会有更好的文生图训练目标。

参数扩增实验结果

现在多数生成模型都会做参数扩增实验,即验证模型表现随参数量增长而增长,确保模型在资源足够的情况下可以被训练成「大模型」。SD3 也做了类似的实验。开发者用参数 $d$ 来控制 MM-DiT 的大小,Transformer 块的个数为 $d$,且所有特征的通道数与 $d$ 成正比。开发者在 $256^2$ 的数据上训练了所有模型 500k 步,每 50k 步在 CoCo 数据集上统计验证误差。最终所有评估指标如下图所示。可以说,所有指标都表明,模型的表现的确随参数量增长而增长。更多结果请参见论文。

Diffusers 源码阅读

测试脚本

我们来阅读一下 SD3 在最流行的扩散模型框架 Diffusers 中的源码。在读源码前,我们先来跑通官方的示例脚本。

由于使用协议的限制,SD3 的环境搭起来稍微有点麻烦。首先,我们要确保 Diffuers 和 Transformers 都用的是最新版本。

1
pip install --upgrade diffusers transformers

之后,我们要注册 HuggingFace 账号,再在 SD3 的模型网站 https://huggingface.co/stabilityai/stable-diffusion-3-medium 里确认同意某些使用协议。之后,我们要设置 Access Token。具体操作如下所示,先点右上角的 “settings”,再点左边的 “Access Tokens”,创建一个新 token。将这个 token 复制保存在本地后,点击 token 右上角选项里的 “Edit Permission”,在权限里开启 “… public gated repos …”。

最后,我们用命令行登录 HuggingFace 并使用 SD3。先用下面的命令安装 HuggingFace 命令行版。

1
pip install -U "huggingface_hub[cli]"

再输入 huggingface-cli login,命令行会提示输入 token 信息。把刚刚保存好的 token 粘贴进去,即可完成登录。

1
2
3
huggingface-cli login

Enter your token (input will not be visible): 在这里粘贴 token

做完准备后,我们就可以执行下面的测试脚本了。注意,该脚本会自动下载模型,我们需要保证当前环境能够访问 HuggingFace。执行完毕后,生成的 $1024 \times 1024$ 大小的图片会保存在 tmp.png 里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

image = pipe(
"A cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
guidance_scale=7.0,
).images[0]

image.save('tmp.png')

我得到的图片如下所示。看起来 SD3 理解文本的能力还是挺强的。

模型组件

接下来我们来快速浏览一下 SD3 流水线 StableDiffusion3Pipeline 的源码。在 IDE 里使用源码跳转功能可以在 diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py 里找到该类的源码。

通过流水线的 __init__ 方法,我们能知道 SD3 的所有组件。组件包括自编码器 vae, MM-DiT Transformer, 流匹配噪声调度器 scheduler,以及三个文本编码器。每个编码器由一个 tokenizer 和一个 text encoder 组成.

1
2
3
4
5
6
7
8
9
10
11
12
def __init__(
self,
transformer: SD3Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
):

vae 的用法和之前 SD 的一模一样,编码时用 vae.encode 并乘 vae.config.scaling_factor,解码时除以 vae.config.scaling_factor 并用 vae.decode

文本编码器的用法可以参见 encode_prompt 方法。文本会分别过各个编码器的 tokenizer 和 text encoder,得到三种文本编码,并按照论文中的描述拼接成两种约束信息。这部分代码十分繁杂,多数代码都是在处理数据形状,没有太多有价值的内容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def encode_prompt(
self,
prompt,
prompt_2,
prompt_3,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
negative_prompt_2,
negative_prompt_3,
...

):
...

return prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds

采样流水线

我们再来通过阅读流水线的 __call__ 方法了解 SD3 采样的过程。由于 SD3 并没有修改 LDM 的这套生成框架,其采样流水线和 SD 几乎完全一致。SD3 和 SD 的 __call__ 方法的主要区别是,生成文本编码时会生成两种编码。

1
2
3
4
5
6
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(...)

在调用去噪网络时,那个较小的文本编码 pooled_prompt_embeds 会作为一个额外参数输入。

1
2
3
4
5
6
7
8
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

MM-DiT 去噪模型

相比之下,SD3 的去噪网络 MM-DiT 的改动较大。我们来看一下对应的 SD3Transformer2DModel 类,它位于文件 diffusers\models\transformers\transformer_sd3.py

类的构造函数里有几个值得关注的模块:二维位置编码类 PatchEmbed、组合时刻编码和文本编码模块 CombinedTimestepTextProjEmbeddings、主模块类 JointTransformerBlock

1
2
3
4
5
6
7
8
9
10
11
def __init__(...):
...
self.pos_embed = PatchEmbed(...)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(...)
...
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(..)
for i in range(self.config.num_layers)
]
)

类的前向传播函数 forward 里都是比较常规的操作。数据会依次经过前处理、若干个 Transformer 块、后处理。所有实现细节都封装在各个模块类里。

1
2
3
4
5
6
7
8
9
10
11
def forward(...):
hidden_states = self.pos_embed(hidden_states)
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(...)

encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
...

接下来我们来看这几个较为重要的子模块。PatchEmbed 类的实现写在 diffusers/models/embeddings.py 里。这个类的实现写得非常清晰。PatchEmbed 类本身用于维护位置编码宽高、特征长度这些信息,计算位置编码的关键代码在 get_2d_sincos_pos_embed 中。get_2d_sincos_pos_embed 会生成 (0, 0), (1, 0), ... 这样的二维坐标网格,再调用 get_2d_sincos_pos_embed_from_grid 生成二维位置编码。get_2d_sincos_pos_embed_from_grid 会调用两次一维位置编码函数 get_1d_sincos_pos_embed_from_grid,也就是 Transformer 里那种标准位置编码生成函数,来分别生成两个方向的编码,最后拼接成二维位置编码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class PatchEmbed(nn.Module):
...
def forward(self, latent):
...
pos_embed = get_2d_sincos_pos_embed(...)

def get_2d_sincos_pos_embed(...):
grid_h = np.arange(...)
grid_w = np.arange(...)
grid = np.meshgrid(grid_w, grid_h)
...
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

def get_2d_sincos_pos_embed_from_grid(...):
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb

组合时刻编码和文本编码模块 CombinedTimestepTextProjEmbeddings 的代码非常短。它实际上就是用通常的 Timesteps 类获取时刻编码,用一个 text_embedder 模块再次处理文本编码,最后把两个编码加起来。
text_embedder 是一个线性层、激活函数、线性层构成的简单模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)

pooled_projections = self.text_embedder(pooled_projection)

conditioning = timesteps_emb + pooled_projections

return conditioning

class PixArtAlphaTextProjection(nn.Module):
def __init__(...):
...

def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states

MM-DiT 的主要模块 JointTransformerBlockdiffusers/models/attention.py 文件里。这个类的代码写得比较乱。它主要负责处理 LayerNorm 及数据的尺度变换操作,具体的注意力计算由注意力处理器 JointAttnProcessor2_0 负责。两处 LayerNorm 的实现方式竟然是不一样的。

我们先简单看一下构造函数里初始化了哪些模块。代码中,norm1, ff, norm2 等模块都是普通 Transformer 块中的模块。而加了 _context 的模块则表示处理文本分支 $c$ 的模块,如 norm1_context, ff_contextcontext_pre_only 表示做完了注意力计算后,还要不要给文本分支加上 LayerNorm 和 FeedForward。如前文所述,具体的注意力计算由 JointAttnProcessor2_0 负责。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class JointTransformerBlock(nn.Module):

def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
super().__init__()

self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"

self.norm1 = AdaLayerNormZero(dim)

if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim)

processor = JointAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
)

self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

if not context_pre_only:
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
else:
self.norm2_context = None
self.ff_context = None

我们再来看 forward 方法。在前向传播时,图像分支和文本分支会分别过 norm1,再一起过注意力操作,再分别过 norm2ff。大概的代码如下所示,我把较复杂的 context 分支的代码略过了。

这份代码写得很不漂亮,按理说模块里两个 LayerNorm + 尺度变换 (即 Adaptive LayerNorm) 的操作是一样的,应该用同样的代码来处理。但是这个模块里 norm1AdaLayerNormZero 类,norm2LayerNorm 类。norm1 会自动做完 AdaLayerNorm 的运算,并把相关变量返回。而在 norm2 处,代码会先执行普通的 LayerNorm,再根据之前的变量手动调整数据的尺度。我们心里知道这份代码是在实现论文里那张结构图就好,没必要去仔细阅读。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

if self.context_pre_only:
...

# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)

# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = hidden_states + ff_output
if self.context_pre_only:
...

return encoder_hidden_states, hidden_states

融合注意力的实现方法很简单。和普通的注意力计算相比,这种注意力就是把另一条数据分支 encoder_hidden_states 也做了 QKV 的线性变换,并在做注意力运算前与原来的 QKV 拼接起来。做完注意力运算后,两个数据又会拆分回去。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class JointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""


def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
...

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

...

# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)

总结

在这篇文章中,我们学习了 SD3 论文及源码中的主要内容。相比于 SD,SD3 做了两项较大的改进:用整流代替原来的 DDPM 中的训练目标;将去噪模型从 U-Net 变成了能更好地处理多模态信息的 MM-DiT。SD3 还在模型结构和训练目标上做了许多小改进,如调整训练噪声采样分布、使用二维位置编码。SD3 论文展示了多项大型消融实验的结果,证明当前的 SD3 是以最优配置训练得到的。SD3 可以在 Diffusers 中使用。当然,由于 SD3 的使用协议较为严格,我们需要做一些配置,才能在代码中使用 SD3。SD3 的采样流水线基本没变,原来 SD 的多数编辑方法能够无缝迁移过来。而 SD3 的去噪模型变动较大,和 U-Net 相关的编辑方法则无法直接用过来。在学习源码时,主要值得学习的是新 MM-DiT 模型中每个 Transformer 层的实现细节。

尽管 SD3 并没有提出新的流匹配方法,但其实验结果表明流匹配模型可能更适合文生图任务。作为研究者,受此启发,我们或许需要关注一下整流等流匹配模型,知道它们的思想,分析它们与原扩散模型训练目标的异同,以拓宽自己的视野。

今年第一次参加 CVPR 的前后发生了很多糟心事。我被气得受不了了。


美国签证不太好申,我周围所有因参加学术会议而申请签证的人都被审查了一个月。我决定参会与申请签证的时间较晚,过签的时间非常仅限。在会议开始前的倒数第二周,我收到了签证通过的通知,并得知我需要在一周后,也就是会议开始前的最后几个工作日领取带了美国签证的护照。能够恰好拿到签证,算得上是运气不错。可下一周发生的种种事情,却谈不上了幸运了。

会议开始前一周的周三早上,我赶早骑行前往学校参加会议前的最后一次周会。骑至半途,忽然天降暴雨。在新加坡,这种现象我已经见怪不怪了。我先穿上日常携带的雨衣,再将书包背到雨衣外,以免背上出太多汗。我又用手机向导师报告我可能会推迟参会,随后将手机和浸满了雨水的眼镜装进口袋。

这场雨就像老天给我开的一个玩笑。我一到目的地,雨就小了下来。上楼,掀开笔记本电脑,发现电脑被穿透了书包的一点雨水浸湿,开不了机。伸手一摸口袋,眼镜竟然也不见了。

这种雨中骑车的事我已经做过好多回了,但我从来没碰到过这么多意外。我有多次电脑进水的经历,可这是我第一次知道穿透书包的那一点水也能弄坏电脑。另外,本来我这条裤子的口袋是不会掉东西的,可这次恰好我要用手机发消息,而眼镜和手机放在了一起,顺势掉了出来。

无论多么令人懊恼,生活都得继续。我用上了我所有的规划能力来处理问题。我先和导师商议将周会延期。随后,我请同学陪我一起找眼镜。还好,眼镜在楼下就找到了,但是被我自己或者其他人踩了一脚,一边的镜框断了。我艰难地用半副眼镜处理完一天的杂事,于傍晚赶往市中心将电脑送修。晚上回家,我用强力胶将断裂的镜框部件粘在一起,眼镜勉强可以用了。第二天周四,我领完美国签证,同时得知电脑严重进水,一天修不好。我回家立刻拿出老电脑并配好工作环境。周五,我准备好一切出国开会所需资料,总算安顿好了一切。

周六,出发前一天,我去换美元。付款时要输入储蓄卡密码,我多次尝试也没能把密码输对,卡被冻结。这时我才想起来为什么我会不记得卡的密码:这张卡的密码是系统随机设置的,平时用卡根本不需要密码,所以我既忘了密码,也忘了密码不是我自己设的。这下去美国用不了这张卡了。还好,我还有国内的信用卡,完成了在新加坡后续的付款。

启程的飞机将于周日早上九点起飞。我起了个大早,提前前往机场。兴冲冲地首次坐上长途国际航班后,我立刻收到了飞机因故障无法起飞的通知,灰溜溜地下了飞机。航空公司给我安排的新航班是 36 小时之后,周一晚上九点。为表赔偿,航空公司给我们预订了新加坡五星级酒店的住宿,并附带了早中晚餐各 20 新币的代金券。走进五星级酒店最普通的房间,我原以为能享受一番,却又被正餐 30 新币起步(还不含税)的菜单气晕了过去。上午我在机场吃过了,中午我就随便点了个 20 新币的汤应付了过去。晚上我兼顾价格与饱腹度,点了一个 30 多新币的披萨。晚餐送至房间,我将代金券和 20 新币交给服务员后,服务员将代金券收走,现金还回,说道:“可以拿明天的早餐钱来抵这顿晚饭。”唉,我又亏了将近 10 新币,早知道点一个更好的披萨了。

周一,白天没有安排,我决定去解决新加坡储蓄卡的事。手机应用有挂失换卡的功能,却楞是没有申请解冻的功能。我只好顶着烈日,跨越条条街道,又穿梭于人山人海的购物中心,找到银行的支行。这个支行只有通过视频交流的虚拟柜台。为了确认身份,柜员问了我三个有关账户的验证问题。有一个问题的华文名词我听不懂,没答上来。因此,卡解冻失败,柜员建议我前往线下柜台。我真是纳闷,我有银行账户的一切访问权限,有身份证、护照,为什么要用这种方法来验证我的身份。但我今天时间多,不跟你们计较了。我又匆忙坐上满载的地铁,赶往下一处有线下柜台的支行。地图上说银行 4 点关门,我是三点半到的,可银行已经关门了。我突然感到一股违和感:今天是周一,为什么地铁那么多人?为什么购物中心那么多人?为什么银行提早下班?我一查,果然如我所料,原来今天是新加坡法定假日。今天,到美国开会的同学在当地开开心心地旅游,留在本地的人则能享受假期。就我一个人,飞机没坐上,假没休息到,事情没办成,从一对对情侣旁边擦身而过,在我本不该在的繁华市区里奔波。没办法,休息一下,晚上早点去坐飞机吧。

晚上,刚进机场,我的镜框因没有粘牢又断开了。还好我早有准备,拿出透明胶勉强固定了镜框。

过安检,我被选为 SSSS 级“幸运”用户,被安检警察拎到队伍前面,进行细致的搜身检查。也好,我也不会带什么危险物品,顺便插了个队。

SSSS 是美国TSA (美国运输安全管理局)随机选取需要接受二次安检的乘客。安检人员看到SSSS的登机牌后会对乘客格外注意,在经过普通的X光扫描检查以后,行李会被拆开进行人工检查,人工彻底搜身,电子产品会接受爆炸物探测。

上了飞机,我总算坐上了有着小屏幕的美联航高级国际航班。航班要航行 15 个小时,我计划在这段时间里写一篇文章。一看时间还早,我就开了两把小游戏。玩到一半,电脑啪地屏幕一黑,没电了。飞机上的插座怎么充不进电啊?过了好一会儿,机长发来广播:“很抱歉,有乘客反映插座失灵,我们将重启设备,尽快修好。(翻译自英文)”我满怀期待地等着电脑的开机,这一等就等到了飞机降落。临走前,贴心的乘务长向大家诚恳地致歉:“今天有一些旅客一直没有接上插座,我们真是深表歉意呢~(翻译自英文)”

第一趟飞机在旧金山降落,我还需要乘坐中转西雅图的飞机。由于下一站是国内航班,我需要在这里完成入境。入境的队伍很长,一眼望去,四五十分钟都不见得能排完队。帮我安排航班的工作人员想必之前是一位极限运动的教练,给我安排了一小时后结束登机、一个半小时后起飞的第二趟航班,让我在这里锻炼极限的时间管理。排队五十分钟后,我极速应付完安检人员的问题,总算顺利入境。眼下飞机只有十分钟就要停止登机了,在这种极限条件下,我的身体潜能与语言表达潜能被猛然激活。我背着沉甸甸的书包——我唯一的行李,仔细地看遍每一个路标,在机场里沿着最短路线狂奔,流畅地办理登机、安检。我这一个多小时一直没来得及上厕所。跑至空荡荡的登机口前,我遗憾地望了一眼对面的厕所后,急忙向柜台后的空乘人员出示机票。“啊!你很幸运啊!快点上飞机。”在我的极限运营下,我成功在起飞前登机,就是肚子下面不太舒服。到了原定的起飞时间,我正双目紧闭静待起飞,忽然听见了机长包含温情的广播:“我们得知有部分旅客还在转机,飞机将在约十五分钟后起飞。”不早说啊!我这么赶是为了什么啊!快把厕所门打开!

飞机起飞,我如愿以偿地解了个手,满足地回座位睡了一觉。可能是剧烈运动导致新陈代谢加快,醒来后,我又是一阵内急。可又好巧不巧,飞机刚开始降落,我还要等半个多小时才能下飞机上厕所。在这段时间里,我仿佛领悟了长生不老之道,一分钟可以当十分钟来过。我思考了我为什么这么难受,为什么上这趟飞机,为什么会存在这个世界上。恍惚的精神慢慢回归,我规划出了飞机降落后最快的下机方式。总算到站了,舱门一打开,我就背起书包,右手紧扣腹部,面露难色,名正言顺地在下机的队伍里插队。一路小跑到厕所前,这下没人看,不用演戏了。可我的手却怎么都不肯从肚子上松开——哦,原来我真的憋得肚子不舒服了。好在厕所只有一步之遥了。我特意找了个高级包厢,坐下来不紧不慢地上起了厕所。这是我第一次憋得这么难受,也是第一次觉得上厕所是人生中最幸福的事。

当地时间凌晨两点,我总算来到了西雅图机场出口。没有飞机要赶,没有厕所要上,充足休息了近 20 个小时的我,即将迎来光明的前程。我没有办理美国的流量,因为我知道在需要用网络的地方都有免费无线网络。按照网上的攻略,我前往停车场的某处,准备用 Uber 打车。一到停车场,凌冽的寒气扑面而来。这冷空调开得有点大啊。不对!这是室外的正常气温。我来美国之前是看了天气预报的,每天的温度都是 10 到 20 多度。我过了太久的夏天,对 10 度的天气没有概念,也没有料到我大晚上还会待在室外。但没事啊,打上车就好了。我在机场就配置好了 Uber,只差呼叫司机这一步了。欸,怎么付款前还要验证手机号?这里信号很差,接收国内手机的短信要三分钟左右。当然,这个延迟是我在多次发送短信后才意识到的。由于我多次输入了上一次短信的验证码,Uber 最后忍无可忍,把我的账号禁用了。寒风中,我身着短袖短裤,在空荡的候车处死盯没有回应的手机,打着牙颤,默默无语。冷静了一会儿,我想我有信用卡有钱,为什么非得用软件打车不可。于是,我走到了普通打车处,稍等片刻,上车,最终总算抵达了酒店。

第一次出国开会,第一次前往美国旅行。对于多数人来说快乐的事情,到我这为什么就成了渡劫呢?这些事或许算不上什么人生大事,但攒在一起总是会令人烦恼。诚然,有些事是我考虑不周,但大多数事是我完全无法控制的。不过这些都没关系。在麻将等概率游戏中搏杀已久的我知道,过去是无法改变的,结果通常是无法控制的,我们能做的只有在此刻寻找问题的最优解。哪怕是碰到了这么多事,我还是拿出了我玩解谜游戏的所有实力,尽力去处理好每一件事。如果还是对生活的不幸感到愤愤不平,不如带着幸灾乐祸的心理想一想我这次事件里的美联航。他们的航班延误,影响了那么多乘客,最后给每个人都赔了一晚的五星级酒店,还不是把问题解决了。


以上是我头脑冷静时写的东西。参会过程省略以示愤怒。


本来我准备了很多要写的东西,回程的时候又被气到了,累死了,不写了。这“一天”里,我提前3小时到机场,坐2小时飞机,等5小时飞机,又坐15小时飞机。累计睡眠5~6乘1个小时。到了新加坡我还想省钱,没打车,在机场走了半小时,又坐了1小时地铁。下地铁后走回家,莫名其妙绕了路,在太阳底下穿长裤走了1个多小时。好不容易回到家,洗完澡想刷个牙,牙膏还没了。真作假时假亦真,没睡醒时醒亦困。看上去我这几天睡了很多次觉,但我根本没有睡醒。算上在美国的时间,我已经颠沛流离了好几天,各种硬抗时差,已经不知道睡够六小时是怎样的感觉了。别人比我早一两天来,晚一两天走,玩得开开心心。我除了在会场当3天“好学生”外,剩下近一周时间全在参与人生的极限挑战。怪不得学术大佬都不愿意出来开会,从明年开始我也是学术大佬了。

在上篇文章中,我们浏览了 Stable Video Diffusion (SVD) 的论文,并特别学习了没有在论文中提及的模型结构、噪声调度器这两个模块。在这篇文章中,让我们来看看 SVD 在 Diffusers 中的源码实现。我们会先学习 SVD 的模型结构,再学习 SVD 的采样流水线。在本文的多数章节中,我都会将 SVD 的结构与 Stable Diffusion (SD) 的做对比,帮助之前熟悉 SD 的读者快速理解 SVD 的性质。强烈建议读者在阅读本文前先熟悉 SD 及其在 Diffusers 中的实现。

Stable Diffusion Diffusers 实现源码解读

简单采样实验

目前开源的 SVD 仅有图生视频模型,即给定视频首帧,模型生成视频的后续内容。在首次开源时,SVD 有 1.0 和 1.0-xt 两个版本。二者模型结构配置相同,主要区别在于训练数据上。SVD 1.0 主要用于生成 14 帧 576x1024 的视频,而 1.0-xt 版本由 1.0 模型微调而来,主要用于生成 25 帧 576x1024 的视频。后来,开发团队又开源了 SVD 1.1-xt,该模型在固定帧率的视频数据上微调,输出视频更加连贯。为了做实验方便,在这篇文章中,我们将使用最基础的 SVD 1.0 模型。

参考 Diffusers 官方文档: https://huggingface.co/docs/diffusers/main/en/using-diffusers/svd ,我们来创建一个关于 SVD 的 “Hello World” 项目。如果你的电脑可以访问 HuggingFace 原站的话,直接运行下面的脚本就行了;如果不能访问原网站,可以尝试取消代码里的那行注释,访问 HuggingFace 镜像站;如果还是不行,则需要手动下载 “stabilityai/stable-video-diffusion-img2vid” 仓库,并将仓库路径改成本地下载的仓库路径。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch
import os
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video

pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()

# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
image = image.resize((1024, 576))

generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]

export_to_video(frames, "generated.mp4", fps=7)

成功运行后,我们能得到这样的一个火箭升空视频。它的第一帧会和我们的输入图片一模一样。

SVD 概览

由于 SVD 并没有在论文里对其图生视频模型做详细的介绍,我们没有官方资料可以参考,只能靠阅读源码来了解 SVD 的实现细节。为了让大家在读代码时不会晕头转向,我会在读代码前简单概述一下 SVD 的模型结构和采样方法。

SVD 和 SD 一样,是一个隐扩散模型(Latent Diffusion Model, LDM)。图像(视频帧)的生成由两个阶段组成:先由扩散模型生成压缩图像,再由 VAE 解码成真实图像。

扩散模型在生成图像时,会用一个去噪 U-Net $\epsilon_\theta$ 反复对纯噪声图像 $z_T$ 去噪,直至得到一幅有意义的图片 $z$。为了让模型输出我们想要的图像,我们会用一些额外的信息来约束模型,或者说将约束信息也输入进 U-Net。对于文生图 SD 来说,额外约束是文本。对于图生视频 SVD 来说,额外约束是图像。LDM 提出了两种输入约束信息的方式:与输入噪声图像拼接、作为交叉注意力模块的 K, V。SD 仅使用了交叉注意力的方式,而 SVD 同时使用了两种方式。

上面这两种添加约束信息的方法适用于信息量比较大的约束。实际上,还有一种更简单的输入实数约束信息的方法。除了噪声输入外,去噪模型还必须输入当前的去噪时刻 $t$。自最早的 DDPM 以来,时刻 $t$ 都是先被转换成位置编码,再输入进 U-Net 的所有残差块中。仿照这种输入机制,如果有其他的约束信息和 $t$ 一样可以用一个实数表示,则不必像前面那样将这种约束信息与输入拼接或输入交叉注意力层,只需要把约束也转换成位置编码,再与 $t$ 的编码加在一起。

SVD 给模型还添加了三种额外约束:噪声增强程度、帧率、运动程度。这三种约束都是用和时刻编码相加的形式实现的。

即使现在不完全理解这三种额外约束的意义也不要紧。稍后我们会在学习 U-Net 结构时看到这种额外约束是怎么添加进 U-Net 的,在学习采样流水线时了解这三种约束的意义。

总结一下,除了添加了少数模块外,SVD 和 SD 的整体架构一样,都是以去噪 U-Net 为核心的 LDM。除了原本扩散模型要求的噪声、去噪时刻这两种输入外,SVD 还加入了 4 种约束信息:约束图像(视频首帧)、噪声增强程度、帧率、运动程度。约束图像是最主要的约束信息,它会与噪声输入拼接,且输入进 U-Net 的交叉注意力层中。后三种额外约束会以和处理去噪时刻类似的方式输入进 U-Net 中。

去噪模型结构

接下来,我们来学习 SVD 的去噪模型的结构。在 Diffusers 中,一个扩散模型的参数、配置全部放在一个模型文件夹里,该文件夹的各个子文件夹存储了模型的各个模块,如自编码器、去噪模型、调度器等。我们可以在 https://huggingface.co/stabilityai/stable-video-diffusion-img2vid/tree/main 找到 SVD 的模型文件夹,或者访问我们本地下载好的模型文件夹。

SVD 的去噪 U-Net 放在模型文件夹的 unet 子文件夹里。通过阅读子文件夹里的 config.json,我们就能知道模型类的名字是什么,并知道初始化模型的参数有哪些。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
{
"_class_name": "UNetSpatioTemporalConditionModel",
...
"down_block_types": [
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal"
],
...
"up_block_types": [
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal"
]
}

通过在本地 Diffusers 库文件夹里搜索类名 UNetSpatioTemporalConditionModel,或者利用 IDE 的 Python 智能提示功能,在前文的示例脚本里跳转到 StableVideoDiffusionPipeline 所在文件,再跳转到 UNetSpatioTemporalConditionModel 所在文件,我们就能知道 SVD 的去噪 U-Net 类定义在 diffusers/models/unet_spatio_temporal_condition.py 里。我们可以对照位于 diffusers/models/unet_2d_condition.py 的 SD 的 2D U-Net 类 UNet2DConditionModel 来看一下 SVD 的 U-Net 有何不同。

先来看 __init__ 构造函数。SVD U-Net 几乎就是一个写死了许多参数的特化版 2D U-Net,其构造函数也基本上是 SD 2D U-Net 的构造函数的子集。比如 2D U-Net 允许用 act_fn 来指定模型的激活函数,默认为 "silu",而 SVD U-Net 直接把所有模块的激活函数写死成 "silu"。经过简化后,SVD U-Net 的构造函数可读性高了很多。我们从参数开始读起,逐一了解构造函数每一个参数的意义:

  • sample_size=None:隐空间图片边长。供其他代码调用,与 U-Net 无关。
  • in_channels=8:输入通道数。
  • out_channels=4: 输出通道数。
  • down_block_types:每一大层下采样模块的类名。
  • up_block_types:每一大层上采样模块的类名。
  • block_out_channels = (320, 640, 1280, 1280):每一大层的通道数。
  • addition_time_embed_dim=256: 每个额外约束的通道数。
  • projection_class_embeddings_input_dim=768: 所有额外约束的通道数。
  • layers_per_block=2: 每一大层有几个结构相同的模块。
  • cross_attention_dim=1024: 交叉注意力层的通道数。
  • transformer_layers_per_block=1: 每一大层的每一个模块里有几个 Transformer 层。
  • num_attention_heads=(5, 10, 10, 20): 各大层多头注意力层的头数。
  • num_frames=25: 训练时的帧数。供其他代码调用,与 U-Net 无关。

SVD U-Net 的参数基本和 SD 的一致,不同之处有:1)稍后我们会在采样流水线里看到,SVD 把图像约束拼接到了噪声图像上,所以整个噪声输入的通道数是原来的两倍,从 4 变为 8;2)多了一个给采样代码用的 num_frames 参数,它其实没有被 U-Net 用到。

我们再来大致过一下构造函数的实现细节。SVD U-Net 的整体结构和 2D U-Net 的几乎一致。数据先经过下采样模块,再经过中间模块,最后过上采样模块。下采样模块和上采样模块之间有短路连接。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for i, down_block_type in enumerate(down_block_types):
...
down_block = get_down_block(...)
self.down_blocks.append(down_block)

self.mid_block = UNetMidBlockSpatioTemporal(...)

for i, up_block_type in enumerate(up_block_types):
...
up_block = get_up_block(...)
self.up_blocks.append(up_block)

self.conv_norm_out = nn.GroupNorm(...)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(...)

扩散模型还需要处理去噪时刻约束 $t$。U-Net 会先用正弦编码(Transformer 里的位置编码)time_proj 来将时刻转为向量,再用一系列线性层 time_embedding 预处理这个编码。该编码后续会输入进 U-Net 主体的每一个模块中。

1
2
3
4
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
timestep_input_dim = block_out_channels[0]

self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

除了多数扩散模型都有的 U-Net 模块外,SVD 还加入了额外约束模块。如前文所述,对于能用一个实数表示的约束,可以使用和处理时刻类似的方式,先让其过位置编码层,再过线性层,最后把得到的输出编码和时刻编码加起来。所以,和这种额外约束相关的模块在代码里叫做 add_time。在 2D U-Net 里,额外约束是可选的。SD 没有用到额外约束。而 SVD 把额外约束设为了必选模块。稍后我们会在采样流水线里看到,SVD 将视频的帧率、运动程度、噪声增强强度作为了生成时的额外约束。这些约束都是用这种与时刻编码相加的形式实现的。

1
2
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

构造函数的代码就看完了。在构造函数中,我们认识了 SVD U-Net 的各个模块,但对其工作原理或许还存在着些许疑惑。我们来模型的前向传播函数 forward 里看一下各个模块是怎么处理输入的。

看代码前,我们先回顾一下概念,整理一下 U-Net 的数据处理流程。下面是我之前给 SD U-Net 画的示意图。该图对 SVD 同样适用。和 SD 相比,SVD 的输入 x 不仅包括噪声图像(准确说是多个表示视频帧的图像),还包括作为约束的首帧图像; c 换成了首帧图像的 CLIP 编码;t 不仅包括时刻,还包括一些额外约束。

和上图所示的一样,SVD U-Net 的 forward 方法的输入包含图像 sample,时刻 timestep,交叉注意力层约束(图像编码) encoder_hidden_states , 额外约束 added_time_ids

1
2
3
4
5
6
7
8
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor,
return_dict: bool = True,
)

方法首先会处理去噪时刻和额外参数,我们来看一下这两个输入是怎么拼到一起的。

做完一系列和形状相关的处理后,输入时刻 timestep 变成了 timesteps。随后,该变量会先过正弦编码(位置编码)层 time_proj,再过一些线性层 time_embedding,得到最后输入 U-Net 主体的时刻嵌入 emb。这两个模块的命名非常容易混淆,千万别弄反了。类似地,额外约束也是先过正弦编码层 add_time_proj,再过一些线性层 add_embedding,最后其输出 aug_emb 会加到 emb 上。当然,为了确保结果可以相加,time_embeddingadd_time_proj 的输出通道数是相同的。

1
2
3
4
5
6
7
8
9
10
11
12
13
# preprocessing
# timesteps = timestep

t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb)

time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)
emb = emb + aug_emb

这里有关额外约束的处理写得很差,逻辑也很难读懂。在构造函数里,额外约束的正弦编码层 add_time_proj 的输出通道数 addition_time_embed_dim 是 256, 线性模块 add_embedding 的输入通道数 projection_class_embeddings_input_dim 是 768。两个通道数不一样的模块是怎么接起来的?

1
2
3
4
5
6
7
8
9
def __init__(
...
addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768,
...
)

self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

原来,在下面这份模块前向传播代码中,added_time_ids 的形状是 [batch_size, 3]。其中的 3 表示有三个额外约束。做了 flatten() 再过 add_time_proj 后,可以得到形状为 [3 * batch_size, 256] 的正弦编码 time_embeds。之所以三个约束可以用同一个模块来处理,是因为正弦编码没有学习参数,对所有输入都会产生同样的输出。得到 time_embeds 后,再根据从输入噪声图像里得到的 batch_size,用 reshapetime_embeds 的形状变成 [batch_size, 768]。这样,time_embeds 就可以输入进 add_embedding 里了。 add_embedding 是有可学习参数的,三个约束必须分别处理。

1
2
3
4
time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)

这些代码不应该这样写的。当前的写法不仅可读性差,还不利于维护。比较好的写法是在构造函数里把输入参数从projection_class_embeddings_input_dim 改为 num_add_time,表示额外约束的数量。之后,把 add_embedding 的输入通道数改成 num_add_time * addition_time_embed_dim。这样,使用者不必手动设置合理的 add_embedding 的输入通道数(比如保证 768 必须是 256 的 3 倍),只设置有几个额外约束就行了。这样改了之后,为了提升可读性,还可以像下面那样把 reshape 里的那个 -1 写清楚来。Diffusers 采用这种比较混乱的写法,估计是因为这段代码是从 2D U-Net 里摘抄出来的。而原 2D U-Net 需要兼容更复杂的情况,所以 add_time_projadd_embedding 的通道数需要分别指定。

1
2
3
time_embeds = time_embeds.reshape((batch_size, -1))
->
time_embeds = time_embeds.reshape((batch_size, self.num_add_time * self.addition_time_embed_dim))

预处理完时刻和额外约束后,方法还会修改所有输入的形状,使得它们第一维的长度都是 batch_size 乘视频帧数。正如我们在上一篇文章中学到的,为了兼容图像模型里的模块,我们要先把视频长度那一维和 batch 那一维合并,等到了和时序相关的模块再对视频长度那一维单独处理。

1
2
3
4
5
6
7
8
# Flatten the batch and frames dimensions
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)

后面的代码就和 2D U-Net 的几乎一样了。数据依次经过下采样块、中间块、上采样块。下采样块的中间结果还会保存在栈 down_block_res_samples 里,作为上采样模块的额外输入。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
sample = self.conv_in(sample)

image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)

down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(...)
down_block_res_samples += res_samples

sample = self.mid_block(...)

for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(...)

sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

光看 U-Net 类,我们还看不出 SVD 的 3D U-Net 和 2D U-Net 的区别。接下来,我们来看一看 U-Net 中某一个具体的模块是怎么实现的。由于 U-Net 下采样块、中间块、上采样块的结构是类似的,我们只挑某一大层的下采样模块类 CrossAttnDownBlockSpatioTemporal 来学习。

CrossAttnDownBlockSpatioTemporal 类中,我们可以看到 SVD U-Net 的每一个子模块都可以拆成残差卷积块和 Transformer 块。数据在经过子模块时,会先过残差块,再过 Transformer 块。我们来继续深究时序残差块类 SpatioTemporalResBlock 和时序 Transformer 块 TransformerSpatioTemporalModel 的实现细节。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# __init__
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
SpatioTemporalResBlock(...)
)
attentions.append(
TransformerSpatioTemporalModel(...)
)
# forward
blocks = list(zip(self.resnets, self.attentions))
for resnet, attn in blocks:
hidden_states = resnet(hidden_states, ...)
hidden_states = attn(hidden_states, ...)

在开始看代码之前,我们再回顾一下论文里有关 3D U-Net 块的介绍。SVD 的 U-Net 是从 Video LDM 的 U-Net 改过来的。下面的模块结构图源自 Video LDM 论文,我将其改成了能描述 SVD U-Net 块的图。图中红框里的模块表示在原 SD 2D U-Net 块的基础上新加入的模块。可以看出,SVD 实际上就是在原来的 2D 残差块后面加了一个 3D 卷积层,原空间注意力块后面加了一个时序注意力层。旧模块输出和新模块输出之间用一个比例 $\alpha$ 来线性混合。中间数据形状变换的细节我们已经在上篇文章里学过了,这篇文章里我们主要关心这些模块在代码里大概是怎么定义的。

3D 残差块类 SpatioTemporalResBlockdiffusers/models/resnet.py 文件中。它有三个子模块,分别对应上文示意图中的 2D 残差块、时序残差块(3D 卷积)、混合模块。在运算时,旧模块的输出会缓存到
hidden_states_mix 中,新模块的输出为 hidden_states,二者最终会送入混合模块 time_mixer 做一个线性混合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class SpatioTemporalResBlock(nn.Module):
def __init__(
self,
...
):
super().__init__()
self.spatial_res_block = ResnetBlock2D(...)
self.temporal_res_block = TemporalResnetBlock(...)
self.time_mixer = AlphaBlender(...)

def forward(
self,
...
):
hidden_states = self.spatial_res_block(hidden_states, temb)

...

hidden_states_mix = hidden_states

...

hidden_states = self.temporal_res_block(hidden_states, temb)
hidden_states = self.time_mixer(
x_spatial=hidden_states_mix,
x_temporal=hidden_states,
)
...
return hidden_states

ResnetBlock2D 是 SD 2D U-Net 的残差模块,我们在这篇文章里就不去学习它了。 时序残差块 TemporalResnetBlock 和 2D 残差块的结构几乎完全一致,唯一的区别在于 2D 卷积被换成了 3D 卷积。从代码中我们可以知道,这个模块是一个标准的残差块,数据会依次过两个卷积层,并在最后输出前与输入相加。扩散模型中的时刻约束 temb 会在数据过完第一个卷积层后,加到数据上。值得注意的是,虽然类里面的卷积层名字叫 3D 卷积,但实际上它的卷积核形状为 (3, 1, 1),这说明这个卷积层实际上只是一个时序维度上窗口大小为 3 的 1D 卷积层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class TemporalResnetBlock(nn.Module):
def __init__(...):
kernel_size = (3, 1, 1)
padding = [k // 2 for k in kernel_size]

self.norm1 = torch.nn.GroupNorm(...)
self.conv1 = nn.Conv3d(...)

if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None

self.norm2 = torch.nn.GroupNorm(...)

self.dropout = torch.nn.Dropout(0.0)
self.conv2 = nn.Conv3d(...)

self.nonlinearity = get_activation("silu")

self.use_in_shortcut = self.in_channels != out_channels

self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = nn.Conv3d(...)

def forward(self, input_tensor, temb):
hidden_states = input_tensor

hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)

if self.time_emb_proj is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, :, None, None]
temb = temb.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states + temb

hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)

if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)

output_tensor = input_tensor + hidden_states

return output_tensor

混合模块 AlphaBlender 其实就只是定义了一个可学习的混合比例 mix_factor,之后用这个比例来混合空间层输出和时序层输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class AlphaBlender(nn.Module):

def __init__(
self,
alpha: float,
...
):
...
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))

def forward(
self,
x_spatial,
x_temporal,
...
) -> torch.Tensor:
# Get mix_factor
alpha = self.get_alpha(...)
alpha = alpha.to(x_spatial.dtype)

if self.switch_spatial_to_temporal_mix:
alpha = 1.0 - alpha

x = alpha * x_spatial + (1.0 - alpha) * x_temporal
return x

看完了3D 残差块 SpatioTemporalResBlock 的内容,我们接着来看 3D 注意力块 TransformerSpatioTemporalModel 的内容。TransformerSpatioTemporalModel 也主要由 2D Transformer 块 BasicTransformerBlock、时序 Transformer 块 TemporalBasicTransformerBlock 、混合模块组成 AlphaBlender。它们的连接方式和上面的残差块类似。时序 Transformer 块和普通 2D Transformer 块一样,都是有自注意力、交叉注意力、全连接层的标准 Transformer 模块,它们的区别只在于时序 Transformer 块对输入做形状变换的方式不同,会让数据在时序维度上做信息交互。这里我们就不去进一步深究它们的实现细节了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class TransformerSpatioTemporalModel(nn.Module):
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: int = 320,
out_channels: Optional[int] = None,
num_layers: int = 1,
cross_attention_dim: Optional[int] = None,
):
...

self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(...)
for d in range(num_layers)
]
)

self.temporal_transformer_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(...)
for _ in range(num_layers)
]
)

self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
self.time_mixer = AlphaBlender(alpha=0.5, ...)

这个时序 Transformer 模块类有一个地方值得注意。我们知道,Transformer 模型本身是不知道输入数据的顺序的。无论是注意力层还是全连接层,它们都与顺序无关。为了让模型知道数据的先后顺序,比如在 NLP 里我们希望模型知道一句话里每个单词的前后顺序,我们会给输入数据加上位置编码。而有些时候我们觉得模型不用知道数据的先后顺序。比如在 SD 的 2D 图像 Transformer 块里,我们把每个像素当成一个 token,每个像素在 Transformer 块的运算方式是相同的,与其所在位置无关。而在处理视频时序的 Transformer 块中,知道视频每一帧的先后顺序看起来还是很重要的。所以,和 SD 的 2D Transformer 块不同,SVD 的时序 Transformer 块根据视频的帧号设置了位置编码,用和 NLP 里处理文本类似的方式处理视频。SVD 的时序 Transformer 类在构造函数里定义了生成位置编码的模块 TimestepEmbedding, Timesteps。在前向传播时,forward 方法会用 torch.arange(num_frames) 根据总帧数生成帧号列表,并经过两个模块得到最终的位置编码嵌入 emb。嵌入 emb 会在数据过时序 Transformer 块前与输入 hidden_states_mix 相加。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class TransformerSpatioTemporalModel(nn.Module):
def __init__(...):
...
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
...
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
...
):
...

num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
num_frames_emb = num_frames_emb.reshape(-1)
t_emb = self.time_proj(num_frames_emb)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]

for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
hidden_states = block(
...
)

hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb

hidden_states_mix = temporal_block(...)
hidden_states = self.time_mixer(...)

hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()

output = hidden_states + residual
...

到这里,我们就读完了 SVD U-Net 的主要代码。相比 SD U-Net,SVD U-Net 主要做了以下修改:

  • 由于输入多了一张约束图像,输入通道数变为原来的两倍。
  • 多加了三个和视频相关的额外约束。它们是通过和扩散模型的时刻嵌入相加输入进模型的。它们的命名通常与 add_time 相关。
  • 仿照 Video LDM 的结构设计,SVD 也在 2D 残差块后面加入了由 3D 卷积组成的时序残差块,在空间 Transformer 块后面加入了对时序维度做注意力的时序 Transformer 块。新旧模块的输出会以一个可学习的比例线性混合。

VAE 结构

SVD 不仅微调了 SD 的 U-Net,还微调了 VAE 的解码器,让输出视频在时序上更加连贯。由于更新 VAE 和更新 U-Net 的方法几乎一致,我们就来快速看一下 SVD 的时序 VAE 的结构,而跳过每个模块的更新细节。

通过阅读 VAE 的配置文件,我们可以知道时序 VAE 的类名为 AutoencoderKLTemporalDecoder,它位于文件 diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py 中。从它的构造函数里我们可以知道,时序 VAE 的编码器类是 Encoder,和 SD 的一样,只是解码器类从 Decoder 变成了 TemporalDecoder。我们来看一下这个新解码器类的代码做了哪些改动。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
...
):
super().__init__()

self.encoder = Encoder(...)

self.decoder = TemporalDecoder(...)

...

在 SD 中,VAE 和 U-Net 的组成模块是几乎一致的,二者的结构主要有三个区别:1)由于 VAE 的解码器和编码器是独立的,它们之间没有残差连接。而 U-Net 是一个整体,它的编码器(下采样块)和解码器(上采样块)之间有残差连接,以减少数据在下采样中的信息损失; 2)由于 VAE 中图像的尺寸较大,仅在 VAE 最深层图像尺寸为 64x64 时才有自注意力层。具体来说,这个自注意力层加到了 VAE 解码器的一开头,代码中相关模块称为 mid_block;3)VAE 仅有空间自注意力,而 SD U-Net 用了完整的 Transformer 块(包含自注意力层、交叉注意力层、全连接层)。由于 SD VAE 和 U-Net 结构上的相似性,SVD 的开发者直接把对 U-Net 的更新也搬到了 VAE 上来。

SVD VAE 解码器仅做了两项更新:1)将所有模块里的 2D 残差块都被换成了我们在上文中见过的 3D 残差块;2)在最终输出前加了一个 3D 卷积(时序维度上的 1D 卷积)。VAE 的自注意力层的结构并没有更新。更新 2D 残差块的方法和 U-Net 的是一致的。比如在新的上采样块类 UpBlockTemporalDecoder 中,我们就可以看到之前在新 U-Net 里看过的 3D 残差块类 SpatioTemporalResBlock 的身影。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
...

class UpBlockTemporalDecoder(nn.Module):
def __init__(...):
super().__init__()
for i in range(num_layers):
...
resnets.append(SpatioTemporalResBlock(...))

class TemporalDecoder(nn.Module):
def __init__(...):
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = nn.Conv2d(...)
self.mid_block = MidBlockTemporalDecoder(...)

...

for i in range(len(block_out_channels)):
...
up_block = UpBlockTemporalDecoder(...)
self.up_blocks.append(up_block)

...

conv_out_kernel_size = (3, 1, 1)
self.time_conv_out = torch.nn.Conv3d(...)

采样流水线

看完了 U-Net 和 VAE 的代码后,我们来看整套 SVD 的采样代码。和其他方法一样,在 Diffusers 中,一套采样方法会用一个流水线类 (xxxPipeline)来表示。SVD 对应的流水线类叫做 StableVideoDiffusionPipeline。我们可以利用 IDE 的代码跳转功能,在本文开头的示例采样脚本中跳转至 StableVideoDiffusionPipeline 所在源文件 diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py

如示例脚本所示,使用流水线类时,可以将类实例 pipe 当成一个函数来用。这种用法实际上会调用实例的 __call__ 方法。所以,在阅读流水线类的代码时,我们可以先忽略其他部分,直接看 __call__ 方法。

1
2
pipe = StableVideoDiffusionPipeline.from_pretrained(...)
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]

__call__ 的参数定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
height: int = 576,
width: int = 1024,
num_frames: Optional[int] = None,
num_inference_steps: int = 25,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: float = 0.02,
decode_chunk_size: Optional[int] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
return_dict: bool = True,
):

__call__ 的参数就是我们在使用 SVD 采样时能修改的参数,我们需要把其中的主要参数弄懂。各参数的解释如下:

  • image:SVD 会根据哪张图片生成视频。
  • height, width: 生成视频的尺寸。如果输入图片与这个尺寸对不上,会将输入图片的尺寸调整为该尺寸。
  • num_frames: 生成视频的帧数。SVD 1.0 版默认 14 帧,1.0-xt 版默认 25 帧。
  • min_guidance_scale, max_guidance_scale: 使用 Classifiser-free Guidance (CFG) 的强度范围。SVD 用了一种特殊的设置 CFG 强度的机制,稍后我们会在采样代码里见到。
  • fps:输出视频期望的帧率。SVD 的额外约束。实际上这个帧率肯定是不准的,只不过提高这个值可以让视频更平滑。
  • motion_bucket_id: SVD 的额外约束。官方没有解释该值的原理,只说明了提高该值能让输出视频的运动更多。
  • noise_aug_strength: 对输入图片添加的噪声强度。值越低输出视频越像原图。
  • decode_chunk_size: 一次放几张图片进时序 VAE 做解码,用于在内存占用和效果之间取得一个平衡。按理说一次处理所有图片得到的视频连续性最好,但那样也会消耗过多的内存。
  • num_videos_per_prompt: 对于每张输入图片 (prompt),输出几段视频。
  • generator: PyTorch 的随机数生成器。如果想要手动控制生成中的随机种子,就手动设置这个变量。
  • latents: 强制指定的扩散模型的初始高斯噪声。
  • output_type: 输出图片格式,是 NumPy、PIL,还是 PyTorch。
  • callback_on_step_endcallback_on_step_end_tensor_inputs 用于在不修改原流水线代码的情况下向采样过程中添加额外的处理逻辑。学习代码的时候可以忽略。
  • return_dict: 流水线是返回一个词典,还是像普通 Python 函数一样返回用元组表示的多个返回值。

大致搞清楚了输入参数的意义后,我们来看流水线的执行代码。一开始的代码都是在预处理输入,可以直接跳过。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames

# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width)

# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
device = self._execution_device
self._guidance_scale = max_guidance_scale

之后,代码开始预处理交叉注意力层的约束信息。在 SD 里,约束信息是文本,所以这一步会用 CLIP 文本编码器得到约束文本的嵌入。而 SVD 是一个图生视频模型,所以这一步会用 CLIP 图像编码器得到约束图像的嵌入。

1
2
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)

代码还把额外约束帧率 fps 减了个一,因为训练的时候模型实际上输入的额外约束是 fps - 1

1
fps = fps - 1

接着,代码开始处理与噪声拼接的约束图像。回顾一下,SVD 的约束图像以两种形式输入进模型:一种是过 CLIP 图像编码器,以交叉注意力 K,V 的形式输入,其预处理如上部分的代码所示;另一种形式是与原去噪 U-Net 的噪声输入拼接,其预处理如当前这部分代码所示。

在预处理要拼接的图像时,代码会先调用预处理器 image_processor.preprocess,把其他格式的图像转成 PyTorch 的 Tensor 类型。之后,代码会随机生成一点高斯噪声,并把噪声根据噪声增强强度 noise_aug_strength 加到这张约束图像上。这种做法来自于之前有约束图像的扩散模型 Cascaded diffusion modelsnoise_aug_strength 稍后会作为额外约束输入进 U-Net 里,与去噪时刻的编码相加。

1
2
3
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
image = image + noise_aug_strength * noise

加了这个噪声后,图像会过 VAE 的编码器,得到 image_latentsimage_latents 会通过 repeat 操作复制成多份,并于稍后拼接到每一帧带噪图像上。注意,一般图像在过 VAE 的编码器后,要乘一个系数 vae.config.scaling_factor; 在过 VAE 的解码器前,要除以这个系数。然而,只有在这个地方,image_latents 没有乘系数。我个人觉得这是开发者的一个失误。当然,做不做这个操作对于模型来说区别不大,因为模型能很快学会这种系数上的差异。

1
2
3
4
5
6
7
8
9
# 4. Encode input image using VAE
image_latents = self._encode_vae_image(
image,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
image_latents = image_latents.to(image_embeddings.dtype)
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)

下一步,代码会把三个额外约束拼接在一起,得到 added_time_ids。它会接入到 U-Net 中,与时刻编码加到一起。在训练时,帧率 fps 和 运动程度 motion_bucket_id 完全来自于数据集标注,而 noise_aug_strength 是可以随机设置的。在采样时,这三个参数都可以手动设置。

1
2
3
4
5
6
7
8
9
10
11
# 5. Get Added Time IDs
added_time_ids = self._get_add_time_ids(
fps,
motion_bucket_id,
noise_aug_strength,
image_embeddings.dtype,
batch_size,
num_videos_per_prompt,
self.do_classifier_free_guidance,
)
added_time_ids = added_time_ids.to(device)

再下一步,代码会将采样的总步数 num_inference_steps 告知采样调度器 scheduler。这一步是 Diffusers API 的要求。

1
2
# 6. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)

然后,代码会随机生成初始高斯噪声。不同的随机噪声即对应不同的输出视频。

1
2
3
4
5
6
7
8
9
10
11
12
13
# 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_frames,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)

开始采样前,SVD 对约束图像的强度做了一种很特殊的设定。在看代码之前,我们先回顾一下约束强度的意义。现在的扩散模型普遍使用了 CFG (Classifier-free Guidance) 技术,它允许我们在采样时灵活地调整模型和约束信息的相符程度。这个强度默认取 1.0。我们可以通过增大强度来提升模型的生成效果,比如在 SD 中,这个强度一般取 7.5,这代表模型会更加贴近输入文本。

而 SVD 中,约束信息为图像。开发者对视频的不同帧采用了不同的约束强度:首帧为 min_guidance_scale, 末帧为 max_guidance_scale。强度从首帧到末帧线性增加。默认情况下,约束强度的范围是 [1, 3]。

1
2
3
4
5
6
7
# 8. Prepare guidance scale
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
guidance_scale = guidance_scale.to(device, latents.dtype)
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
guidance_scale = _append_dims(guidance_scale, latents.ndim)

self._guidance_scale = guidance_scale

最后,就来到了扩散模型的去噪循环了。根据之前采样调度器返回的采样时刻列表 timesteps,代码从中取出去噪时刻,对纯噪声输入迭代去噪。

1
2
3
4
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):

去噪迭代的一开始,代码会根据是否要执行 CFG 来决定是否要把输入额外复制一份。这是因为做 CFG 时,我们需要把同一个输入过两次去噪模型,一次带约束,一次不带约束。为了简化这个流程,我们可以直接把输入复制一遍,这样只要过一次去噪模型就能得到两个输出了。下一行的 scale_model_input 是 Diffusers 的 API 要求,可以忽略。

1
2
3
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

接着,加了噪声、过了 VAE 解码器、没有乘系数的约束图像 image_latents 会与普通的噪声拼接到一起,作为模型的直接输入。

1
2
# Concatenate image_latents over channels dimension
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)

准备好了所有输入后,代码调用 U-Net 对输入噪声图像去噪。输入包括直接输入 latent_model_input,去噪时刻 t,约束图像的 CLIP 嵌入 image_embeddings,三个额外约束的拼接 added_time_ids

1
2
3
4
5
6
7
8
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=image_embeddings,
added_time_ids=added_time_ids,
return_dict=False,
)[0]

去噪结束后,代码根据公式做 CFG。

1
2
3
4
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)

有了去噪的输出 noise_pred 还不够,我们还需要用一些比较复杂的公式计算才能得到下一时刻的噪声图像。这一切都被 Diffusers 封装进调度器里了。

1
2
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample

以上就是一步去噪迭代的主要内容。代码会反复执行去噪迭代。这后面除了下面这行会调用 VAE 解码器将隐空间的视频解码回真实视频外,没有其他重要代码了。

1
frames = self.decode_latents(latents, num_frames, decode_chunk_size)

总结

在这篇文章中,我们学习了图生视频模型 SVD 的模型结构和采样代码。整体上看,SVD 相较 SD 在模型上的修改不多,只是在原来的 2D 模块后面加了一些在时序维度上交互信息的卷积块和 Transformer 块。在学习时,我们应该着重关注 SVD 的采样流水线。SVD 使用拼接和交叉注意力两种方式添加了图像约束,并以与时刻编码相加的方式额外输入了三种约束信息。由于视频不同帧对于首帧的依赖情况不同,SVD 还使用了一种随帧号线性增长的 CFG 强度设置方式。

近期,各个科技公司纷纷展示了自己在视频生成模型上的最新成果。虽然不少模型的演示效果都非常惊艳,但其中可供学术界研究的开源模型却少之又少。Stable Video Diffusion (SVD) 算得上是目前开源视频生成模型中的佼佼者,有认真一学的价值。在这篇文章中,我将面向之前已经熟悉 Stable Diffusion (SD) 的读者,简要解读 SVD 的论文。由于 SVD 的部分结构复用了之前的工作,并没有在论文正文中做详细介绍,所以我还会补充介绍一下 SVD 的模型结构、调度器。后续我还会在其他文章中详细介绍 SVD 的代码实现及使用方法。

背景

Stable Video Diffusion 是 Stability 公司于 2023 年 11 月 21 日公布并开源的一套用扩散模型实现的视频生成模型。由于该模型是从 Stability 公司此前发布的著名文生图模型 Stable Diffusion 2.1 微调而成的,因而得名 Stable Video Diffusion。SVD 的技术报告论文与模型同日发布,它对 SVD 的训练过程做了一个详细的分享。由于该论文过分偏向实践,这里我们仅对它的开头及中间模型设计的几处关键部分做解读。

摘要

最近,有许多视频生成模型都是在图像生成模型 SD 的基础上,添加和视频时序相关的模块,并在小规模高质量视频数据集上微调新模型。而 SVD 作者认为,该领域在训练方法及精制数据集的策略上并未达成统一。这篇文章的主要贡献,也正是提出了一套训练方法与精制数据集的方法。具体而言,SVD 的训练由三个阶段组成:文生图预训练、视频预训练、高质量视频微调。同时,SVD 提出了一种系统性的数据精制流程,包含数据的标注与过滤这两部分的策略。论文会分享诸多的实验成果,包括验证精心构建的数据集对生成高质量视频的必要性、探究视频预训练与微调这两步的重要性、展示基础模型如何为图生视频等下游任务提供强大的运动表示、演示模型如何提供多视角三维先验并可以作为微调多视角扩散模型的基础模型在一轮神经网络推理中同时生成多视角的图片。

「构建」一个数据集在论文中通常用动词 curate 及名词 curation 指代。curate 原指展出画作时,选择、组织和呈现艺术品的过程。而现代将这个词用在数据集上时,则转变为表示精心选择、组织和管理数据的过程。中文中并没有完全对应的翻译,我暂时将这个词翻译为「精制」,以区别于随便收集一些数据来构成一个数据集。

总结一下,SVD 并没有强调在模型设计或者采样算法上的创新,而主要宣传了该工作在数据集精制及训练策略上的创新。对于大部分普通研究人员来说,由于没有训练大视频模型的需求,该文章的很多内容都价值不大。我们就只是来大致过一遍这篇文章的主要内容。

SVD 模型架构回顾

Video-LDM 与 SVD

在阅读正文之前,我们先来回顾一下此前视频生成模型的开发历程,并重点探究 SVD 的模型架构——Video LDM 的具体组成。绝大多数工作在训练一个基于扩散模型的视频生成模型时,都是在预训练的 SD 上加入时序模块,如 3D 卷积,并通过微调把一个图像生成模型转换成视频生成模型。由于 SD 是一种 LDM (Latent Diffusion Model),所以这些视频模型都可以归类为 Video-LDM。所谓 LDM,就是一种先生成压缩图像,再用解码模型把压缩图像还原成真实图像的模型。而对于视频,Video-LDM 则会先生成边长压缩过的视频,再把压缩视频还原。

虽然 Video-LDM 严格上来说是一个视频扩散模型的种类,但大家一般会用Video LDM (没有横杠) 来指代 Align your Latents: High-Resolution Video Synthesis with Latent Diffusion Models 这篇工作。这篇论文已在 CVPR 2023 上发布,两个主要作者正是前一年在 CVPR 上发表 SD 论文的主要作者,也是现在这篇 SVD 论文的主要作者。从署名上来看,似乎两个作者在毕业后就加入了 Stability 公司,并将 Video LDM 拓展成了 SVD。论文中也讲到,SVD 完全复用了 Video LDM 的结构。为了了解 SVD 的模型结构,我们再来回顾一下 Video LDM 的结构。

在 SD 的基础上,Video LDM 做对模型结构了两项改动:在扩散模型的去噪模型 U-Net 中加入时序层、在对图像压缩和解压的 VAE 的解码器中加入时序层。

添加时序层

Video LDM 在 U-Net 中加入时序层的方法与多数同期方法相同,是在每个原来处理图像的空间层后面加上处理视频的时序层。Video LDM 加入的时序层包括 3D 卷积层与时序注意力层。这些新模块本身不难理解,但我们需要着重关注这些新模块是怎么与原模型兼容的。

要兼容各个模块,其实就是要兼容数据的形状。本来,图像生成模型的 U-Net 的输入形状为 B C H W,分别表示图像数、通道数、高、宽。而视频数据的形状是 B T C H W,即视频数、视频长度、通道数、高、宽。要让视频数据复用之前的图像模型的结构,只要把数据前两维合并,变成 (B T) C H W 即可。这种做法就是把 B 组长度为 T 的视频看成了 $B \cdot T$ 张图片。

对于之前已有的空间层,只要把数据形状变成 (B T) C H W 就没问题了。而 SVD 又新加入了两种时序层:3D 卷积和时序注意力。我们来看一下数据是怎么经过这些新的时序层的。

2D 卷积会对 B C H W 的数据的后两个高、宽维度做卷积。类似地,3D 卷积会对数据最后三个时间、高、宽维度做卷积。所以,过 3D 卷积前,要把形状从 (B T) C H W 变成 B C T H W,做完卷积再还原。

接下来我们来看新的时序注意力。这个地方稍微有点难理解,我们从最简单的注意力开始一点一点学习。最早的 NLP 中的注意力层的输入形状为 B L C,表示数据数、token 长度、token 通道数。L 这一维最为重要,它表示了 L个 token 之间互相交换信息。如果把其拓展成图像空间注意力,则 token 表示图像的每一个像素。在这种注意力层中,L(H W)B C H W 的数据会被转换成 B (H W) C 输入进注意力层。这表示同一组图像中,每个像素两两之间交换信息。而让视频数据过空间注意力层时,只需要把 B 换成 (B T) 即可,即把数据形状从 (B T) C H W 变为 (B T) (H W) C。这表示同一组、同一帧的图像的每个像素之间,两两交换信息。

在 SVD 新加入的时序注意力层中,token 依旧指代是某一组、某一帧上的一个像素。然而,这次我们不是让同一张图像的像素互相交换信息,而是让不同时刻的像素互相交换信息。因此,这次 token 长度 LT,它表示要像素在时间维度上交换信息。这样,在视频数据过时序层里的自注意力层时,要把数据形状从 (B T) C H W 变成 (B H W) T C。这表示每一组、图像每一处的像素独立处理,它们仅与同一位置不同时间的像素进行信息交换。

此处如果你没有理解注意力层的形状变换也不要紧,它只是一个实现细节,不影响后面的阅读。如果感兴趣的话,可以回顾一下 Transformer 论文的代码,看一下注意力运算为什么是对 B L C 的数据做操作的。

微调 VAE 解码器

Video LDM 的另一项改动是修改了图像压缩模型 VAE 的解码器。具体来说,方法先在 VAE 的解码器中加入类似的时序层,并在 VAE 配套的 GAN 的判别器里也加入了时序层,随后开始微调。在微调时,编码器不变,仅训练解码器和判别器。

如果你没听过这套 VAE + GAN 的架构的话,请回顾 Stable Diffusion 论文及与其紧密相关的 VQGAN 论文。

以上就是 Video LDM 的模型结构。SVD 对其没有做任何更改,所以也没有在论文里对模型结构做详细介绍。稍有不同的是,Video LDM 仅微调了新加入的模块,而 SVD 在加入新模块后对模型的所有参数都进行了重新训练。

SVD 训练细节

SVD 分四节介绍了模型训练过程。第一节介绍了数据精制的过程,后三节分别介绍了训练的三个阶段:文生图预训练、视频预训练、高质量视频微调。

获取了一个大规模视频数据集后,SVD 的数据精制主要由预处理和标注这两步组成。由于视频生成模型主要关注生成同一个场景的视频,而不考虑转场的问题,每段训练视频也应该尽量只包含一个场景。为此,预处理主要是在用一些自动化视频剪切工具把收集到的视频进一步切成连续的片段。经切片后,视频片段数变为原来的4倍。标注主要是给视频加上文字描述,以训练一个文生视频的模型。SVD 在添加文字描述时用到了多个标注模型,并使用大语言模型来润色描述。经预处理和标注后,得到的数据集被称作 LVD (Large Video Dataset)。

SVD 数据精制的细节中,比较值得注意的是有关视频帧数的处理。由于开发团队发现视频数据的播放速度快慢不一,于是他们使用光流预测模型来大致估计每段视频的播放速度(以帧率 FPS 表示),并将视频的帧率也作为标注。这样,在训练时,视频的帧率也可以作为一种约束信息。这样的好处是,在我们在生成视频时,可以用该约束来指定视频的播放速度。

之后我们来看 SVD 模型训练的三个阶段。对于第一个文生图预训练阶段,论文没有对模型结构做过多修改,因为他们在这一步使用了之前训练好的 SD 2.1。不过,SVD 在这一步做了一个非常重要的改进:SVD 的噪声调度器从原版的 DDPM 改成了 EDM,采样方法也改成了 EDM 的。

EDM 的论文全称为 Elucidating the Design Space of Diffusion-Based Generative Models 。这篇论文用一种概括性较强的数学模型统一表示了此前各种各样的扩散模型结构,并提出了改进版模型的训练及采样策略。简单来说,EDM 把扩散模型不同时刻的噪声强度表示成 $\sigma_t$,它表示在 $t$ 时刻时,对来自数据集的图像加了标准差为 $\sigma_t$ 的高斯噪声 $\mathcal{N}(\mathbf{0}, \sigma_t^2\mathbf{I})$。一开始,对于没加噪声的图像,$\sigma_0=0$。对于最后一个时刻 $T$ 的图像,$\sigma_T$ 要足够大,使得原图像的内容被完全破坏。

这里时刻 $0$ 与时刻 $T$ 的定义与 DDPM 论文相同,与 EDM 论文相反。

有了这样一种统一的表示后,EDM 对扩散模型的训练和采样都做了不少改进。这里我们仅关注其中最重要的一条改进:将离散噪声改进成连续噪声。原来 DDPM 的去噪模型会输入时刻 $t$ 这个参数。EDM 论文指出,$t$ 实际上表示了噪声强度 $\sigma_t$,应该把 $\sigma_t$ 输入进模型。与其用离散的 $t$ 训练一个只认识离散噪声强度的去噪模型,不如训练一个认识连续噪声强度 $\sigma$ 的模型。这样,在采样 $n$ 步时,我们不再是选择离散去噪时刻[timestep[n], timestep[n - 1], ..., 0],而是可以选择连续噪声强度[sigma[n], sigma[n - 1], ..., 0] 。这样采样更灵活,效果也更好。在第一个训练阶段中,SVD 照搬了 EDM 的这种训练方法,改进了原来的 DDPM。SVD 的默认采样策略也使用了 EDM 的 。我们会在之后的代码实践文章中详细学习这种新采样方法。

对于第二个视频预训练阶段,或许是因为视频模型和图像模型的训练过程毫无区别,论文的介绍重点依然放在了这一阶段的数据处理上,而没有强调训练方法上的创新。简单来看,这一阶段的目标是得到一个过滤后的高质量数据集 LVD-F。为了找到这样一种合适的过滤方案,开发团队先用排列组合生成了大量的过滤方案:对每类指标(文本视频匹配度、美学分数、帧率等)都设置 12.5%, 25% 或 50% 的过滤条件,然后不同指标的条件之间排列组合。之后,开发团队抽取原数据集的一个子集 LVD-10M,用各个方案得到过滤后的视频子集 LVD-10-F。最后,用这样得到的子数据集分别训练模型,比较模型输出的好坏,以决定在完整数据集上使用的最优过滤方案。

在第三个阶段,参考以往多阶段训练图像模型的经验,SVD 也在另一个小而精的视频数据集上进行微调。此数据集的获取方法并没有在论文中给出,大概率是人工手动收集并标注。

SVD 应用

经上述训练后,开发团队得到了一个低分辨率的基础文生视频模型。在实验部分,SVD 论文除了给出视频生成模型在各大公开数据集上的指标外,还分享了几个基于基础模型的应用。

高分辨率文生视频

基础文生视频最直接的应用就是高分辨率文生视频。实现的方法很简单,只要准备一个高分辨率的视频数据集,在此数据集上微调原基础模型即可。SVD 高分辨率文生视频模型能生成 576 x 1024 的视频。

高分辨率图生视频

除了文生视频外,也可以用基础模型来微调出一个图生视频模型。为了把约束从文本换成图像,开发团队将 U-Net 交叉注意力层的约束从文本嵌入变成了约束图像的图像嵌入,并将约束图像与原 U-Net 的噪声输入在通道维度上拼接在一起。特别地,参考以往 Cascaded diffusion models 论文的经验,约束图像在与噪声输入拼接前,会加上一些噪声。除此之外,由于约束机制的变动,像文生图模型一样将约束强度(CFG scale)设成 7.5 会让 SVD 图生视频模型产生瑕疵。因此,SVD 图生视频模型每一帧的约束强度不同,从第一帧到最后一帧以 1.0 到 3.0 线性增长。

参考之前 AnimateDiff 工作,SVD 也成功训练了相机运动 LoRA,使得图生视频模型只会生成平移、缩放等某一种特定相机运动的视频。

视频插帧

Video LDM 曾提出了一种把基础视频模型变成视频插帧模型方法。该方法以视频片段的首末帧为额外约束,在此新约束下把视频生成模型微调成了预测中间帧的视频预测模型。SVD 以同样方式实现了这一应用。

多视角生成

多视角生成是计算机视觉中另一类重要的任务:给定 3D 物体某个视角的图片,需要算法生成物体另外视角的图片,从而还原 3D 物体的原貌。而视频生成模型从数据中学到了物体的平滑变换规律,恰好能帮助到多视角生成任务。SVD 论文用不少篇幅介绍了如何在 3D 数据集上生成视频并微调基础模型,从而得到一个能生成环绕物体旋转的视频的模型。

结语

Stable Video Diffusion 是在文生图模型 Stable Diffusion 2.1 的基础上添加了和 Video LDM 相同的视频模块微调而成的一套视频生成模型。SVD 的论文主要介绍了其精制数据集的细节,并展示了几个微调基础模型能实现的应用。通过微调基础低分辨率文生视频模型,SVD 可以用于高分辨率文生视频、高分辨率图生视频、视频插帧、多视角生成。

对于没有资源与需求训练大视频模型的多数科研人员而言,没有深究这篇文章细节的必要。并且,由于 SVD 只开源了图生视频模型 (3D模型后来是在 SV3D 论文中正式公布的),这篇文章比较有用的只有和图生视频相关的部分。为了彻底搞懂 SVD 的原理,读这篇论文是不够的,我们还需要通过回顾 Video LDM 论文来了解模型结构,学习 EDM 论文来了解训练及采样机制。

这篇文章主要是面向熟悉 Stable Diffusion 的读者的。如果你缺少某些背景知识,欢迎读我之前介绍 Stable Diffusion 的文章。我没有在本文过多介绍 SVD 的实现细节,欢迎阅读我之后发表的 SVD 代码实践文章。

Stable Diffusion 解读(一):回顾早期工作

Stable Diffusion 解读(二):论文精读

Stable Diffusion 解读(三):原版实现及Diffusers实现源码解读

FID 是一种衡量图像生成模型质量的指标。对于这种常见的指标,一般都能找到好用的 PyTorch 计算接口。然而,当我用 PyTorch 的官方库 TorchEval 来算 FID 指标时,却发现它的结果和多数非官方库无法对齐。我花了不少时间,总算把 TorchEval 的 FID 计算接口修好了。在这篇文章中,我将分享有关 FID 计算的知识以及我调试 TorchEval 的经历,并总结用 pytorch-fid, torch-fidelity, TorchEval 算 FID 的方法。文章最后,我还会分享一个偶然发现的用于反映模型训练时的当前 FID 的方法。

FID 指标简介

FID 的全称是 Fréchet Inception Distance,它用于衡量两个图像分布之间的差距。如果令一个图像分布是训练集,再用生成模型输出的图像构成另一个分布,那么 FID 指标就表示了生成出来的图片和训练集整体上的相似度,也就间接反映了模型对训练集的拟合程度。FID 名字中的 Fréchet Distance 是一种描述两个样本分布的距离的指标,其定位和 KL 散度一样,但某些情况下会比 KL 散度更加合适。FID 用来算 Fréchet Distance 的样本来自预训练 InceptionV3 模型,它名称中的 Inception 由此而来。

计算 FID 的过程如下:

  1. 准备两个图片文件夹。一般一个是训练集,另一个存储了生成模型随机生成的图片。
  2. 用预训练的 InceptionV3 模型把每个输入图片转换成一个 2048 维的向量。
  3. 计算训练集、生成集上输出向量的均值、协方差。
  4. 把均值、协方差代入进下面这个算 Fréchet Distance 的公式,就得到了 FID。

实际上,在用 FID 的时候我们完全不用管它的原理,只要知道它的值越小就越好,并且会调用相关接口即可。需注意的是,由于 FID 是一种和集合相关的指标,算 FID 时一定要给足图片。在构建自己模型的输出集合时,至少得有 10000 张图片,推荐生成 50000 张。否则 FID 的结果会不准确。

用 PyTorch 计算 FID 的第三方库

由于 FID 的计算需要用到一个预训练的 InceptionV3 模型,只有在模型实现完全一致的情况下,FID 的输出结果才是可比的。因此,所有论文汇报的 FID 都基于提出 FID 的作者的官方实现。这份官方实现是用 TensorFlow 写的,后来也有完全等价的 PyTorch 实现。在这一节里,我们就来学习如何用这些基于 PyTorch 的库算 FID。

GitHub 上点赞最多的 PyTorch FID 库是 pytorch-fid。这个库被 FID 官方仓库推荐,且 Stable Diffusion 论文也用了这个库,结果绝对可靠。使用该库的方法很简单,只需要先安装它。

1
pip install pytorch-fid

再准备好两个用于计算 FID 的文件夹,将文件夹路径传给脚本即可。

1
python -m pytorch_fid path/to/dataset1 path/to/dataset2

另一个较为常见的用 PyTorch 算指标的库叫做 torch-fidelity。它用起来和 pytorch-fid 一样简单。一开始,需要用 pip 安装它。

1
pip install torch-fidelity

之后,同样是准备好两个图片文件夹,将文件夹路径传给脚本。

1
fidelity --gpu 0 --fid --input1 path/to/dataset1 --input2 path/to/dataset2

除了命令行脚本外,torch-fidelity 还提供了 Python API。我们可以在 Python 脚本里加入算 FID 的代码。

1
2
3
4
5
6
7
8
import torch_fidelity

metrics_dict = torch_fidelity.calculate_metrics(
input1='path1',
input2='path2',
fid=True
)
print(metrics_dict)

torch-fidelity 还提供了其他便捷的功能。比如直接以某个生成模型为 API 的输入 input1,而不是先把图像生成到一个文件夹里,再把文件夹路径传给 input1。同时,torch-fidelity 还支持计算其他指标,我们只需要在命令行脚本或者 API 里多加几个参数就行了。

修正 TorchEval 里的 FID 计算接口

尽管这些第三方库已经足够好用了,我还是想用 PyTorch 官方近年来推出的指标计算库 TorchEval 来算 FID 指标。原因有两点:

  1. 我的项目其他地方都是用 PyTorch 官方库实现的 (torch 以及 torchvision),算指标也用官方库会让整体代码风格更加统一。我已经用 TorchEval 算了 PSNR、SSIM,使用体验还可以。
  2. 目前,似乎只有 TorchEval 支持在线更新指标的值。也就是说,我可以先生成一部分图片,储存算 FID 需要的中间结果;再生成一部分图片,最终计算此前所有图片与训练集的 FID。这种计算方法的好处我会在文章后面介绍。

以前我都是用 pytorch-fid 来算 FID。而当我换成用 TorchEval 后,却发现结果对不齐。于是,漫长的调试之路开始了。

当你有两块时间不一样的手表时,应该怎样确认时间呢?答案是,再找到第三块表。如果三块表中能有两块表时间一样,那么它们的时间就是正确的。一开始,我并不能确定是哪个库写错了,所以我又测试了 torch-fidelity 的结果。实验发现,torch-fidelity 和 pytorch-fid 的结果是一致的。并且我去确认了 Stable Diffusion 的论文,其中用来计算 FID 的库也是 pytorch-fid。看来,是 TorchEval 结果不对。

像 FID 这么常见的指标,大家的中间计算过程肯定都没错,就是一些细微的预处理不太一样。抱着这样的想法,我随意地比对了一下二者的代码,很快就发现 TorchEval 把输入尺寸调成 [299, 299] 了,而 pytorch-fid 没做。可删掉这段代码,程序直接报错了。我深入阅读了 pytorch-fid 的代码,发现它的写法和 TorchEval 不一样,把调整尺寸为 [299, 299] 写到了另一个地方。且通过调查发现,InceptionV3 网络的输入尺寸必须是 [299, 299] 的,是我孤陋寡闻了。唉,看来这次的调试不能太随意啊。

我准备拿出我的真实实力来调 bug。我认真整理了一下算 FID 的步骤,将其主要过程总结为以下几步:

  1. 用预训练权重初始化 InceptionV3
  2. 用 InceptionV3 算两个数据集输出的均值、协方差
  3. 根据均值、协方差算距离

最后那个算距离的过程不涉及任何神经网络,输出该是什么就是什么。这一块是最不容易出错,且最容易调试的。于是,我决定先排除第三步是否对齐。我把 TorchEval 得到的均值、协方差存下来,用 pytorch-fid 算距离。发现结果和原 TorchEval 的输出差不多。看来算距离这一步没有问题。

接下来,我很自然地想到是不是均值和协方差算错了。我存下了两个库得到的均值、协方差,算了两个库输出之间的误差。结果发现,均值的误差在 0.09 左右,协方差的误差在 0.0002 左右。图像的数据范围在 0~1 之间,0.09 算是一个很大的误差了。可见,第一步和第二步一定存在着无法对齐的部分。

模型输出不同,最容易想到的是模型权重不同。于是,我尝试交换使用二者的模型权重,再比较输出的 FID。两个库的模型定义不太一样,不能直接换模型文件名。我用强大的代码魔改实力强行让新权重分别都跑起来了。结果非常神奇,算上之前的两个 FID,我一共得到了 4 个不一样的 FID 结果。也就是说,A 库 A 模型、B 库 B 模型、A 库 B 模型,B 库 A 模型,结果均不一样。

我被这两个库气得不行,决定认真研究对比二者的模型定义。眼尖的我发现初始化 pytorch-fid 的 InceptionV3 时有一个参数叫 use_fid_inception。作者对此的注释写道:「如果设置为 true,则用 TensorFlow 版 FID 实现;否则,用 torchvision 版 Inception 模型。TensorFlow 的 FID Inception 模型和 torchvision 的在权重和结构上有细微的差别。如果你要计算 FID,强烈推荐将此值设置为 true,以得到和其他论文可比的结果。」总结来说,TorchEval 用的是 torchvision 里的标准 PyTorch 版 InceptionV3,而 pytorch-fid 在标准 PyTorch 版 InceptionV3 外又封装了一层,改了一些模块的定义。为什么要改这些东西呢?这是因为原来的 FID Inception 模型是在 TensorFlow 里实现的,需要改一些结构来将 PyTorch 模型对齐过去。除了模型结构外,二者的权重也有一定差别。大家都是用 TensorFlow 版模型算 FID,一切都应该以 pytorch-fid 的为准。这个 TorchEval 太离谱了,我也懒得认真修改了,直接注释掉 TorchEval 里原 FIDInceptionV3 的定义,然后大笔一挥:

1
2
from pytorch_fid.inception import \
InceptionV3 as FIDInceptionV3

按理说,这下权重和模型结构都对齐了。FID 计算的第一、第二步绝对不会有错。而开始的结果表明,FID 计算的第三步也没有错。那么,两个库就应该对齐了。我激动地又测了 TorchEval 的结果,发现结果还是无法对齐!

这不应该啊?难道哪步测错了?人生就是在不断自我怀疑中度过的。而怀疑自我,首先会怀疑最久远的自我。所以,我感觉是最早测第三步的时候有问题。之前我是把 TorchEval 的均值、协方差放到 pytorch-fid 里,结果与 TorchEval 自己的输出一致。这次我反过来,把 pytorch-fid 的均值、协方差放到 TorchEval 的算距离函数里算。这次,我第一次见到 TorchEval 输出了正确的 FID。由此可见,第三步没错。难道是均值和协方差又没对齐了?

自我怀疑开始进一步推进,我开始怀疑第二步输出的均值、协方差还是没有对齐。我再次计算了 pytorch-fid 和 TorchEval 的输出之间的误差,发现误差这次仅有 1e-16,可以认为没有区别。我花了很多时间复习协方差的计算,想找出 TorchEval 里的 bug。可是越学习,越觉得 TorchEval 写得很对。这一回,我找不到错误了。

调试代码,不怕到处有错,而怕「没错却有错」。「没错」,指的是每一步中间步骤都找不到错误;「有错」,指的是最终结果还是错了。没有错误,就得创造错误。我开启了随机乱调模式,希望能触发一个错误。回忆一下,算 FID 要用到两个数据集,一般一个是训练集,一个是模型输出的集合。在 TorchEval 最后一步算距离时,我乱改代码,让一个集合的均值、协方差不变,即来自原 TorchEval 的 Inception 模型的输出;而让另一个的集合的均值、协方差来自 pytorch-fid。理论上说,如果两个库的均值、协方差是对齐的,那么这次输出的 FID 也应该是正确的。欸,这回代码报错了,运行不了。报错说数据精度不统一。原来,TorchEval 的输出精度是 float32,而 pytorch-fid 的输出精度是 float64。之前测试距离计算函数时,数据要么全来自 TorchEval,要么全来自 pytorch-fid,所以没报过这个错。可是这个错只是一个运行上的错误,稍微改改就好了。

我把 pytorch-fid 相关数据的精度统一成了 float32。这下代码跑起来了,可 FID 不对了。调试过程中,如果上一次成功,而这一次失败,则应该想办法把代码退回上一次的,再次测试。因此,我又修改了最后用 TorchEval 计算距离的数据来源,让所有数据都来自 pytorch-fid。可是,修改后,FID 输出没变,还是错的。

为什么两轮测试之前,我全用 pytorch-fid 的输出、TorchEval 的距离计算函数没有错,这次却错了?到底是哪里不同?当测试两份差不多的代码后,一份对了,一份错了,那么错误就可以定位到两份代码的差异处。仔细回顾一下我的调试经历,相信你可以推理出 bug 出自哪了。

没错!我仔细比对了当前代码和我记忆中两轮测试前的代码,仅发现了一处不同——我把 pytorch-fid 的输出数据的精度改成了 float32。把精度改回 float64 就对了。同样,如果把 TorchEval 的输出数据的精度改成 float64,再扔进 TorchEval 的距离计算函数里算,结果也是对的。问题出在 TorchEval 的距离计算函数的数据精度上。

定位到了 bug 的位置,再找出 bug 的原因就很简单了。对比 pytorch-fid 的距离计算函数和 TorchEval 的,可以发现二者描述的计算公式完全相同。然而,pytorch-fid 是用 NumPy 算的,而 TorchEval 是用 PyTorch 算的。算 FID 的距离时,会涉及矩阵特征值等较为复杂的运算,它们对数据精度要求较高。像 NumPy 这种久经考验的库应该会自动把数据变成高精度再计算,而 PyTorch 就没做这么多细腻的处理了。

汇总一下我调试的结论。TorchEval 在权重初始化、模型计算、距离计算这三步中均有错误。前两步没有让 InceptionV3 模型和普遍使用的 TensorFlow 版对齐,最后一步没有考虑输入精度,用了不够稳定的 PyTorch API 来做复杂矩阵运算。要用 TorchEval 算出正确的 FID,需要做以下修改:

  • 安装 pytorch-fid 和 TorchEval
  • 打开 torcheval/metrics/image/fid.py
  • 注释掉 FIDInceptionV3 类,在文件开头加上 from pytorch_fid.inception import InceptionV3 as FIDInceptionV3
  • FrechetInceptionDistance 类的构造函数中,在定义所有浮点数据时加上 dtype=torch.float64

这里点名批评 TorchEval。开源的时候吹得天花乱坠,结果根本没人用,这么简单的有关 FID 的 bug 也发现不了。我发了一个修正此 bug 的相关 issue https://github.com/pytorch/torcheval/issues/192,截至目前还是没有官方人员回复。这个库的开发水平实在太逆天了,希望他们能尽快维护好。

在线计算 FID

前文提到,我用 TorchEval 的原因是它支持在线计算 FID。具体来说,可以建立一个 FID 管理类,之后用 update 方法来不断往某个集合加入新图片,并随时使用 compute 方法算出当前所有图片的 FID。我之前写代码忘了清空旧图片的中间结果时发现了一个相关应用。经我使用下来,这种应用非常有用,我们可以用它高效估计训练时的当前 FID。

回顾一下,要得到准确的 FID 值,一般需要 50000 张图片。而训练图像生成模型时,如果每次验证都要生成这么多图片,则大部分时间都会消耗在验证上了。为了加快 FID 的验证,我发现可以用一种 「全局 FID」来近似表示当前的模型拟合情况。具体来说,我先用训练集的所有图片初始化 FID 的集合 1 的中间结果,再在模型训练中每次验证时随机生成 500 张图片,将其中间结果加到 FID 的集合 2 中,并输出一次当前 FID。这样,随着训练不断推进,算 FID 的图片的数量会逐渐满足 50000 张的要求,但是这些图片并不是来自同一个模型,而是来自不同训练程度的模型。这样得到的 FID 仅能大致反映当前的真实 FID 值,有时偏高、有时偏低。但经我测试发现,这种全局 FID 的相对关系很能反映最终的真实 FID 的相对关系。训练两个不同超参的模型时,如果一个全局 FID 较大,那它最终的 FID 一般也会较大。同时,如果训练一切正常,则全局 FID 会随验证轮数单调递减(因为图片数量变多,且拟合情况不会变差)。如果某一次验证时全局 FID 增加了,则模型也一定在这段时间里变差了。通过这种验证方式,我们能够大致评估模型在训练中的拟合情况。这应该是一种很容易想到的工程技巧,但由于分享自己训练生成模型的经验帖较少,且重要性不足以写进论文,我没有在任何地方看到有人介绍这种技巧。

总结

FID 是评估图像生成模型的重要指标。通过 pytorch-fid 等库,我们能轻松地用 PyTorch 计算两个图像分布间的 FID。而通过计算输出分布和训练分布之间的 FID,我们就能评估当前模型的拟合情况。

FID 的计算本身是很简单的。所以在介绍 FID 的计算方法之外,我分享了我调试 TorchEval 的漫长过程。这段经历很有意思,我学到了不少调 bug 的新知识。此前我从来没想到过数据精度竟然会大幅影响某个值的结果。这段经历启示我们,做一些复杂运算时,不要用 PyTorch 算,最好拿 NumPy 等更稳定的库来计算。如果你调 bug 的经验不足,这段经历也能给你许多参考。

文章最后我分享了一种算全局 FID 的方法。它可以高效反映生成模型在训练时的拟合情况。该功能很容易实现,感兴趣的话可以自己尝试一下。

相信大家都在网上看过这种「笑容逐渐消失」的表情包:一张图片经过经过平滑的变形,逐渐变成另一张截然不同的图片。

对此,计算机科学中有一种专门描述此应用的任务——图像变形(image morphing)。给定两张图像,图像变形算法会输出一系列合理的插值图像。当按顺序显示这些插值图像时,它们应该能构成一个描述两张输入图像平滑变换的视频。

图像变形可以广泛运用于创意制作中。比如在做 PPT 时,我们可以在翻页处用图像变形做出炫酷的过渡效果。当然,图像变形也可以用在更严谨的场合。比如在制作游戏中的 2D 人物动画时,可以让画师只画好一系列关键帧,再用图像变形来补足中间帧。可是,这种任务对于中间插值图像的质量有着很高的要求。而传统的基于优化的图像变形算法只能对两张输入图像的像素进行一定程度的变形与混合,难以生成高质量的中间帧。有没有一种更好的图像变形算法呢?

针对这一需求,我们提出了 DiffMorpher —— 一种基于预训练扩散模型 (Stable Diiffusion)的图像变形算法。该研究由 DragGAN 作者潘新钢教授指导,经清华大学、上海人工智能实验室、南洋理工大学 S-Lab 合作完成。目前,该工作已经被 CVPR 2024 接收。

我们可以借助 DiffMorpher 实现许多应用。最简单的玩法是输入两张人脸,生成人脸的渐变图。

如果输入一系列图片,我们还能制作更长更丰富的渐变图。

而当输入的两张图片很相似时,我们可以用该工具制作出质量尚可的补间动画。

在这篇文章中,让我们来浏览一下 DiffMorpher 的工作原理,并学习如何使用这一工具。学习 DiffMorpher 的一些技术也能为我们开发其他基于扩散模型的编辑工具提供启发。

项目官网:https://kevin-thu.github.io/DiffMorpher_page/

代码仓库:https://github.com/Kevin-thu/DiffMorpher

隐变量插值

如前所述,图像变形任务在生成插值图像时不仅需要混合输入图像的内容,还需要补充生成一些内容。用预训练的图像生成模型来完成图像变形是再自然不过的想法了。前几年,已经有一些工作探究了如何用 GAN 来完成图像变形。使用 GAN 做图像变形的方法非常直接:在 GAN 中,每张被生成的图片都由一个高维隐变量决定。可以说,隐变量蕴含了生成一张图像所需的所有信息。那么,只要先使用 GAN 反演(inversion)把输入图片变成隐变量,再对隐变量做插值,就能用其生成两张输入图像的一系列中间过渡图像了。

对于隐变量,我们一般使用球面插值(slerp)而不是线性插值。

然而,GAN 生成的图像往往局限于某一类别,泛用性差。因此,用 GAN 做图像变形时,往往得不到高质量的图像插值结果。

而以 Stable Diffusion(SD)为代表的图像生成扩散模型以能生成各式各样的图像而著称。我们可以在 SD 上也用类似的过程来实现图像插值。具体来说,我们需要对 DDIM 反演得到的纯噪声图像(隐变量)进行插值,并对输入文本的嵌入进行插值,最后根据插值结果生成图像。

可是,扩散模型也存在缺陷:扩散模型的隐变量没有 GAN 的那么适合编辑。如下面的动图所示,如果仅使用简单的隐变量插值,会存在着两个问题:1)早期和晚期的中间帧和输入图像非常相近,而中期的中间帧又变化过快,图像的过渡非常突兀;2)中间帧的图像质量较低。这样的结果无法满足实际应用的要求。

LoRA 插值

扩散模型的隐变量不适合编辑,准确来说是其所在隐空间的性质导致的。模型只能处理隐空间中部分区域的隐变量。如果对隐变量稍加修改,让隐变量「跑」到了一个模型处理不了的区域,那模型就生成不了高质量的结果。而在对两个输入隐变量做插值时,插值的隐变量很可能就位于一个模型处理不了的区域。

想解决此问题,我们需要参考一些其他的工作。为了提升扩散模型的编辑单张图像的能力,一些往期工作会在单张图片上微调预训练扩散模型(即训练集只由同一张图片构成,让模型在单张图片上过拟合)。这样,无论是调整初始的隐变量还是文本输入,模型总是能够生成一些和该图片很相近的图片。比如在 Imagic 工作中,为了编辑输入图片,算法会先在输入图片上微调扩散模型,再用新的文本描述重新生成一次图片。这样,最终生成的图片既和原图很接近(鸟的外观差不多),又符合新的文本描述(鸟张开了翅膀)。

后来,许多工作用 LoRA 代替了全参数微调。LoRA 是一种高效的模型微调技术。在训练 LoRA 时,原来的模型权重不用修改,只需要训练额外引入的少量参数即可。假设原模型的参数为$W$,则 LoRA 参数可以表示为 $\Delta W$,新模型可以表示为$W + \Delta W$,其中 $\Delta W$ 里的参数比 $W$ 要少得多。训练 LoRA 的目的和全参数微调是一样的,只不过 LoRA 相对而言更加高效。

对单张图片训练 LoRA,其实就是在把整个隐空间都变成一个能生成高质量图像的空间。但付出的代价是,模型只能生成和该图片差不多的图片。

我们能不能把 LoRA 的这种性质放在隐变量插值上呢?我们可以认为,LoRA 的参数 $\Delta W$ 存储了新的隐空间的一些信息。如果我们不仅对两个输入图片的隐变量做插值,还对分别对两个输入图片训练一个 LoRA,得到$\Delta W_1, \Delta W_2$,再对两个 LoRA 的参数进行插值,得到$\Delta W = \alpha \Delta W_1 + (1-\alpha) \Delta W_2$,就能让中间的插值隐变量也能生成有意义的图片,且该图片会保留两个输入图片的性质。

相关实验结果能支撑我们的假设。下图展示了不同 LoRA 配置下,对隐变量做插值得到的结果。第一行和第二行表示分别仅使用左图或右图的 LoRA,第三行表示对 LoRA 也进行插值。可以看出,使用 LoRA 后,所有图片的质量都还不错。固定 LoRA,对隐变量做插值时,图像的风格会随隐变量变化而变化,而图像的语义内容会与训练该 LoRA 的图像相同。而对 LoRA 也进行插值的话,图像的风格、语义都会平滑过渡。

下图是前文那个例子的某些中间帧不使用 LoRA 插值和使用 LoRA 插值的结果。可以看出,使用了 LoRA 后,图像质量提升了很多。而通过对 LoRA 的插值,输出图像也会保留两个输入图像的特征。

自注意力输入的插值与替换

使用 LoRA 插值后,中间帧的图像质量得到了大幅提升,可是图像变形不连贯的问题还是没有得到解决。要提升图像变换的连贯性,还需要使用到一项和自注意力相关的技术。

深度学习中常见的注意力运算都可以表示成交叉注意力 $CrossAttn(\mathbf{x}, \mathbf{y})$,它表示数据 $\mathbf{x}$ 从数据 $\mathbf{y}$ 中获取了一次信息。交叉注意力的特例是自注意力 $SelfAttn(\mathbf{x}) = CrossAttn(\mathbf{x}, \mathbf{x})$,它表示数据 $\mathbf{x}$ 自己内部做了一次信息聚合。多数扩散模型的 U-Net 都使用了自注意力层。

由于自注意力本质上是一种交叉注意力,我们可以把另一个图像的自注意力输入替换某图像的自注意力输入。具体来说,我们可以先生成一张参考图像,将自注意力输入$\mathbf{x’}$缓存下来。再开始生成当前图片,对于原来的自注意力计算 $CrossAttn(\mathbf{x}, \mathbf{x})$,我们把第二个 $\mathbf{x}$ 换成 $\mathbf{x’}$,让计算变成 $CrossAttn(\mathbf{x}, \mathbf{x’})$。这样,在生成当前图片时,当前图片会和参考图片更加相似一些。

在扩散模型生成图像时,每个去噪时刻的每个自注意力模块的输入都有自己的意义。在替换输入时,我们必须用参考图像当前时刻当前模块的输入来做替换。

这种注意力替换技巧通常用在基于图像扩散模型的视频编辑任务里。一般我们会以输出视频的第一帧为参考图像,让生成后续帧的自注意力模块参考第一帧的信息。这样视频每一帧的风格都会更加一致。

我们可以把视频编辑任务的这种技巧挪用到图像变形任务里。在图像变形中,每一个中间帧要以一定的混合比例参考两个输入图像。那么,我们也可以先分别生成两个输入图像,缓存它们的自注意力输入$\mathbf{x}_0, \mathbf{x}_1$。在生成混合比例为 $\alpha$ 的中间帧时,我们先混合自注意力输入$\mathbf{x’} = \alpha \mathbf{x}_0 + (1-\alpha) \mathbf{x}_1$,再以 $\mathbf{x’}$ 为自注意力的第二个参数,计算 $CrossAttn(\mathbf{x_{\alpha}}, \mathbf{x’})$。

下面是不使用/使用自注意力替换的结果。可以看出,不使用注意力替换时,视频中间某帧会出现突变。而使用了注意力替换后,视频平滑了很多。

在实验中我们也发现,直接用 $\mathbf{x’}$ 来替换注意力输入会降低中间帧的质量。为了权衡质量与过渡性,我们会让替换的注意力输入在原输入 $\mathbf{x_{\alpha}}$ 和 $\mathbf{x’}$ 之间做一个混合,即令插入的注意力输入为 $\mathbf{x’} \gets \lambda \mathbf{x’} + (1-\lambda) \mathbf{x_{\alpha}}$。最终实验中我们令 $\lambda=0.6$。

重调度采样

通过注意力插值,我们解决了中间帧跳变的问题。然而,视频的变换速度还是不够平均。在开始和结束时,视频的变化速度较慢;而在中间时刻,视频的变化又过快。

我们使用了一种重新选择混合比例 $\alpha$ 的重调度策略来解决这一问题。之前,我们在选择混合比例时,是均匀地在 0~1 之间采样。比如要生成 10 段过渡,9个中间帧,我们就可以令混合比例为 $[0, 0.1, 0.2, …, 0.9, 1]$。但是,由于不同比例处插值图像的变化率不同,这样选取混合比例会导致每两帧之间变化量不均匀。

上图是一个可能的变化率分布图。图的横坐标是插值的混合比例,或者说视频渐变的时刻,图的纵坐标是图像内容随时间的变化率。每个矩形的面积表示相邻两帧之间的 LPIPS 感知误差。如果等间距采样混合比例的话,由于每个时刻的变化率不同,矩形的面积也不同,图像的变化会时快时慢。

我们希望重新选择一些采样的横坐标,使得相邻帧构成的矩形的面积尽可能一致。通过使用类似于平均颜色分布的直方图均衡化(histogram equalization)算法,我们可以得到重采样的混合比例 $[0, \alpha_1, \alpha_2, …, \alpha_{n-1}, 1]$,达到下面这种相邻帧变化量几乎相同的结果。

下面是不使用/使用重采样的结果。可以看出,二者生成的中间图像几乎是一致的,但左边的视频在开头和结尾会停顿一会儿,而右边的视频的内容一直都在均匀地变化。

在线示例与代码

看完了该工作的原理,我们来动手使用一下 DiffMorpher。我们先来运行一下在线示例。在线示例可以在 OpenXLab (https://openxlab.org.cn/apps/detail/KaiwenZhang/DiffMorpher ) 或者 HuggingFace(https://huggingface.co/spaces/Kevin-thu/DiffMorpher )上访问。

使用 WebUI 时,可以直接点击 Run 直接运行示例,或者手动上传两张图片并给定 prompt 再运行。

如果你对一些细节感兴趣,也可以手动 clone GitHub 仓库。配置环境的过程也很简单,只需要准备一个有 PyTorch 的运行环境,再安装相关 Pip 包即可。注意,该项目用的 Diffusers 版本较旧,最新的 Diffusers 可能无法成功运行,建议直接照着 requirements.txt 里的版本来。

1
2
3
git clone https://github.com/Kevin-thu/DiffMorpher.git
cd DiffMorpher
pip install -r requirements.txt

配置好了环境后,可以直接尝试仓库里自带的示例:

1
2
3
4
5
python main.py \
--image_path_0 ./assets/Trump.jpg --image_path_1 ./assets/Biden.jpg \
--prompt_0 "A photo of an American man" --prompt_1 "A photo of an American man" \
--output_path "./results/Trump_Biden" \
--use_adain --use_reschedule --save_inter

运行后,就能得到下面的结果:

总结与展望

图像变形任务的目标是在两个有对应关系的图像之间产生一系列合理的过渡帧。传统基于像素变形与混合的方法无法在中间帧里生成新内容。我们希望用包含丰富图像信息的预训练扩散模型来完成图像变形任务。然而,直接对扩散模型的隐变量插值,会出现中间帧质量低、结果不连贯这两个问题。为了解决这两个问题,我们对扩散模型生成两个输入图像时的诸多属性进行了插值,包括 LoRA 插值、自注意力插值,分别解决了中间帧质量与结果连贯性的问题。另外,加入了重调度采样后,输出视频的连贯性得到了进一步的提升。

受限于图像变形这一任务本身的上限,DiffMorpher 在实际应用中的质量难以比拟专门面向某一任务的方法(比如只做拖拽式编辑,或者只做视频插帧)。这篇工作在科研上的贡献会远大于其在应用上的贡献。方法中一些较为新颖的插值手段或许会帮助到未来的图像编辑工作。

尽管 DiffMorpher 已经算是一个不错的图像变形工具了,该方法并没有从本质上提升扩散模型的可编辑性。相比 GAN 而言,逐渐对扩散模型的隐变量修改难以产生平滑的输出结果。比如在拖拽式编辑任务中,DragGAN 只需要优化 GAN 的隐变量就能产生合理的编辑效果,而扩散模型中的类似工具(如 DragDiffusion, DragonDiffusion)需要更多设计才能达到同样的结果。从本质上提升扩散模型的可编辑性依然是一个值得研究的问题。

出于可读性的考虑,本文没有过多探讨技术细节。如果你对相关技术感兴趣,欢迎阅读我之前的文章:

LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例

Stable Diffusion 中的自注意力替换技术与 Diffusers 实现