世界模型的可识别令牌对应

arXiv cs.LG 论文

摘要

本文提出可识别令牌对应(Identifiable Token Correspondence)方法,通过建模跨时间帧的令牌对应关系,提升基于Transformer的世界模型在视觉强化学习中的时间一致性,在多个基准测试中取得最先进结果。

arXiv:2605.16457v1 公告类型: 新成果 摘要: 基于Transformer的世界模型在视觉强化学习中表现出色,但在长程推演中常出现时间不一致问题,包括物体重复、消失和变形。主要原因在于现有方法大多将下一帧预测纯粹视为令牌生成问题,而未明确建模令牌在时间上的对应关系。我们将下一帧预测建模为具有潜在令牌对应变量的结构化概率推理问题,推导出一个模型,其中每个下一帧令牌要么通过复制前一帧的令牌来解释,要么通过生成新令牌来解释。实验表明,该方法在4个具有挑战性的基准测试中达到了最先进性能。所提方法在Craftax-classic基准测试上获得了72.5%的回报率和35.6%的得分,显著超过了之前的最佳结果67.4%和27.9%。源代码已发布于https://github.com/snu-mllab/Identifiable-Token-Correspondence。
查看原文
查看缓存全文

缓存时间: 2026/05/19 06:43

# 面向世界模型的可辨识标记对应关系
来源:https://arxiv.org/html/2605.16457
###### 摘要

基于 Transformer 的世界模型在视觉强化学习中展现出强大性能,但在长程推演中常出现时间不一致性问题,包括物体重复、消失和变形。一个关键原因是现有方法大多将下一帧预测纯粹视为标记生成问题,而未显式建模跨时间标记之间的对应关系。我们将下一帧预测形式化为一个带有潜变量标记对应关系的结构化概率推理问题,推导出一个模型,其中每个下一帧标记要么通过复制前一帧的标记来解释,要么通过生成新标记来解释。实验表明,该方法在 4 个具有挑战性的基准上达到了最先进的性能。在 Craftax-classic 基准上,所提方法实现了 72.5% 的回报率和 35.6% 的分数,显著超越此前最佳结果 67.4% 和 27.9%。我们在 https://github.com/snu-mllab/Identifiable-Token-Correspondence 上开源了源代码。

机器学习, ICML, 强化学习, 世界模型

## 1 引言

请参考图注
图 1:在像 Craftax-classic 和 Atari 这样的视觉环境中,连续帧包含相同的底层实体。

强化学习(RL)提供了一个框架,用于通过奖励信号训练智能体与环境交互 (Sutton and Barto, 2018)。为避免严重依赖昂贵的环境交互,基于模型的 RL 学习环境动力学的预测模型,使智能体能够模拟称为“想象”的未来轨迹 (Hafner et al., 2023; Micheli et al., 2022)。近年来,Transformer 已成为强大的世界模型 (Micheli et al., 2022; Dedieu et al., 2025)。它们将过去状态和动作的序列视为标记流,并逐个标记预测下一个状态。然而,尽管最近取得了进展,这类模型在长程推演中常常表现出时间不一致性,包括物体重复、消失和变形为不同物体。这些错误随时间累积,严重限制了长程想象轨迹对策略训练的效用。

这种失败的一个核心原因是,大多数现有世界模型将下一帧预测纯粹视为标记生成问题 (Micheli et al., 2022; Dedieu et al., 2025)。然而,在现实环境中,连续帧中的许多标记对应的是随时间持续存在并移动的相同底层实体(见图 1)。因此,预测下一帧不仅需要确定每个位置应该出现什么标记,还需要确定该标记的来源。当未显式建模跨时间对应关系时,这两个问题被混为一谈,迫使模型在每一步都重新学习持久结构,从而使身份保持变得脆弱。

请参考图注
图 2:我们提出的世界模型通过求解之前状态标记(\(s_t\),蓝色)与 Transformer 输出的候选下一状态标记(\(\tilde{s}_{t+1}\),绿色)之间的最优传输问题,来增强下一状态预测,生成最终的下一状态标记(\(\hat{s}_{t+1}\))。最优传输根据 \(s_t\) 和 \(\tilde{s}_{t+1}\) 标记到 \(\hat{s}_{t+1}\) 位置的亲和矩阵进行定义。求解器接收亲和矩阵并生成传输计划,将 \(s_t\) 或 \(\tilde{s}_{t+1}\) 中的一个标记分配给 \(\hat{s}_{t+1}\) 中的每个最终下一状态标记。这种方法能够有效重用相关过去标记。

