PRX Part 3 — 在24小时内训练文本到图像模型!
摘要
Photoroom的 PRX Part 3 演示了如何通过结合优化的架构和训练技术(包括感知损失、TREAD 令牌路由和 Muon 优化器)在24小时内训练文本到图像模型。
查看缓存全文
缓存时间: 2026/04/20 17:27
PRX 第3部分 — 24小时内训练文本到图像模型!
来源: https://huggingface.co/blog/Photoroom/prx-part3 返回文章列表 (https://huggingface.co/blog)
- 介绍 (https://huggingface.co/blog/Photoroom/prx-part3#introduction) - 训练方案 (https://huggingface.co/blog/Photoroom/prx-part3#the-training-recipe)- X预测和像素空间训练 (https://huggingface.co/blog/Photoroom/prx-part3#x-prediction-and-training-in-the-pixel-space) - 感知损失 (https://huggingface.co/blog/Photoroom/prx-part3#perceptual-losses) - 使用TREAD的令牌路由 (https://huggingface.co/blog/Photoroom/prx-part3#token-routing-with-tread) - 使用REPA和DINOv3的表示对齐 (https://huggingface.co/blog/Photoroom/prx-part3#representation-alignment-with-repa-and-dinov3) - 优化器:Muon (https://huggingface.co/blog/Photoroom/prx-part3#optimizer-muon) - 训练设置 (https://huggingface.co/blog/Photoroom/prx-part3#training-settings) - 结果和总结思考 (https://huggingface.co/blog/Photoroom/prx-part3#results-and-closing-thoughts) - 下一步计划? (https://huggingface.co/blog/Photoroom/prx-part3#whats-next) - 致谢 (https://huggingface.co/blog/Photoroom/prx-part3#acknowledgements)
介绍
欢迎回来 👋
在之前的两篇文章(第1部分和第2部分)中,我们探索了扩散模型的各种架构和训练技巧。我们试图独立评估每个想法,衡量吞吐量、收敛速度和最终图像质量,并试图理解什么真正改进了性能。
在这篇文章中,我们想回答一个更实际的问题:
当我们结合所有有效的技巧时会发生什么?
我们不会一次优化一个维度,而是将最有前景的成分堆叠在一起,看看在严格的计算预算下能推进多远。
具体来说,我们进行了一场24小时速通:
- 32个H200 GPU
- 约1500美元总计算预算(2美元/小时/GPU)
这远离了早期扩散模型的时代,当时训练有竞争力的模型需要花费数百万美元。这里的目标是展示该领域的发展程度,以及仔细的工程在仅仅一天的训练中能带你走多远。
这次速通不仅仅是一个有趣的实验。它很可能将作为我们未来大规模训练方案的基础。
除了结果外,我们还开源了我们的代码(Github链接),其中包含:
- 用于此次速通的训练代码
- 之前博客文章中的实验框架
这样你就可以自己复现、修改和扩展所有内容。
训练方案
现在让我们一起看看这个24小时运行中包含了什么。
X预测和像素空间训练
我们使用来自《Back to Basics: Let Denoising Generative Models Denoise》Li和He, 2025的x预测公式。如第2部分所示,这使得直接在像素空间训练并完全消除了对VAE的需求。我们使用32的补丁大小,在初始令牌投影层中使用256维瓶颈。这种设计保持了序列长度可控,使像素空间训练即使在更高分辨率下也在计算上可行。
在512px处,序列长度为:
(512 / 32)^2 = 256
在1024px处,序列长度变为:
(1024 / 32)^2 = 1024
与通常的256px → 512px → 1024px计划不同,我们直接从512px开始,然后在1024px处微调。
随着可控的令牌计数和现代硬件,像素空间训练不再禁止。它只是一个更清晰、更直接的公式。
感知损失
在像素空间中直接预测x_0的一个很好的副作用是,我们可以重复使用来自经典计算机视觉的整个工具箱。
当你的模型输出潜在表示时,感知监督变得很尴尬。你要么必须解码回像素,要么在可能或可能不符合人类感知的学习潜在空间中定义损失。一旦你直接预测像素,一切就变得简单了。你可以完全按照最初的设计方式插入感知损失。
我们从论文《PixelGen: Pixel Diffusion Beats Latent Diffusion with Perceptual Loss》Ma et al.中获得灵感,作者在扩散损失之上引入了额外的感知目标。他们展示了添加感知信号可以明显改进收敛速度和最终视觉质量。
对于这个24小时的运行,我们添加了两个辅助损失:
- LPIPS (Zhang et al.)
- 基于DINO的感知损失(我们使用DINOv2 Oquab et al.)
想法很简单:除了标准的流匹配目标外,我们还鼓励预测的清晰图像在感知特征空间中与目标图像匹配。LPIPS捕捉低级感知相似性,而DINO特征提供更强的语义信号。
我们保持了论文的相同总体思路,但我们调整了一些细节。在我们的实验中,我们凭经验发现这样做效果更好:
- 在池化后的完整图像上应用感知损失,而不是逐补丁特征
- 在所有噪声级别上应用它们
这些是小的实现细节,但在我们的设置中,它们始终给出了更好的结果。
我们对LPIPS损失使用了0.1的权重,对DINO感知损失使用了0.01的权重,与原始论文中推荐的值相匹配。
与主要的transformer前向传播相比,这些损失很轻量级,在我们的设置中只增加了少量开销,同时提供了一致的质量提升。
使用TREAD的令牌路由
为了降低每一步的成本,我们使用令牌路由与TREAD Krause et al., 2025,它随机选择一部分令牌,让它们跳过一个连续的transformer块块,然后稍后重新注入,这样就不会丢失任何东西。
我们选择TREAD而不是SPRINT Park et al., 2025主要是为了简单起见,因为SPRINT的额外复杂性在我们的设置中似乎不值得相当小的额外计算节省(在512px处TREAD的序列长度为64对128)。
遵循TREAD方案,我们从第2个块到倒数第二个块路由transformer中50%的令牌。
路由模型在vanilla CFG下看起来可能更糟,特别是当训练不足时,所以我们实现了一个受《Guiding Token-Sparse Diffusion Models》Krause et al., 2025启发的简单自我引导方案,它使用密集vs路由的条件预测进行引导,而不是依赖于无条件分支。
使用REPA和DINOv3的表示对齐
我们使用了REPA Yu et al., 2024来进行表示对齐。
对于教师模型,我们选择了DINOv3 Siméoni et al. 2025,因为它在我们之前的实验中给出了最好的质量改进。
具体来说,我们在第8个transformer块处应用对齐损失一次,损失权重为0.5。
由于我们将REPA与TREAD路由结合,我们只在非路由令牌上计算对齐损失,意味着实际通过我们应用损失的块的令牌。这保持了REPA信号的一致性,并避免了比较跳过计算路径的令牌的特征。
优化器:Muon
我们使用了Muon优化器,使用来自muon_fsdp_2的FSDP实现,因为它在我们之前的运行中显示出了相对于Adam的明确改进。
Muon仅应用于2D参数(基本上是矩阵)。其他所有内容(偏差、规范化、嵌入等)都使用Adam优化,这就是为什么配置有两个参数组。
| 组 | 应用对象 | 关键参数 |
|---|---|---|
| Muon | 2D参数 | lr=1e-4,momentum=0.95,nesterov=true,ns_steps=5 |
| Adam | 所有非2D参数 | lr=1e-4,betas=(0.9, 0.95),eps=1e-8 |
训练设置
我们在三个公开可用的合成数据集上进行了训练:
- Flux生成 (1.7M),lehduong/flux_generated
- FLUX-Reason-6M (6M),LucasFang/FLUX-Reason-6M
- midjourney-v6-llava (1M),brivangl/midjourney-v6-llava,我们使用Gemini 1.5重新标注以使提示更一致并减少标题噪声。
计划基本上是:在512处快速运行,然后在1024处锐化:
- 512px进行100k步,批量大小1024
- 1024px进行20k步,批量大小512,不使用REPA。
我们还保持权重的EMA用于采样和评估:
smoothing = 0.999update_interval = 10baema_start = 0ba
结果和总结思考
以下是我们在整个运行过程中跟踪的评估曲线和来自最终检查点的一些样本网格:

对于一天的训练运行来说,这已经是一个相当不错的地方了。该模型还不是完美的(你仍然可以看到一些纹理故障、偶尔的奇怪解剖和在非常困难的提示上可能会有些不稳定),但它显然是可用的。提示遵循很强,整体美学一致,1024阶段基本上做了我们想要的:锐化细节而不破坏构图。
关键的收获是我们非常接近。剩余的问题看起来更像是欠训练的人工制品和有限的数据多样性,而不是方案中结构性缺陷的迹象。失败模式与你对一个简单还没有看到足够多样化数据的模型所期望的一致。有了更多计算和更广泛的覆盖,这个确切的设置应该继续以相当可预测的方式改进。
从更大的角度看,这次速通也突显了扩散训练已经走了多远。通过结合像素空间训练、高效路由、表示对齐和轻量级感知引导,你现在可以在一天内在预算上获得一个有意义的模型,这个预算在不久前听起来会是不切实际的。
下一步计划?
这个24小时运行只是一个起点,而不是终点。接下来,我们将继续用更大的规模推进相同的方案,并迭代数据集混合和标注。
此次速通背后的所有代码和配置,以及整个第1部分和第2部分中使用的完整实验框架,都可以在PRX存储库中获得:https://github.com/Photoroom/PRX。
虽然我们不重新分发本次运行中使用的确切训练数据集,但管道是完全可配置的,设计用于轻松适应你自己的数据。你可以插入不同的数据集,调整单个组件(TREAD、REPA、感知损失、Muon等),并以最小的摩擦运行受控实验。我们的目标是使这成为一个快速扩散研究的实用平台,我们希望社区将使用它来在他们自己的设置中探索、基准测试和迭代这些技术。
如果你读到了这里,感谢你的阅读。我们也很希望你加入我们的Discord社区,我们在那里分享PRX的进展和结果,讨论任何与扩散和文本到图像相关的内容。
暂时告别,敬请期待下一轮实验! 🚀
致谢
此次速通受到了几项探索快速和低成本扩散模型训练的最近工作的启发。如果你有兴趣对文本到图像模型进行速通,我们鼓励你查看以下作品:
- Haridas, A., Shen, T., Yu, J. Nitro-T: Training a Text-to-Image Diffusion Model from Scratch in 1 Day. https://rocm.blogs.amd.com/artificial-intelligence/nitro-t-diffusion/README.html
- Bhanded, S. Speedrunning ImageNet Diffusion. https://arxiv.org/abs/2512.12386
- Sehwag, V., Kong, X., Li, J., Spranger, M., Lyu, L. Stretching Each Dollar: Diffusion Training from Scratch on a Micro-Budget. https://arxiv.org/abs/2407.15811
- Yeh, S.-Y. Home-made Diffusion Model from Scratch to Hatch. https://arxiv.org/abs/2509.06068
相似文章
Lens:重新思考基础文本到图像模型的训练效率
Lens是微软推出的一款紧凑型38亿参数文本到图像模型,在训练计算量显著降低的同时,通过密集描述、多分辨率批处理和高效架构,达到了与更大模型竞争甚至超越的性能。
HRM-Text: 仅用1千美元和400亿token训练,采用受大脑启发的分层潜在架构
HRM-Text是一个10亿参数文本生成模型,采用受大脑启发的分层循环架构,仅用400亿token和约1000美元即可实现高效预训练,大幅降低计算和数据需求,使得基础模型训练更加可及。
prunaai/z-image-turbo
阿里巴巴60亿参数的Z-Image-Turbo文生图模型,经PrunaAI进一步压缩,可在8步扩散下于1秒内生成1024×1024双语文字照片级图像。
HRM-Text: 超越规模的高效预训练
HRM-Text 引入了一种分层循环模型,将计算解耦为慢速和快速层级,使得仅使用400亿个token和1500美元预算即可从头开始高效预训练,实现了与更大模型竞争的性能。
使用合成数据构建快速多语言OCR模型
NVIDIA推出Nemotron OCR v2,一个使用合成数据生成技术构建的快速多语言OCR模型。该模型通过采用统一的基于FOTS的架构,在检测、识别和关系组件之间实现特征复用,在单个A100 GPU上达到34.7页/秒的性能。