为了解决这个问题,我们提出了一种具有可辨识标记对应关系(ITC)的 Transformer 世界模型,该模型使用潜变量将每个下一帧标记分配给从前一帧复制的标记或生成的标记。该分配被形式化为前一帧标记与 Transformer 对下一帧标记的预测之间的最优传输问题。这使得能够部分重用先前标记,从而减少幻觉并提高物体随时间的持久性。

我们在 Craftax-classic、Craftax、MinAtar 和 Atari 100K 基准上评估了 ITC。Craftax-classic 是一个具有挑战性的 2D 开放世界游戏,具有长程任务和动态敌人 (Matthews et al., 2024)。ITC 实现了 72.5% 的回报率和 35.6% 的分数,创下新的最先进水平,分别超过了此前最佳结果 67.4% 和 27.9% (Dedieu et al., 2025)。Craftax 是一个基于 Craftax-classic 的更困难环境,ITC 在其中也超越了基线 (Matthews et al., 2024)。MinAtar 是 4 个 Atari 游戏的简化表示集合,测试了跨不同游戏动力学的泛化能力 (Young and Tian, 2019)。ITC 在所有 4 个游戏中超越了此前基于模型的 RL 的最先进水平 (Dedieu et al., 2025)。Atari 100K 是一个包含 26 个具有多样视觉结构的 Atari 游戏的集合。ITC 在 26 个游戏中超越了此前最先进的基于标记的世界模型 (Cohen et al., 2025)。

## 2 预备知识

### 2.1 基于模型的强化学习

强化学习考虑部分可观测马尔可夫决策过程 (POMDP),其特征为 \(\mathbb{S}, \mathbb{A}, \Omega, T, O, R, \gamma\),其中 \(\mathbb{S}\) 是状态集,\(\mathbb{A}\) 是离散动作集,\(\Omega\) 是观测集,\(T\) 给出状态间的转移概率 \(T(s' \mid s, a)\),\(O\) 给出观测概率 \(O(o \mid s)\),\(R\) 是奖励函数 \(R(s, a)\) (Sutton and Barto, 2018)。目标是找到一个策略 \(\pi\),为每个状态选择动作,以最大化期望折扣回报 \(\mathbb{E}_\pi \left[ \sum_{t \geq 0} \gamma^t r_t \right]\),其中 \(\gamma\) 是折扣因子。世界模型接收过去状态 \(s_t\) 和动作 \(a_t\) 作为输入,然后返回下一个状态 \(\hat{s}_{t+1}\)、奖励 \(r_t\) 和终止信号 \(d_t\) 的预测输出,类似于真实环境。在训练过程中,智能体通过使用策略 \(\pi\) 与环境交互来收集真实环境轨迹。然后,世界模型在保存在回放缓冲区中的轨迹上进行训练。随着训练的进行,智能体同时使用从真实环境收集的轨迹和从世界模型生成的轨迹(称为“想象”)进行训练。

### 2.2 RoPE

旋转位置编码 (RoPE) 是一种位置编码方法,通过对查询和键向量应用旋转,将位置信息注入 Transformer 的注意力机制 (Su et al., 2024)。这些旋转使得注意力操作自然地编码标记之间的相对偏移。具体来说,每个输入标记嵌入被划分为坐标对,每个坐标对形成一个 2D 子空间,并根据标记的 1D 位置索引应用旋转。由于其简单性和可扩展性,RoPE 已成为现代 Transformer 架构中的标准位置编码。

然而,RoPE 使用单一维度的位置索引,无法区分时间差异(即来自不同时间步的标记)和空间差异(即来自同一帧内不同位置的标记)。为了将空间和时间信息都融入模型,已经开发了用于多维信息的 3D 位置编码 (Wang et al., 2024; Wei et al., 2025)。每个标记的嵌入被划分为三个子向量,分别对应其时间、垂直和水平坐标。然后沿每个轴独立应用 RoPE,使注意力机制能够捕获跨空间和时间的局部关系结构。这种公式允许模型泛化到局部交互(例如相邻像素或帧),而与绝对位置无关。它在空间和时间维度上都保持了邻接性,而原始 RoPE 失去了垂直轴和时间轴的邻接性。在 3D RoPE 基础上添加绝对位置嵌入也能改善标记表示 (Agarwal et al., 2025)。

### 2.3 分词器

Transformer 世界模型需要一个分词器来将状态和动作转换为供 Transformer 使用的离散标记。Dedieu 等人 (2025) 引入了使用最近邻补丁查找将视觉观测转换为标记的分词器。每个标记代表图像状态的一个特定视觉补丁。首先,每一帧被划分为 \(L\) 个视觉补丁的网格 \(\{p_1, \ldots, p_L\}\),其中 \(p_i \in [0,1]^{h \times w \times 3}\),高度为 \(h\),宽度为 \(w\)。分词器维护一个码本 \(C = \{c_1, \ldots, c_K\}\),由 \(K\) 个码 \(c_i \in [0,1]^{h \times w \times 3}\) 组成。每个补丁 \(p\) 通过在码本中找到其最近邻来映射到标记 \(q\):

\[
q = \mathrm{enc}(p) = \operatorname*{argmin}_{1 \leq i \leq K} \|p - c_i\|_2^2 .
\]

码本通过从回放缓冲区中采样补丁来构建。如果一个补丁与所有现有码的距离足够远:当 \(\min_{1 \leq i \leq K} \|p - c_i\|_2^2 > \tau\)(对于选定的阈值 \(\tau\)),则添加该补丁。为了将标记转换回图像,分词器检索每个标记对应的码 \(\mathrm{dec}(q) = c_q\) 并重新组装网格为完整图像。

### 2.4 最优传输

最优传输是一类优化问题,基于元素间移动质量的给定成本,比较和对齐概率分布 (Peyré and Cuturi, 2019)。最优传输考虑分别定义在源域和目标域上的概率分布 \(\mathbf{a} \in \Delta^{n-1}\) 和 \(\mathbf{b} \in \Delta^{m-1}\)。给定一个成本矩阵 \(\bm{C}\),它寻求一个传输计划 \(\bm{\Pi} \in \mathbb{R}_+^{n \times m}\),使得在边际约束 \(\bm{\Pi} \mathbf{1}_m = \mathbf{a}\) 和 \(\bm{\Pi}^\top \mathbf{1}_n = \mathbf{b}\) 下,成本 \(\langle \mathbf{\Pi}, \bm{C} \rangle = \sum_{i=1}^n \sum_{j=1}^m \Pi_{ij} C_{ij}\) 最小化。

为了高效求解最优传输问题,已经提出了正则化变体。一种流行的方法是在目标中添加熵正则化项,得到 *Sinkhorn 距离*,它可以利用迭代矩阵缩放高效计算 (Cuturi, 2013)。Sinkhorn 算法在 \(O(n^2 / \epsilon^2)\) 时间内求解正则化问题,达到期望近似误差 \(\epsilon\),使其适用于大规模问题。

## 3 方法

基于第 2 节中提出的概念,我们的方法围绕一个利用最优传输求解器来建模帧间标记对应关系的 Transformer 世界模型展开。在分词器将状态和动作转换为标记后,标记嵌入被添加 3D 位置编码,然后馈入 Transformer。Transformer 输出标记由最优传输求解器用于生成下一状态标记,如图 2 所示。通过这个过程,世界模型为策略训练生成想象轨迹。

**算法 1** 使用最优传输的解码
**输入**: Transformer 预测 \(\mathbf{p}\),前一帧标记 \(\mathbf{u}\),每帧标记数 \(L\),Sinkhorn 正则化参数 \(\epsilon\),Sinkhorn 迭代次数 \(T\)
**输出**: 下一帧的生成标记 \(\mathbf{u}'\)
1. 根据公式 (1) 和 (2) 计算 \(\bm{A}^{(prev)}\)、\(\bm{A}^{(gen)}\)
2. \(\bm{A} = \begin{pmatrix} \bm{A}^{(prev)} & \mathbf{0} \in \mathbb{R}^{L \times L} \\ \bm{A}^{(gen)} & \mathbf{0} \in \mathbb{R}^{L \times L} \end{pmatrix} \in \mathbb{R}^{(2L) \times (2L)}\)
3. \(\bm{P} = \textsc{Sinkhorn}(-\bm{A}, \epsilon, T)\)
4. \(\bm{P}^{(prev)} = \bm{P}[1:L, 1:L]\)
5. \(\bm{P}^{(gen)} = \bm{P}[L+1:2L, 1:L]\)
6. \(\bm{\Pi}^{(prev)}, \bm{\Pi}^{(gen)} = \textsc{Binarization}(\bm{P}^{(prev)}, \bm{P}^{(gen)})\)
7. **for** \(j = 0\) **to** \(L-1\) **do**
8.     **if** \(\Pi_{ij}^{(prev)} = 1\) 对于某个 \(i\) **then**
9.         \(\mathbf{u}'_j = \mathbf{u}_i\)
10.    **else if** \(\Pi_{jj}^{(gen)} = 1\) **then**
11.        \(\mathbf{u}'_j = \text{sample}(\mathbf{p}_j)\)
12.    **end if**
13. **end for**
14. **返回** \(\mathbf{u}'\)

**算法 2** 部分传输计划的二值化
**输入**: 部分传输计划 \(\bm{P}^{(prev)}\)、\(\bm{P}^{(gen)}\),大值 \(v\)
**输出**: 二值化传输计划 \(\bm{\Pi}^{(prev)}\)、\(\bm{\Pi}^{(gen)}\)
1. \(\bm{P}^{\mathrm{in}} = \text{concatenate}\left( \bm{P}^{(prev)}, \bm{P}^{(gen)} \right)\)
2. 初始化 \(\bm{P}^{(0)} = \bm{P}^{\mathrm{in}},\ t = 0\)
3. **repeat**
4.     \(\mathrm{target} = \operatorname*{argmax}(\bm{P}^{(t)}, \text{dim}=1)\)
5.     \(\bm{\Pi^{\mathrm{initial}}} = \mathbf{0}_{n \times m}\)
6.     **for** \(i = 0\) **to** \(n-1\) **do**
7.         \(\bm{\Pi^{\mathrm{initial}}}[i, \mathrm{target}[i]] = 1\)
8.     **end for**
9.     \(\bm{C} = \bm{P}^{(t)} \odot \bm{\Pi^{\mathrm{initial}}} - v(1 - \bm{\Pi^{\mathrm{initial}}})\)
10.    \(\mathrm{source} = \operatorname*{argmax}(\bm{C}, \text{dim}=0)\)
11.    \(\bm{\Pi}^{\mathrm{out}} = \mathbf{0}_{n \times m}\)
12.    **for** \(j = 0\) **to** \(m-1\) **do**
13.        \(\bm{\Pi}^{\mathrm{out}}[\mathrm{source}[j], j] = 1\)
14.    **end for**
15.    \(\bm{\Pi}^{\mathrm{out}} = \bm{\Pi}^{\mathrm{out}} \odot \bm{\Pi^{\mathrm{initial}}}\)
16.    \(\bm{R} = (1 - \bm{\Pi}^{\mathrm{out}}) \odot \bm{\Pi^{\mathrm{initial}}}\)
17.    \(\bm{P}^{(t+1)} = \bm{P}^{(t)} - v \bm{R}\)
18.    \(t = t+1\)
19. **until** \(\bm{\Pi}^{\mathrm{out}} = \bm{\Pi^{\mathrm{initial}}}\)
20. \(\bm{\Pi}^{(prev)}\) ...

相似文章

Agentic RL: Token-In, Token-Out Done Right (16 minute read)

TLDR AI

This article explains the 'Token-In, Token-Out' (TITO) invariant in reinforcement learning for LLMs, highlighting a common error when training multi-turn agents with tool calls. It presents two solutions: using per-model renderers or designing training to avoid re-encoding decoded tokens, emphasizing prefix-preserving chat templates.

TONIC:面向任务无线系统的以令牌为中心的语义通信

arXiv cs.LG

本文提出TONIC,一种面向任务无线系统的以令牌为中心的语义通信框架,该框架为令牌分配具效用感知的不等错误保护,并使用基于Transformer的补全模型进行置信度感知门控,在图像分类任务上优于基线方法。

World Machine:面向时间序列的生成式世界建模

arXiv cs.LG

World Machine 提出了一种基于 Transformer 的生成式世界建模架构,用于时间序列分析。该架构通过潜在状态自适应地处理不同长度的上下文,解决了传统 Transformer 的二次内存成本问题。在合成数据集上的实验验证了该方法的可行性,并显示出相比传统 Transformer 的改进。