Identifiable Token Correspondence for World Models
Summary
This paper introduces Identifiable Token Correspondence, a method that models token correspondence across time frames to improve temporal consistency in transformer-based world models for visual reinforcement learning, achieving state-of-the-art results on multiple benchmarks.
View Cached Full Text
Cached at: 05/19/26, 06:43 AM
# Identifiable Token Correspondence for World Models
Source: [https://arxiv.org/html/2605.16457](https://arxiv.org/html/2605.16457)
###### Abstract
Transformer\-based world models have shown strong performance in visual reinforcement learning, but often suffer from temporal inconsistency in long\-horizon rollouts, including object duplication, disappearance, and transmutation\. A key reason is that most existing approaches treat next\-frame prediction purely as a token generation problem, without explicitly modeling correspondence between tokens across time\. We formulate next\-frame prediction as a structured probabilistic inference problem with latent token correspondence variables, deriving a model in which each next\-frame token is explained either by copying a token from the previous frame or by generating a new token\. Our experiments show state\-of\-the\-art performance on 4 challenging benchmarks\. The proposed method achieves a*return*of 72\.5% and a*score*of 35\.6% on the Craftax\-classic benchmark, significantly surpassing the previous best of 67\.4% and 27\.9%\. We release our source code on[https://github\.com/snu\-mllab/Identifiable\-Token\-Correspondence](https://github.com/snu-mllab/Identifiable-Token-Correspondence)\.
Machine Learning, ICML, Reinforcement Learning, World Models
## 1Introduction
Figure 1:Sequential frames in visual environments like Craftax\-classic and Atari contain the same underlying entities\.Reinforcement learning \(RL\) provides a framework for training agents to interact with their environment through reward signals\(Sutton and Barto,[2018](https://arxiv.org/html/2605.16457#bib.bib32)\)\. To avoid heavy reliance on costly environment interactions, model\-based RL learns a predictive model of the environment dynamics, enabling the agent to simulate future trajectories called “imaginations”\(Hafneret al\.,[2023](https://arxiv.org/html/2605.16457#bib.bib22); Micheliet al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib40)\)\. Recently, transformers have emerged as powerful world models\(Micheliet al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib40); Dedieuet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\. They treat sequences of past states and actions as token streams and predict the next state token\-by\-token\. However, despite recent advances, such models often exhibit temporal inconsistency in long\-horizon rollouts, including object duplication, disappearance, and transmutation into different objects\. These errors compound over time and significantly limit the usefulness of long imagined trajectories for policy training\.
A central reason for this failure is that most existing world models treat next\-frame prediction purely as a token generation problem\(Micheliet al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib40); Dedieuet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\. In realistic environments, however, many tokens in successive frames correspond to the same underlying entities that persist and move over time \(see Figure[1](https://arxiv.org/html/2605.16457#S1.F1)\)\. Predicting the next frame therefore requires determining not only what token should appear at each position, but also where that token comes from\. When correspondence across time is not modeled explicitly, these two questions are conflated, forcing the model to relearn persistent structure at every step and making identity preservation fragile\.
Figure 2:Our proposed world model enhances next state prediction by solving an optimal transport problem with previous state tokens \(sts\_\{t\}, blue\) and the transformer’s output for candidate next\-state tokens \(s~t\+1\\tilde\{s\}\_\{t\+1\}, green\) to generate the final next\-state tokens \(s^t\+1\\hat\{s\}\_\{t\+1\}\)\. Optimal transport defines an affinity matrix from thests\_\{t\}ands~t\+1\\tilde\{s\}\_\{t\+1\}tokens to the positions fors^t\+1\\hat\{s\}\_\{t\+1\}\. A solver takes the affinity matrix and produces a transport plan, assigning a token fromsts\_\{t\}ors~t\+1\\tilde\{s\}\_\{t\+1\}to each final next\-state token ins^t\+1\\hat\{s\}\_\{t\+1\}\. This approach enables effective reuse of relevant past tokens\.To address this problem, we propose a transformer world model with Identifiable Token Correspondence \(ITC\), which uses latent variables that assign each next\-frame token either to a copied token from the previous frame or a generated token\. The assignment is formulated as an optimal transport problem between the previous frame’s tokens and the transformer’s predictions for next\-frame tokens\. This enables partial reuse of previous tokens, reducing hallucinations and improving object persistence over time\.
We evaluate ITC on the Craftax\-classic, Craftax, MinAtar, and Atari 100K benchmarks\. Craftax\-classic is a challenging 2D open\-world game featuring long\-horizon tasks and dynamic enemies\(Matthewset al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib31)\)\. ITC achieves a return of 72\.5% and a score of 35\.6%, setting a new state\-of\-the\-art and outperforming the previous best results of 67\.4% and 27\.9%, respectively\(Dedieuet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\. Craftax is a harder environment based on Craftax\-classic, in which ITC also exceeds baselines\(Matthewset al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib31)\)\. MinAtar is a suite of 4 Atari games with simplified representations, which tests generality across different game dynamics\(Young and Tian,[2019](https://arxiv.org/html/2605.16457#bib.bib51)\)\. ITC surpasses the previous state of the art for model\-based RL in all 4 games\(Dedieuet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\. Atari 100K is a suite of 26 Atari games with diverse visual structure\. ITC surpasses the previous state\-of\-the\-art token\-based world model\(Cohenet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib52)\)across the 26 games\.
## 2Preliminaries
### 2\.1Model\-based Reinforcement Learning
Reinforcement learning considers a Partially Observable Markov Decision Process \(POMDP\), characterized by\(𝕊,𝔸,Ω,T,O,R,γ\)\(\\mathbb\{S\},\\mathbb\{A\},\\Omega,T,O,R,\\gamma\), where𝕊\\mathbb\{S\}is a set of states,𝔸\\mathbb\{A\}is a set of discrete actions,Ω\\Omegais a set of observations,TTgives the transition probabilities between statesT\(s′∣s,a\)T\(s^\{\\prime\}\\mid s,a\),OOgives the observation probabilitiesO\(o∣s\)O\(o\\mid s\), andRRis a reward functionR\(s,a\)R\(s,a\)\(Sutton and Barto,[2018](https://arxiv.org/html/2605.16457#bib.bib32)\)\. The goal is to find a policyπ\\piwhich chooses actions for each state that maximizes the expected discounted return𝔼π\[∑t≥0γtrt\]\\mathbb\{E\}\_\{\\pi\}\\left\[\\sum\_\{t\\geq 0\}\\gamma^\{t\}r\_\{t\}\\right\], whereγ\\gammais a discount factor\. A world model takes an input of previous statests\_\{t\}and actionata\_\{t\}, then returns a predicted output of next states^t\+1\\hat\{s\}\_\{t\+1\}, rewardrtr\_\{t\}, and done signaldtd\_\{t\}, similar to the real environment\. The agent collects real environment trajectories during training by interacting with the environment using the policyπ\\pi\. Then the world model trains on the trajectories saved in the replay buffer\. Over the course of training, the agent is trained on both the trajectories collected from the real environment and generated trajectories from the world model, calledimaginations\.
### 2\.2RoPE
Rotary Position Embedding \(RoPE\) is a positional encoding method that injects positional information into a transformer’s attention mechanism by applying rotations to query and key vectors\(Suet al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib27)\)\. These rotations cause the attention operation to naturally encode relative offsets between tokens\. Concretely, each input token embedding is partitioned into pairs of coordinates, with each pair forming a 2D subspace where a rotation is applied according to the token’s 1D position index\. Owing to its simplicity and scalability, RoPE has become the standard positional encoding in modern transformer architectures\.
However, RoPE uses a single\-dimensional position index, which is unable to distinguish between temporal differences \(i\.e\., tokens from different time steps\) and spatial differences \(i\.e\., tokens from different positions within the same frame\)\. To incorporate both spatial and temporal information into the model, 3D positional encoding for multi\-dimensional information has been developed\(Wanget al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib26); Weiet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib25)\)\. Each token’s embedding is divided into three sub\-vectors corresponding to its temporal, vertical, and horizontal coordinates\. RoPE is then applied independently along each axis, enabling the attention mechanism to capture localized relational structure across both space and time\. This formulation allows the model to generalize over local interactions \(e\.g\., neighboring pixels or frames\), regardless of absolute location\. It preserves adjacency in both spatial and temporal dimensions, while original RoPE loses the adjacency of the vertical axis and temporal axis\. On top of 3D RoPE, adding absolute positional embeddings also improves token representations\(Agarwalet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib61)\)\.
### 2\.3Tokenizer
Transformer world models require a tokenizer to convert states and actions into discrete tokens for the transformer\.Dedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)introduced a tokenizer that converts the visual observation to tokens using nearest neighbor patch lookup\. Each token represents a particular visual patch of the image state\. First, each frame is divided into a grid ofLLvisual patches\{p1,…,pL\}\\\{p\_\{1\},\\ldots,p\_\{L\}\\\}, wherepi∈\[0,1\]h×w×3p\_\{i\}\\in\[0,1\]^\{h\\times w\\times 3\}with heighthhand widthww\. The tokenizer maintains a codebookC=\{c1,…,cK\}C=\\\{c\_\{1\},\\ldots,c\_\{K\}\\\}, consisting ofKKcodesci∈\[0,1\]h×w×3c\_\{i\}\\in\[0,1\]^\{h\\times w\\times 3\}\. Each patchppis mapped to a tokenqqby finding its nearest neighbor in the codebook:
q=enc\(p\)=argmin1≤i≤K‖p−ci‖22\.q=\\mathrm\{enc\}\(p\)=\\operatorname\*\{argmin\}\_\{1\\leq i\\leq K\}\\\|p\-c\_\{i\}\\\|\_\{2\}^\{2\}~\.
The codebook is constructed by sampling patches from the replay buffer\. A patch is added if it is sufficiently far away from all existing codes: whenmin1≤i≤K‖p−ci‖22\>τ\\min\_\{1\\leq i\\leq K\}\\\|p\-c\_\{i\}\\\|\_\{2\}^\{2\}\>\\taufor a chosen thresholdτ\\tau\. To convert tokens back to images, the tokenizer retrieves the corresponding code for each tokendec\(q\)=cq\\mathrm\{dec\}\(q\)=c\_\{q\}and reassembles the grid into the full image\.
### 2\.4Optimal Transport
Optimal transport is a family of optimization problems that compares and aligns probability distributions based on a given cost of moving mass between elements\(Peyré and Cuturi,[2019](https://arxiv.org/html/2605.16457#bib.bib48)\)\. Optimal transport considers probability distributions𝐚∈Δn−1\\mathbf\{a\}\\in\\Delta^\{n\-1\}and𝐛∈Δm−1\\mathbf\{b\}\\in\\Delta^\{m\-1\}over the source and target domains, respectively\. Given a cost matrix𝑪\\bm\{C\}, it seeks a transport plan𝚷∈ℝ\+n×m\\bm\{\\Pi\}\\in\\mathbb\{R\}\_\{\+\}^\{n\\times m\}that minimizes the cost⟨𝚷,𝑪⟩=∑i=1n∑j=1mΠijCij\\langle\\mathbf\{\\Pi\},\\bm\{C\}\\rangle=\\sum\_\{i=1\}^\{n\}\\sum\_\{j=1\}^\{m\}\\Pi\_\{ij\}C\_\{ij\}, subject to the marginal constraints𝚷𝟏m=𝐚\\mathbf\{\\Pi\}\\mathbf\{1\}\_\{m\}=\\mathbf\{a\}and𝚷⊤𝟏n=𝐛\\mathbf\{\\Pi\}^\{\\top\}\\mathbf\{1\}\_\{n\}=\\mathbf\{b\}\.
To solve optimal transport problems efficiently, regularized variants of optimal transport have been proposed\. One popular approach introduces an entropic regularization term to the objective, leading to the*Sinkhorn distance*, which can be computed efficiently using iterative matrix scaling\(Cuturi,[2013](https://arxiv.org/html/2605.16457#bib.bib37)\)\. The Sinkhorn algorithm solves the regularized problem inO\(n2/ϵ2\)O\(n^\{2\}/\\epsilon^\{2\}\)time for a desired approximation errorϵ\\epsilon, making it practical for large\-scale problems\.
## 3Method
Based on the concepts presented in Section[2](https://arxiv.org/html/2605.16457#S2), our method centers around a transformer world model that leverages an optimal transport solver to model token correspondence between frames\. After the tokenizer converts states and actions to tokens, the token embeddings are augmented with 3D positional encodings, before being fed into a transformer\. The transformer output tokens are used by an optimal transport solver to produce the next state tokens, as shown in Figure[2](https://arxiv.org/html/2605.16457#S1.F2)\. Through this process, the world model generates imagined trajectories for policy training\.
Algorithm 1Decoding with Optimal TransportInput:transformer prediction
𝐩\\mathbf\{p\}, previous tokens
𝐮\\mathbf\{u\}, number of tokens per frame
LL, Sinkhorn regularization parameter
ϵ\\epsilon, Number of Sinkhorn iterations
TT
Output:Generated tokens for next frame
𝐮′\\mathbf\{u\}^\{\\prime\}
Compute
𝑨\(prev\)\\bm\{A\}^\{\(prev\)\},
𝑨\(gen\)\\bm\{A\}^\{\(gen\)\}from Equations[1](https://arxiv.org/html/2605.16457#S3.E1)and[2](https://arxiv.org/html/2605.16457#S3.E2)
𝑨=\(𝑨\(prev\)𝟎∈ℝL×L𝑨\(gen\)𝟎∈ℝL×L\)∈ℝ\(2L\)×\(2L\)\\bm\{A\}=\\begin\{pmatrix\}\\bm\{A\}^\{\(prev\)\}&\\mathbf\{0\}\\in\\mathbb\{R\}^\{L\\times L\}\\\\ \\bm\{A\}^\{\(gen\)\}&\\mathbf\{0\}\\in\\mathbb\{R\}^\{L\\times L\}\\end\{pmatrix\}\\in\\mathbb\{R\}^\{\(2L\)\\times\(2L\)\}
𝑷=Sinkhorn\(−𝑨,ϵ,T\)\\bm\{P\}=\\textsc\{Sinkhorn\}\(\-\\bm\{A\},\\epsilon,T\)
𝑷\(prev\)=𝑷\[1:L,1:L\]\\bm\{P\}^\{\(prev\)\}=\\bm\{P\}\[1:L,1:L\]
𝑷\(gen\)=𝑷\[L\+1:2L,1:L\]\\bm\{P\}^\{\(gen\)\}=\\bm\{P\}\[L\+1:2L,1:L\]
𝚷\(prev\)\\bm\{\\Pi\}^\{\(prev\)\},
𝚷\(gen\)\\bm\{\\Pi\}^\{\(gen\)\}=Binarization\(
𝑷\(prev\)\\bm\{P\}^\{\(prev\)\},
𝑷\(gen\)\\bm\{P\}^\{\(gen\)\}\)
for
j=0j=0to
L−1L\-1do
if
Πij\(prev\)=1\\Pi^\{\(prev\)\}\_\{ij\}=1for some
iithen
𝐮j′=𝐮i\\mathbf\{u\}^\{\\prime\}\_\{j\}=\\mathbf\{u\}\_\{i\}
elseif
Πjj\(gen\)=1\{\\Pi\}^\{\(gen\)\}\_\{jj\}=1then
𝐮j′=\\mathbf\{u\}^\{\\prime\}\_\{j\}=sample\(
𝐩j\\mathbf\{p\}\_\{j\}\)
endif
endfor
Return
𝐮′\\mathbf\{u\}^\{\\prime\}
Algorithm 2Binarizationof partial transport planInput:Partial transport plans
𝑷\(prev\)\\bm\{P\}^\{\(prev\)\},
𝑷\(gen\)\\bm\{P\}^\{\(gen\)\}, large value
vv
Output:Binarized transport plans
𝚷\(prev\)\\bm\{\\Pi\}^\{\(prev\)\},
𝚷\(gen\)\\bm\{\\Pi\}^\{\(gen\)\}
𝑷in\\bm\{P\}^\{\\mathrm\{in\}\}=
concatenate\(𝑷\(prev\),𝑷\(gen\)\)\\text\{concatenate\}\\left\(\\bm\{P\}^\{\(prev\)\},\\bm\{P\}^\{\(gen\)\}\\right\)
Initialize
𝑷\(0\)=𝑷in,t=0\\bm\{P\}^\{\(0\)\}=\\bm\{P\}^\{\\mathrm\{in\}\},\\ t=0
repeat
target=argmax\(𝑷\(t\),dim=1\)\\mathrm\{target\}=\\operatorname\*\{argmax\}\(\\bm\{P\}^\{\(t\)\},\\text\{dim\}=1\)
𝚷𝐢𝐧𝐢𝐭𝐢𝐚𝐥=𝟎n×m\\bm\{\\Pi^\{\\mathrm\{initial\}\}\}=\\mathbf\{0\}\_\{n\\times m\}
for
i=0i=0to
n−1n\-1do
𝚷𝐢𝐧𝐢𝐭𝐢𝐚𝐥\[i,target\[i\]\]=1\\bm\{\\Pi^\{\\mathrm\{initial\}\}\}\[i,\\mathrm\{target\}\[i\]\]=1
endfor
𝑪=𝑷\(t\)⊙𝚷𝐢𝐧𝐢𝐭𝐢𝐚𝐥−v\(1−𝚷𝐢𝐧𝐢𝐭𝐢𝐚𝐥\)\\bm\{C\}=\\bm\{P\}^\{\(t\)\}\\odot\\bm\{\\Pi^\{\\mathrm\{initial\}\}\}\-v\(1\-\\bm\{\\Pi^\{\\mathrm\{initial\}\}\}\)
source=argmax\(𝑪,dim=0\)\\mathrm\{source\}=\\operatorname\*\{argmax\}\(\\bm\{C\},\\text\{dim\}=0\)
𝚷out=𝟎n×m\\bm\{\\Pi\}^\{\\mathrm\{out\}\}=\\mathbf\{0\}\_\{n\\times m\}
for
j=0j=0to
m−1m\-1do
𝚷out\[source\[j\],j\]=1\\bm\{\\Pi\}^\{\\mathrm\{out\}\}\[\\mathrm\{source\}\[j\],j\]=1
endfor
𝚷out=𝚷out⊙𝚷𝐢𝐧𝐢𝐭𝐢𝐚𝐥\\bm\{\\Pi\}^\{\\mathrm\{out\}\}=\\bm\{\\Pi\}^\{\\mathrm\{out\}\}\\odot\\bm\{\\Pi^\{\\mathrm\{initial\}\}\}
𝑹=\(1−𝚷out\)⊙𝚷𝐢𝐧𝐢𝐭𝐢𝐚𝐥\\bm\{R\}=\(1\-\\bm\{\\Pi\}^\{\\mathrm\{out\}\}\)\\odot\\bm\{\\Pi^\{\\mathrm\{initial\}\}\}
𝑷\(t\+1\)=𝑷\(t\)−v𝑹\\bm\{P\}^\{\(t\+1\)\}=\\bm\{P\}^\{\(t\)\}\-v\\bm\{R\}
t=t\+1t=t\+1
until
𝚷out=𝚷initial\\bm\{\\Pi\}^\{\\mathrm\{out\}\}=\\bm\{\\Pi\}^\{\\mathrm\{initial\}\}
𝚷\(prev\)\\bm\{\\Pi\}^\{\(prev\)\}=
𝚷out\[1:L,1:L\]\\bm\{\\Pi\}^\{\\mathrm\{out\}\}\[1:L,1:L\]
𝚷\(gen\)\\bm\{\\Pi\}^\{\(gen\)\}=
𝚷out\[L\+1:2L,1:L\]\\bm\{\\Pi\}^\{\\mathrm\{out\}\}\[L\+1:2L,1:L\]
Return
𝚷\(prev\)\\bm\{\\Pi\}^\{\(prev\)\},
𝚷\(gen\)\\bm\{\\Pi\}^\{\(gen\)\}
Our transformer world model uses optimal transport as the identifiable token correspondence mechanism\. In existing approaches, the output of the transformer world model is directly used to predict each token in the next frame\(Micheliet al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib40),[2024](https://arxiv.org/html/2605.16457#bib.bib24); Agarwalet al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib44); Dedieuet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\. However, in most visual environments, two adjacent frames are often very similar, e\.g\. the same tiles but shifted when the player moves right\. Our intuition is closely related to the notion of optical flow from classic computer vision tasks\(Broxet al\.,[2004](https://arxiv.org/html/2605.16457#bib.bib46); Vedulaet al\.,[2005](https://arxiv.org/html/2605.16457#bib.bib45); Perazziet al\.,[2016](https://arxiv.org/html/2605.16457#bib.bib47)\)\. This relation allows tokens to be taken directly from the previous frame into the next frame, rather than offloading the burden of regenerating all next\-state tokens to the transformer\. To exploit this, the final next\-state token predictions are formulated as an optimal transport problem\. The end\-to\-end decoding process is characterized in Algorithm[1](https://arxiv.org/html/2605.16457#alg1)\.
LetLLbe the number of tokens for each frame state\. Our method constructs a graph𝒢=\(𝒱,ℰ\)\\mathcal\{G\}=\(\\mathcal\{V\},\\mathcal\{E\}\), where the vertices𝒱=𝒱S∪𝒱D\\mathcal\{V\}=\\mathcal\{V\}\_\{S\}\\cup\\mathcal\{V\}\_\{D\}consist of source vertices𝒱S\\mathcal\{V\}\_\{S\}that correspond to previous state tokens and candidate next\-state tokens, and destination vertices𝒱D\\mathcal\{V\}\_\{D\}that represent the finalized next\-state tokens \(\|𝒱S\|=2L\|\\mathcal\{V\}\_\{S\}\|=2Land\|𝒱D\|=L\|\\mathcal\{V\}\_\{D\}\|=L\)\. The edgesℰ=\{\(u,v\)∣u∈𝒱S,v∈𝒱D\}\\mathcal\{E\}=\\\{\(u,v\)\\mid u\\in\\mathcal\{V\}\_\{S\},v\\in\\mathcal\{V\}\_\{D\}\\\}connect all sources to all destinations\. We now define affinities on these edges for transport\.
LetKKbe the size of the codebook\. Given transformer predictions𝐩j∈\[0,1\]K\\mathbf\{p\}\_\{j\}\\in\[0,1\]^\{K\}for the next state tokens, and previous state tokens𝐮i∈\{0,1\}K\\mathbf\{u\}\_\{i\}\\in\\\{0,1\\\}^\{K\}for alli,j∈\{0,…,L−1\}i,j\\in\\\{0,\\dots,L\-1\\\}, we define an affinity matrix𝑨\(prev\)∈ℝL×L\\bm\{A\}^\{\(prev\)\}\\in\\mathbb\{R\}^\{L\\times L\}that scores the affinity between previous state tokens and predicted next\-state tokens\. Each entry is computed as:
Aij\(prev\)=⟨𝐩j,𝐮i⟩−cdD\(\(xi,yi\),\(xj,yj\)\),∀i,j∈\{0,…,L−1\},\\begin\{split\}A^\{\(prev\)\}\_\{ij\}=\\langle\\mathbf\{p\}\_\{j\},\\mathbf\{u\}\_\{i\}\\rangle\-c\_\{d\}D\\left\(\(x\_\{i\},y\_\{i\}\),\(x\_\{j\},y\_\{j\}\)\\right\),\\\\ \\forall i,j\\in\\\{0,\\ldots,L\-1\\\},\\end\{split\}\(1\)wherecdc\_\{d\}is a coefficient of cost for distance,D\(⋅\)D\(\\cdot\)is a distance function for 2D coordinates, and\(xi,yi\)\(x\_\{i\},y\_\{i\}\)and\(xj,yj\)\(x\_\{j\},y\_\{j\}\)are the 2D coordinates of theii\-th andjj\-th tokens, respectively\. To allow the model to generate new content not present in the previous frame, the graph includes wildcard tokens\. The matrix𝑨\(gen\)∈ℝL×L\\bm\{A\}^\{\(gen\)\}\\in\\mathbb\{R\}^\{L\\times L\}scores the bonus of admitting newly generated tokens instead of reusing the previous ones, using diagonal entries:
Akj\(gen\)=\{‖𝐩j‖∞−cw,ifk=j,−∞otherwise,∀k,j∈\{0,…,L−1\},\\begin\{split\}A\_\{kj\}^\{\(gen\)\}=\\begin\{cases\}\\\|\\mathbf\{p\}\_\{j\}\\\|\_\{\\infty\}\-c\_\{w\},&\\text\{ if \}k=j,\\\\ \-\\infty&\\text\{ otherwise, \}\\end\{cases\}\\\\ \\forall k,j\\in\\\{0,\\ldots,L\-1\\\},\\end\{split\}\(2\)wherecwc\_\{w\}is a constant penalty for using a wildcard token\. With the matrices defined above, an optimal transport plan𝑷\(prev\)\\bm\{P\}^\{\(prev\)\}and𝑷\(gen\)\\bm\{P\}^\{\(gen\)\}is computed by optimizing the following equation:
minimize𝑷\(prev\)∈\[0,1\]L×L𝑷\(gen\)∈\[0,1\]L×L\\displaystyle\\operatorname\*\{minimize\}\\limits\_\{\\begin\{subarray\}\{c\}\\,\\bm\{P\}^\{\(prev\)\}\\in\[0,1\]^\{L\\times L\}\\\\ \\bm\{P\}^\{\(gen\)\}\\in\[0,1\]^\{L\\times L\}\\end\{subarray\}\}⟨−\(𝑨\(prev\)𝑨\(gen\)\),\(𝑷\(prev\)𝑷\(gen\)\)⟩\\displaystyle\\left\\langle\-\\begin\{pmatrix\}\\bm\{A\}^\{\(prev\)\}\\\\ \\bm\{A\}^\{\(gen\)\}\\end\{pmatrix\}\_\{\\textstyle,\}\\begin\{pmatrix\}\\bm\{P\}^\{\(prev\)\}\\\\ \\bm\{P\}^\{\(gen\)\}\\end\{pmatrix\}\\right\\rangle\(3\)subjectto\\displaystyle\\mathrm\{~~~~\\,\\,subject~to\}𝑷\(prev\)𝟏L≤𝟏L,\\displaystyle\\bm\{P\}^\{\(prev\)\}\\mathbf\{1\}\_\{L\}\\leq\\mathbf\{1\}\_\{L\},𝑷\(gen\)𝟏L≤𝟏L,\\displaystyle\\bm\{P\}^\{\(gen\)\}\\mathbf\{1\}\_\{L\}\\leq\\mathbf\{1\}\_\{L\},\(𝑷\(prev\)\+𝑷\(gen\)\)⊤𝟏L=𝟏L\.\\displaystyle\\left\(\\bm\{P\}^\{\(prev\)\}\+\\bm\{P\}^\{\(gen\)\}\\right\)^\{\\top\}\\mathbf\{1\}\_\{L\}=\\mathbf\{1\}\_\{L\}\.
Figure 3:ITC achieves state\-of\-the\-art return and score in Craftax\-classic, with significantly faster convergence\(Matthewset al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib31)\)\. Shading indicates standard deviation among seeds\. \*Baselines with reported results at 1M steps are displayed with horizontal lines from 900K to 1M steps\. DART does not report score, andIRISandΔ\\Delta\-IRISdo not report standard deviation for score\.Solving this optimization problem involves the Sinkhorn algorithm\. By default, the Sinkhorn algorithm minimizes the objective given by a cost matrix rather than an affinity matrix, so the cost matrix is set as the negative of the computed affinity matrix\. Solving the optimal transport problem yields a partial transport plan, represented by a matrix with continuous values in the range\[0,1\]\[0,1\]\.
However, our application requires a strict one\-to\-one mapping between discrete tokens\. To address this, we convert the partial transport plan into a binary assignment matrix with values\{0,1\}\\\{0,1\\\}using a greedy binarization procedure based on column\-wise argmax\. Specifically, for each column in the transport matrices𝑷\(prev\)\\bm\{P\}^\{\(prev\)\}and𝑷\(gen\)\\bm\{P\}^\{\(gen\)\}, we identify the row with the highest transport weight, selecting that row in either𝑷\(prev\)\\bm\{P\}^\{\(prev\)\}or𝑷\(gen\)\\bm\{P\}^\{\(gen\)\}, whichever yields the larger value\. In the event of a conflict where multiple columns select the same row, we retain the assignment corresponding to the column with the higher transport value and reassign the conflicting column using argmax again, excluding rows that have already been assigned\. The complete binarization procedure, which is adapted fromKimet al\.\([2020](https://arxiv.org/html/2605.16457#bib.bib30)\), is described in Algorithm[2](https://arxiv.org/html/2605.16457#alg2)\.
Let𝚷\(prev\)∈\{0,1\}L×L\\bm\{\\Pi\}^\{\(prev\)\}\\in\\\{0,1\\\}^\{L\\times L\}and𝚷\(gen\)∈\{0,1\}L×L\\bm\{\\Pi\}^\{\(gen\)\}\\in\\\{0,1\\\}^\{L\\times L\}denote the binarized versions of𝑷\(prev\)\\bm\{P\}^\{\(prev\)\}and𝑷\(gen\)\\bm\{P\}^\{\(gen\)\}, respectively\.𝚷\(prev\)\\bm\{\\Pi\}^\{\(prev\)\}and𝚷\(gen\)\\bm\{\\Pi\}^\{\(gen\)\}are the latent variables that represent token correspondence\. Thejj\-th token of the next state is determined by copying theii\-th token of the previous state whereΠij\(prev\)=1\{\\Pi\}\_\{ij\}^\{\(prev\)\}=1\. If no suchiiexists, which occurs only whenΠjj\(gen\)=1\{\\Pi\}\_\{jj\}^\{\(gen\)\}=1, the model instead samples from the transformer’s predicted distribution\. The overall decoding rule is thus defined as
𝐮j′=\{𝐮i,whereΠij\(prev\)=1,sample\(𝐩j\)whereΠjj\(gen\)=1,∀j∈\{0,…,L−1\}\.\\begin\{split\}\\mathbf\{u\}^\{\\prime\}\_\{j\}=\\begin\{cases\}\\mathbf\{u\}\_\{i\},&\\text\{ where \}\{\\Pi\}\_\{ij\}^\{\(prev\)\}=1,\\\\ \\mathrm\{sample\}\(\\mathbf\{p\}\_\{j\}\)&\\text\{ where \}\{\\Pi\}\_\{jj\}^\{\(gen\)\}=1,\\end\{cases\}\\\\ \\forall j\\in\\\{0,\\ldots,L\-1\\\}\.\\end\{split\}\(4\)
By using the latent correspondence variables𝚷\(prev\)\\bm\{\\Pi\}^\{\(prev\)\}and𝚷\(gen\)\\bm\{\\Pi\}^\{\(gen\)\}in this way, the world model selectively reuses tokens that have a strong correspondence with the previous frame, while leveraging the transformer to generate new tokens for changes in the environment\.
## 4Experiments
### 4\.1Craftax\-classic
##### Environment
We evaluate our method on the Craftax\-classic environment\(Matthewset al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib31)\)\. Craftax\-classic is a fast implementation of Crafter, a challenging procedurally generated, partially observable environment featuring stochastic transitions and a complex hierarchy of achievements\(Hafner,[2021](https://arxiv.org/html/2605.16457#bib.bib21)\)\. These attributes demand both strong generalization and the ability to model object interactions across time\.
##### Experiment Configuration
Each method is trained on Craftax\-classic for 1M environment steps, using 10 different seeds per method\. The baseline methods consist of DreamerV3\(Hafneret al\.,[2023](https://arxiv.org/html/2605.16457#bib.bib22)\),iris\(Micheliet al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib40)\),Δ\\Delta\-iris\(Micheliet al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib24)\), DART\(Agarwalet al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib44)\), andDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)111We use the \(fast\) variant fromDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\), as the \(slow\) variant is prohibitively expensive to train\., which had the previous state\-of\-the\-art return on Craftax\-classic\. Each experiment runs on a single Nvidia RTX 3090 GPU for 48\.2 hours\. See Appendix[A](https://arxiv.org/html/2605.16457#A1)for all hyperparameters\.
##### Results
Figure[3](https://arxiv.org/html/2605.16457#S3.F3)shows that our proposed world model leads to substantially higher return and score, along with faster convergence compared to baseline methods\.222Score is a metric defined as the geometric mean of the success rates for each achievement\(Hafner,[2021](https://arxiv.org/html/2605.16457#bib.bib21)\)\. Score puts more emphasis on unlocking a variety of achievements, in contrast to return, which is simply the sum of rewards for each episode\.Return and score are reported in Table[1](https://arxiv.org/html/2605.16457#S4.T1), as the mean and standard error over 10 seeds\. After 1M environment interactions, our method achieves a final return and score surpassing all baselines\. It also outperforms the previous best baseline during training at 0\.5M environment interactions, demonstrating superior sample efficiency in a more data\-constrained setting\.
Table 1:Results on Craftax\-classic after 0\.5M and 1M environment interactions\. Return is averaged over episodes of the final 50,000 environment interactions to smooth out variance\. The final value for Score is reported directly, as it is already a cumulative metric and does not require additional smoothing\. Metrics not reported by baselines are marked as —\.†\{\\dagger\}uses hyperparameters of ITC\.Table 2:Prediction accuracy on a dataset of 10,000 transitions\. The first column reports overall accuracy, while the latter two break down accuracy based on whether the input state contains a randomly moving creature\. Applying ITC to the transformer outputs increases accuracy by 3\.39%\. The transformer accuracy is especially low for transitions involving creatures, which ITC improves by 3\.45%\.†\{\\dagger\}uses hyperparameters of ITC\.
##### Accuracy Evaluation
To directly assess the contribution of ITC to world model prediction, we compare the prediction accuracy of ITC to the baseline\. Our evaluation uses 10,000 environment transitions and counts how many next states are predicted perfectly \(where every predicted token is correct\)\. Table[2](https://arxiv.org/html/2605.16457#S4.T2)shows that ITC improves accuracy\. The transformer alone has particularly low accuracy in cases involving randomly moving creatures, which ITC helps with\. By improving the accuracy of world model prediction, identifiable token correspondence leads to higher quality imaginations and improved policy performance\.
##### Qualitative Analysis
Figure[4](https://arxiv.org/html/2605.16457#S4.F4)compares imaginations generated by our method vs\.Dedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\. Our method excels in situations where tiles in the generated frame are correlated\. For example, a creature in Craftax\-classic can move to adjacent tiles, but it should only move to one destination tile and should not be duplicated to multiple destination tiles\. However, because the transformer generates output tokens for a state in parallel, it cannot capture this constraint naturally\. Therefore, during imagination, duplication or disappearance of creatures occurs, which is a critical defect of modeling environment dynamics\. ITC eliminates this issue by capturing the appropriate constraint between output tiles\. Solving this issue is particularly important because similar hallucinations arise in non\-transformer world models as well\. For example, Figure[5](https://arxiv.org/html/2605.16457#S4.F5)shows duplication and disappearance artifacts in an imagination rollout generated by DreamerV3\(Hafneret al\.,[2023](https://arxiv.org/html/2605.16457#bib.bib22)\)\. Thus, by eliminating these hallucinations, ITC resolves a problem that is widespread among world models\.
Figure 4:Comparison of imagined rollouts from different world models\. \(a\) shows the ground\-truth environment trajectory, while \(b\) and \(c\) illustrate imagined rollouts generated by the baseline and ITC, respectively\. All rollouts begin from the same initial states0s\_\{0\}\(left of the yellow dashed line\)\. ITC fixes inaccurate dynamics \(red boxes\) and duplication errors \(blue boxes\) produced by the baseline\.Figure 5:An imagination rollout of DreamerV3 compared to the ground\-truth trajectory\. DreamerV3’s imagination includes disappearance of trees \(red boxes\) and duplication of trees \(blue boxes\) over time, similar to duplication issues shown in Figure[4](https://arxiv.org/html/2605.16457#S4.F4)forDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\.Table 3:Running times on a single Nvidia RTX 3090 GPU\. WM training measures one epoch of world model training\. Imagination measures one epoch of policy training in imagination\. Total time represents end\-to\-end training time for 1M environment steps\.†\{\\dagger\}uses hyperparameters of ITC\.
##### Compute Time Analysis
Table[3](https://arxiv.org/html/2605.16457#S4.T3)reports the running time of ITC, and its baselineDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)using the same hyperparameters\. ITC increases the overall end\-to\-end training time by only 2\.8%\. Thus, ITC introduces negligible overhead to world model and policy training\.
Figure 6:Example observations of Craftax\. Craftax includes many more items, enemies, and underground levels than Crafter, making it much more difficult\.
### 4\.2Craftax
Craftax is a more complex and difficult environment that builds on Craftax\-classic\(Matthewset al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib31)\)\. Craftax features a larger screen, more items, more enemies, and more levels compared to Craftax\-classic \(see Figure[6](https://arxiv.org/html/2605.16457#S4.F6)\)\. On Craftax, we compare against Simulus\(Cohenet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib52)\)andDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)333We report the \(fast\) variant from version arXiv:2502\.01591v1 ofDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\., which set the previous best return and score, respectively\. All hyperparameters are listed in Appendix[C](https://arxiv.org/html/2605.16457#A3)\. Table[4](https://arxiv.org/html/2605.16457#S4.T4)reports return and score on Craftax, as the mean and standard error over 5 seeds\. ITC achieves a return of 7\.09% and a score of 2\.40%, surpassing the baselines\. These results demonstrate that ITC can generalize to more difficult environments\.
Table 4:Results on Craftax after 1M environment interactions\. Simulus does not report Score \(—\)\.
### 4\.3MinAtar
To further validate the generalization performance of our approach, we also evaluate on the MinAtar benchmark\(Young and Tian,[2019](https://arxiv.org/html/2605.16457#bib.bib51); Lange,[2022](https://arxiv.org/html/2605.16457#bib.bib54)\)\. MinAtar consists of 4 Atari games with simplified symbolic observations of size10×1010\\times 10\. We compare against the previous state of the art for model\-based RL,Dedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\), and the recent model\-free Artificial Dopamine \(AD\) agent\(Guanet al\.,[2023](https://arxiv.org/html/2605.16457#bib.bib53)\)\. Each method is trained on each game in MinAtar for 1M environment steps \(except AD uses 5M steps\), using 10 seeds per game\. Table[5](https://arxiv.org/html/2605.16457#S4.T5)shows that ITC outperforms the baselines in all 4 games\. Return graphs for each game can be found in Appendix[D](https://arxiv.org/html/2605.16457#A4)\. By improving in every game, ITC demonstrates robust benefits across a variety of environments\.
Table 5:Returns on MinAtar after 1M environment interactions \(or 5M for AD\)\. Return is evaluated on 1,000 evaluation episodes at the end of training\.
### 4\.4Atari 100K
Table 6:Aggregate metrics on Atari 100K after 100K environment interactions\. Return for each game is evaluated on 100 evaluation episodes at the end of training\.We further assess performance on the popular Atari 100K benchmark, which trains on a suite of 26 Atari games for 100K environment interactions each, to validate the generalization of our method to non\-grid environments\(Kaiseret al\.,[2020](https://arxiv.org/html/2605.16457#bib.bib65)\)\. We use the current state\-of\-the\-art token\-based world model, Simulus\(Cohenet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib52)\), as our baseline, since theDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)baseline is not designed for or tested on Atari 100K\. We create an instantiation of ITC for Atari 100K by applying our method to Simulus\. Each method is trained on each game in Atari 100K for 100K environment steps, using 5 seeds per game\. Results on Atari 100K are reported as human\-normalized score, calculated asagent\_return−random\_agent\_returnhuman\_return−random\_agent\_return\\frac\{\\text\{agent\\\_return\}\-\\text\{random\\\_agent\\\_return\}\}\{\\text\{human\\\_return\}\-\\text\{random\\\_agent\\\_return\}\}\. Table[6](https://arxiv.org/html/2605.16457#S4.T6)shows that ITC exceeds the baseline and achieves new state\-of\-the\-art performance in interquartile mean \(IQM\) and optimality gap, the robust metrics proposed byAgarwalet al\.\([2020](https://arxiv.org/html/2605.16457#bib.bib66)\)\. Detailed results for each game are presented in Table[13](https://arxiv.org/html/2605.16457#A5.T13)of Appendix[E](https://arxiv.org/html/2605.16457#A5)\. By excelling in Atari 100K, ITC shows that its performance generalizes across 2D visual RL environments\.
## 5Related Work
### 5\.1Transformer World Models
Transformer architectures have been effectively utilized in model\-based RL\. The concept of transformer world models was first introduced byiris\(Micheliet al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib40)\)\. Building uponiris,Δ\\Delta\-irisproposed an agent architecture that encodes stochastic deltas between time steps, enhancing token efficiency by exploiting similarities between adjacent frames\(Micheliet al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib24)\)\. TWM, STORM, DART, and TWISTER also incorporated transformer world models, demonstrating their efficacy across different benchmarks\(Robineet al\.,[2023](https://arxiv.org/html/2605.16457#bib.bib39); Zhanget al\.,[2023](https://arxiv.org/html/2605.16457#bib.bib41); Agarwalet al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib44); Burchi and Timofte,[2025](https://arxiv.org/html/2605.16457#bib.bib55)\)\. Transformer world models further advanced with techniques including nearest neighbor tokenization and block teacher forcing, achieving state\-of\-the\-art performance on Craftax\-classic\(Dedieuet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\. Outside of transformers, other world models have used GRUs\(Hafneret al\.,[2023](https://arxiv.org/html/2605.16457#bib.bib22)\), diffusion\(Alonsoet al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib29)\), decoder\-free latent spaces\(Hansenet al\.,[2024](https://arxiv.org/html/2605.16457#bib.bib56)\), and discrete codebook latent spaces\(Scannellet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib57)\)\.
### 5\.2Optimal Transport in RL
Optimal transport theory has been applied to RL in other contexts, specifically for curriculum and offline reinforcement learning\. CurrOT framed curriculum generation as a constrained optimal transport problem between task distributions\(Klinket al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib33)\)\. GRADIENT formulated curriculum reinforcement learning as an optimal transport problem with a tailored distance metric between tasks\(Huanget al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib34)\)\. Additionally, Achievement Distillation introduced a contrastive learning method using optimal transport to enhance the discovery of hierarchical achievements, leading to improved sample efficiency\(Moonet al\.,[2023](https://arxiv.org/html/2605.16457#bib.bib28)\)\.
## 6Conclusion
In this paper, we present ITC, a transformer world model that captures token correspondences between frames using optimal transport\. ITC identifies the underlying entities inherent in visual environments, preventing temporal inconsistency such as duplicated or disappearing objects\. By selectively reusing tokens from preceding frames, it effectively leverages frame\-to\-frame similarities to model next\-state tokens instead of solely relying on the transformer to regenerate each one\. This enables ITC to achieve new state\-of\-the\-art performance on the challenging Craftax\-classic, Craftax, MinAtar, and Atari 100K benchmarks\.
## Acknowledgements
This work was supported by the Air Force Office of Scientific Research under award number FA2386\-25\-1\-4013, a grant from KRAFTON AI, Institute of Information & Communications Technology Planning & Evaluation \(IITP\) grant funded by the Korea government \(MSIT\) \[No\. RS\-2026\-25524173, Ultra\-Long\-Term Hierarchical Memory and Reasoning Architecture for Next\-Generation Omnimodal Agents, 30%; No\. RS\-2020\-II200882, \(SW STAR LAB\) Development of deployable learning intelligence via self\-sustainable and trustworthy machine learning, 15%; No\. RS\-2022\-II220480, Development of Training and Inference Methods for Goal Oriented Artificial Intelligence Agents, 15%; No\. RS\-2026\-25522672, Development of Unified Reasoning Technology Mimicking Human Cognition for Hierarchical Understanding and Unbounded Problem Solving, 10%; and No\. RS\-2021\-II211343, Artificial Intelligence Graduate School Program \(Seoul National University\), 10%\], and Basic Science Research Program through the National Research Foundation of Korea \(NRF\) funded by the Ministry of Education \(RS\-2023\-00274280, 20%\)\. Hyun Oh Song is the corresponding author\.
## Impact Statement
This paper presents work whose goal is to advance the field of Machine Learning\. There are many potential societal consequences of our work, none of which we feel must be specifically highlighted here\.
## References
- N\. Agarwal, A\. Ali, M\. Bala, Y\. Balaji, E\. Barker, T\. Cai, P\. Chattopadhyay, Y\. Chen, Y\. Cui, Y\. Ding,et al\.\(2025\)Cosmos world foundation model platform for physical ai\.arXiv preprint arXiv:2501\.03575\.Cited by:[§2\.2](https://arxiv.org/html/2605.16457#S2.SS2.p2.1)\.
- P\. Agarwal, S\. Andrews, and S\. E\. Kahou \(2024\)Learning to play atari in a world of tokens\.InICML,Cited by:[§3](https://arxiv.org/html/2605.16457#S3.p2.1),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px2.p1.1),[Table 1](https://arxiv.org/html/2605.16457#S4.T1.10.8.8.2),[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- R\. Agarwal, M\. Schwarzer, P\. S\. Castro, A\. C\. Courville, and M\. Bellemare \(2020\)Deep reinforcement learning at the edge of the statistical precipice\.InNeurIPS,Cited by:[§4\.4](https://arxiv.org/html/2605.16457#S4.SS4.p1.1)\.
- E\. Alonso, A\. Jelley, V\. Micheli, A\. Kanervisto, A\. Storkey, T\. Pearce, and F\. Fleuret \(2024\)Diffusion for world modeling: visual details matter in atari\.InNeurIPS,Cited by:[Table 6](https://arxiv.org/html/2605.16457#S4.T6.2.2.5.3.1),[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- J\. L\. Ba, J\. R\. Kiros, and G\. E\. Hinton \(2016\)Layer normalization\.arXiv preprint arXiv:1607\.06450\.Cited by:[Appendix D](https://arxiv.org/html/2605.16457#A4.p1.1)\.
- T\. Brox, A\. Bruhn, N\. Papenberg, and J\. Weickert \(2004\)High accuracy optical flow estimation based on a theory for warping\.InECCV,Cited by:[§3](https://arxiv.org/html/2605.16457#S3.p2.1)\.
- M\. Burchi and R\. Timofte \(2025\)Learning transformer\-based world models with contrastive predictive coding\.InICLR,Cited by:[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- L\. Cohen, K\. Wang, B\. Kang, U\. Gadot, and S\. Mannor \(2025\)Uncovering untapped potential in sample\-efficient world model agents\.arXiv preprint arXiv:2502\.11537\.Cited by:[Appendix E](https://arxiv.org/html/2605.16457#A5.p1.4),[§1](https://arxiv.org/html/2605.16457#S1.p4.1),[§4\.2](https://arxiv.org/html/2605.16457#S4.SS2.p1.1),[§4\.4](https://arxiv.org/html/2605.16457#S4.SS4.p1.1),[Table 6](https://arxiv.org/html/2605.16457#S4.T6.2.2.6.4.1)\.
- M\. Cuturi, L\. Meng\-Papaxanthos, Y\. Tian, C\. Bunne, G\. Davis, and O\. Teboul \(2022\)Optimal transport tools \(ott\): a jax toolbox for all things wasserstein\.arXiv preprint arXiv:2201\.12324\.Cited by:[Appendix B](https://arxiv.org/html/2605.16457#A2.p1.1)\.
- M\. Cuturi \(2013\)Sinkhorn distances: lightspeed computation of optimal transport\.InNeurIPS,Cited by:[Appendix B](https://arxiv.org/html/2605.16457#A2.p1.1),[§2\.4](https://arxiv.org/html/2605.16457#S2.SS4.p2.2)\.
- A\. Dedieu, J\. Ortiz, X\. Lou, C\. Wendelken, J\. S\. Guntupalli, W\. Lehrach, M\. Lazaro\-Gredilla, and K\. P\. Murphy \(2025\)Improving transformer world models for data\-efficient rl\.InICML,Cited by:[item 2a](https://arxiv.org/html/2605.16457#A1.I2.i2.I1.i1.p1.1),[item 2](https://arxiv.org/html/2605.16457#A1.I3.i2.p1.6),[§A\.3\.1](https://arxiv.org/html/2605.16457#A1.SS3.SSS1.p1.1),[§A\.3\.2](https://arxiv.org/html/2605.16457#A1.SS3.SSS2.p1.1),[Appendix C](https://arxiv.org/html/2605.16457#A3.p1.1),[Appendix D](https://arxiv.org/html/2605.16457#A4.p1.1),[§1](https://arxiv.org/html/2605.16457#S1.p1.1),[§1](https://arxiv.org/html/2605.16457#S1.p2.1),[§1](https://arxiv.org/html/2605.16457#S1.p4.1),[§2\.3](https://arxiv.org/html/2605.16457#S2.SS3.p1.10),[§3](https://arxiv.org/html/2605.16457#S3.p2.1),[Figure 5](https://arxiv.org/html/2605.16457#S4.F5),[Figure 5](https://arxiv.org/html/2605.16457#S4.F5.3.2),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px2.p1.1),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px5.p1.1),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px6.p1.1),[§4\.2](https://arxiv.org/html/2605.16457#S4.SS2.p1.1),[§4\.3](https://arxiv.org/html/2605.16457#S4.SS3.p1.1),[§4\.4](https://arxiv.org/html/2605.16457#S4.SS4.p1.1),[Table 1](https://arxiv.org/html/2605.16457#S4.T1.12.10.10.3),[Table 1](https://arxiv.org/html/2605.16457#S4.T1.16.14.14.5),[Table 1](https://arxiv.org/html/2605.16457#S4.T1.17.15.15.1),[Table 2](https://arxiv.org/html/2605.16457#S4.T2.3.1.1.1),[Table 3](https://arxiv.org/html/2605.16457#S4.T3.3.1.1.1),[Table 4](https://arxiv.org/html/2605.16457#S4.T4.2.2.2.3),[Table 5](https://arxiv.org/html/2605.16457#S4.T5.8.8.8.5),[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1),[footnote 1](https://arxiv.org/html/2605.16457#footnote1),[footnote 3](https://arxiv.org/html/2605.16457#footnote3)\.
- J\. Guan, S\. E\. Verch, C\. Voelcker, E\. C\. Jackson, N\. Papernot, and W\. A\. Cunningham \(2023\)Temporal\-difference learning using distributed error signals\.InNeurIPS,Cited by:[§4\.3](https://arxiv.org/html/2605.16457#S4.SS3.p1.1),[Table 5](https://arxiv.org/html/2605.16457#S4.T5.4.4.4.5)\.
- D\. Hafner, J\. Pasukonis, J\. Ba, and T\. Lillicrap \(2023\)Mastering diverse domains through world models\.arXiv preprint arXiv:2301\.04104\.Cited by:[§1](https://arxiv.org/html/2605.16457#S1.p1.1),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px2.p1.1),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px5.p1.1),[Table 1](https://arxiv.org/html/2605.16457#S4.T1.6.4.4.3),[Table 6](https://arxiv.org/html/2605.16457#S4.T6.2.2.3.1.1),[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- D\. Hafner \(2021\)Benchmarking the spectrum of agent capabilities\.arXiv preprint arXiv:2109\.06780\.Cited by:[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px1.p1.1),[footnote 2](https://arxiv.org/html/2605.16457#footnote2)\.
- N\. Hansen, H\. Su, and X\. Wang \(2024\)TD\-mpc2: scalable, robust world models for continuous control\.InICLR,Cited by:[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- P\. Huang, M\. Xu, J\. Zhu, L\. Shi, F\. Fang, and D\. Zhao \(2022\)Curriculum reinforcement learning using optimal transport via gradual domain adaptation\.InNeurIPS,Cited by:[§5\.2](https://arxiv.org/html/2605.16457#S5.SS2.p1.1)\.
- L\. Kaiser, M\. Babaeizadeh, P\. Milos, B\. Osinski, R\. H\. Campbell, K\. Czechowski, D\. Erhan, C\. Finn, P\. Kozakowski, S\. Levine, A\. Mohiuddin, R\. Sepassi, G\. Tucker, and H\. Michalewski \(2020\)Model\-based reinforcement learning for atari\.InICLR,Cited by:[§4\.4](https://arxiv.org/html/2605.16457#S4.SS4.p1.1)\.
- J\. Kim, W\. Choo, and H\. O\. Song \(2020\)Puzzle mix: exploiting saliency and local statistics for optimal mixup\.InICML,Cited by:[§3](https://arxiv.org/html/2605.16457#S3.p7.5)\.
- D\. Kingma and J\. Ba \(2015\)Adam: a method for stochastic optimization”\.InICLR,Cited by:[§A\.2\.1](https://arxiv.org/html/2605.16457#A1.SS2.SSS1.p9.1),[§A\.3\.2](https://arxiv.org/html/2605.16457#A1.SS3.SSS2.p5.1)\.
- P\. Klink, H\. Yang, C\. D’Eramo, J\. Peters, and J\. Pajarinen \(2022\)Curriculum reinforcement learning via constrained optimal transport\.InICML,Cited by:[§5\.2](https://arxiv.org/html/2605.16457#S5.SS2.p1.1)\.
- R\. T\. Lange \(2022\)gymnax: a JAX\-based reinforcement learning environment libraryExternal Links:[Link](http://github.com/RobertTLange/gymnax)Cited by:[§4\.3](https://arxiv.org/html/2605.16457#S4.SS3.p1.1)\.
- M\. Matthews, M\. Beukman, B\. Ellis, M\. Samvelyan, M\. Jackson, S\. Coward, and J\. Foerster \(2024\)Craftax: a lightning\-fast benchmark for open\-ended reinforcement learning\.InICML,Cited by:[§1](https://arxiv.org/html/2605.16457#S1.p4.1),[Figure 3](https://arxiv.org/html/2605.16457#S3.F3),[Figure 3](https://arxiv.org/html/2605.16457#S3.F3.2.1),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px1.p1.1),[§4\.2](https://arxiv.org/html/2605.16457#S4.SS2.p1.1)\.
- V\. Micheli, E\. Alonso, and F\. Fleuret \(2022\)Transformers are sample\-efficient world models\.arXiv preprint arXiv:2209\.00588\.Cited by:[§1](https://arxiv.org/html/2605.16457#S1.p1.1),[§1](https://arxiv.org/html/2605.16457#S1.p2.1),[§3](https://arxiv.org/html/2605.16457#S3.p2.1),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px2.p1.1),[Table 1](https://arxiv.org/html/2605.16457#S4.T1.7.5.5.2),[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- V\. Micheli, E\. Alonso, and F\. Fleuret \(2024\)Efficient world models with context\-aware tokenization\.InICML,Cited by:[§3](https://arxiv.org/html/2605.16457#S3.p2.1),[§4\.1](https://arxiv.org/html/2605.16457#S4.SS1.SSS0.Px2.p1.1),[Table 1](https://arxiv.org/html/2605.16457#S4.T1.8.6.6.1),[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- S\. Moon, J\. Yeom, B\. Park, and H\. O\. Song \(2023\)Discovering hierarchical achievements in reinforcement learning via contrastive learning\.InNeurIPS,Cited by:[§5\.2](https://arxiv.org/html/2605.16457#S5.SS2.p1.1)\.
- F\. Perazzi, J\. Pont\-Tuset, B\. McWilliams, L\. Van Gool, M\. Gross, and A\. Sorkine\-Hornung \(2016\)A benchmark dataset and evaluation methodology for video object segmentation\.InCVPR,Cited by:[§3](https://arxiv.org/html/2605.16457#S3.p2.1)\.
- G\. Peyré and M\. Cuturi \(2019\)Computational optimal transport: with applications to data science\.Foundations and Trends® in Machine Learning11,pp\. 355–206\.External Links:[Document](https://dx.doi.org/10.1561/2200000073)Cited by:[§2\.4](https://arxiv.org/html/2605.16457#S2.SS4.p1.7)\.
- A\. Radford, J\. Wu, R\. Child, D\. Luan, D\. Amodei, and I\. Sutskever \(2019\)Language models are unsupervised multitask learners\.OpenAI Blog1\(8\)\.Cited by:[§A\.2\.1](https://arxiv.org/html/2605.16457#A1.SS2.SSS1.p1.1)\.
- P\. Ramachandran, B\. Zoph, and Q\. V\. Le \(2017\)Searching for activation functions\.arXiv preprint arXiv:1710\.05941\.Cited by:[Appendix D](https://arxiv.org/html/2605.16457#A4.p1.1)\.
- J\. Robine, M\. Höftmann, T\. Uelwer, and S\. Harmeling \(2023\)Transformer\-based world models are happy with 100k interactions\.arXiv preprint arXiv:2303\.07109\.Cited by:[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- A\. Scannell, M\. Nakhaei, K\. Kujanpää, Y\. Zhao, K\. S\. Luck, A\. Solin, and J\. Pajarinen \(2025\)Discrete codebook world models for continuous control\.InICLR,Cited by:[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
- J\. Schulman, F\. Wolski, P\. Dhariwal, A\. Radford, and O\. Klimov \(2017\)Proximal policy optimization algorithms\.arXiv preprint arXiv:1707\.06347\.Cited by:[§A\.3\.2](https://arxiv.org/html/2605.16457#A1.SS3.SSS2.p1.1)\.
- J\. Su, M\. Ahmed, Y\. Lu, S\. Pan, W\. Bo, and Y\. Liu \(2024\)Roformer: enhanced transformer with rotary position embedding\.Neurocomputing568,pp\. 127063\.Cited by:[§2\.2](https://arxiv.org/html/2605.16457#S2.SS2.p1.1)\.
- R\. Sutton and A\. Barto \(2018\)Reinforcement learning: an introduction\.MIT press\.Cited by:[§1](https://arxiv.org/html/2605.16457#S1.p1.1),[§2\.1](https://arxiv.org/html/2605.16457#S2.SS1.p1.19)\.
- S\. Vedula, P\. Rander, R\. Collins, and T\. Kanade \(2005\)Three\-dimensional scene flow\.IEEE Transactions on Pattern Analysis and Machine Intelligence27\(3\),pp\. 475–480\.External Links:[Document](https://dx.doi.org/10.1109/TPAMI.2005.63)Cited by:[§3](https://arxiv.org/html/2605.16457#S3.p2.1)\.
- P\. Wang, S\. Bai, S\. Tan, S\. Wang, Z\. Fan, J\. Bai, K\. Chen, X\. Liu, J\. Wang, W\. Ge,et al\.\(2024\)Qwen2\-vl: enhancing vision\-language model’s perception of the world at any resolution\.arXiv preprint arXiv:2409\.12191\.Cited by:[§2\.2](https://arxiv.org/html/2605.16457#S2.SS2.p2.1)\.
- X\. Wei, X\. Liu, Y\. Zang, X\. Dong, P\. Zhang, Y\. Cao, J\. Tong, H\. Duan, Q\. Guo, J\. Wang,et al\.\(2025\)VideoRoPE: what makes for good video rotary position embedding?\.arXiv preprint arXiv:2502\.05173\.Cited by:[§A\.2\.2](https://arxiv.org/html/2605.16457#A1.SS2.SSS2.p1.5),[§2\.2](https://arxiv.org/html/2605.16457#S2.SS2.p2.1)\.
- K\. Young and T\. Tian \(2019\)MinAtar: an atari\-inspired testbed for thorough and reproducible reinforcement learning experiments\.arXiv preprint arXiv:1903\.03176\.Cited by:[§1](https://arxiv.org/html/2605.16457#S1.p4.1),[§4\.3](https://arxiv.org/html/2605.16457#S4.SS3.p1.1)\.
- W\. Zhang, G\. Wang, J\. Sun, Y\. Yuan, and G\. Huang \(2023\)Storm: efficient stochastic transformer based world models for reinforcement learning\.InNeurIPS,Cited by:[Table 6](https://arxiv.org/html/2605.16457#S4.T6.2.2.4.2.1),[§5\.1](https://arxiv.org/html/2605.16457#S5.SS1.p1.1)\.
## Appendix AAgent Training and Implementation
### A\.1Training Loop
This section outlines the training procedure for the world model and the policy, which are trained concurrently through alternating update steps\. The overall training loop is composed of the following steps:
1. 1\.Environment interaction:Execute the current policy in the real environment and store the resulting experiences in a replay buffer\.
2. 2\.Policy update on real data:Update the policy using the most recent real environment experiences collected in Step 1\. The policy is trained on the data overEenvE\_\{\\text\{env\}\}epochs, with each batch split intoBpolicyB\_\{\\text\{policy\}\}minibatches due to memory constraints\.
3. 3\.Tokenizer training:Sample experiences from the replay buffer to train the nearest neighbor tokenizer\. The tokenizer is updated onUtokenizerU\_\{\\text\{tokenizer\}\}batches of sample trajectories\.
4. 4\.World model training:Sample experiences from the replay buffer to train the transformer world model\. The world model is updated onUWMU\_\{\\text\{WM\}\}batches of sample trajectories, usingBWMB\_\{\\text\{WM\}\}minibatches\.
5. 5\.Policy update in imagination:For training stepst\>Twarmupt\>T\_\{\\text\{warmup\}\}, generateUimagU\_\{\\text\{imag\}\}batches of imagined trajectories using the world model and the current policy, and update the policy on these synthetic rollouts\. During the initialTwarmupT\_\{\\text\{warmup\}\}real environment interactions, this step is skipped to allow the world model to reach sufficient accuracy before generating imaginations\. At the start of each imagination rollout, the policy usesTburnT\_\{\\text\{burn\}\}frames from the replay buffer to initialize its RNN hidden state\.
The overall training loop is repeated until the agent has performed a total ofTtotalT\_\{\\text\{total\}\}real environment interactions\.
### A\.2World Model Network
#### A\.2\.1World Model Architecture
Our transformer world model follows the GPT\-2 architecture\(Radfordet al\.,[2019](https://arxiv.org/html/2605.16457#bib.bib49)\)\. The model operates over tokenized sequences that encode states and actions overTTconsecutive frames\. These tokens are first mapped to 128\-dimensional embeddings via a learned embedding layer\. Absolute positional embeddings are then added, followed by an initial dropout layer\. The resulting embeddings are processed through a stack of three transformer blocks\. Each block consists of the following components:
1. 1\.Layer normalization
2. 2\.Multi\-head attention module, comprising: 1. \(a\)Self\-attention with a block causal mask\. In the block causal mask, tokens within the same timestep are decoded in parallel \(see Figure[7](https://arxiv.org/html/2605.16457#A1.F7)\)\(Dedieuet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib23)\)\. 2. \(b\)A linear projection to the 128\-dimensional embedding space 3. \(c\)Dropout
3. 3\.Residual connection with the block input
4. 4\.Layer normalization
5. 5\.Feed\-forward multilayer perceptron \(MLP\) composed of: 1. \(a\)A hidden layer of dimension 512 2. \(b\)GeLU activation 3. \(c\)Dropout
Figure 7:Comparison between the causal attention mask and the block causal attention mask\. The tokenstis^\{i\}\_\{t\}denotes theii\-th state token at timesteptt,ata\_\{t\}denotes the action,r^t\\hat\{r\}\_\{t\}denotes the predicted reward, andd^t\\hat\{d\}\_\{t\}denotes the predicted done signal\. Only two state tokens are shown per state for simplicity\. \(left\) In the causal mask, each token attends to the tokens preceding it\. The output embeddings of state tokenstis\_\{t\}^\{i\}are used to predict the subsequent state tokensti\+1s\_\{t\}^\{i\+1\}\. The rewardr^t\\hat\{r\}\_\{t\}and done signald^t\\hat\{d\}\_\{t\}are predicted fromata\_\{t\}, and the output ofstLs\_\{t\}^\{L\}is unused\. \(right\) In the block causal mask, all state and action tokens in the same timestep attend to each other, and they are used to predict the corresponding token in the next timestep \(ata\_\{t\}predictsr^t\\hat\{r\}\_\{t\}andd^t\\hat\{d\}\_\{t\}\)\. This allows each frame to be predicted in parallel rather than token\-by\-token\.After processing through the final block, the output undergoes a final layer normalization and is then passed to three separate prediction heads: one for the next state tokens, one for the reward signal, and one for the done signal\. We denote the output embeddings as
\(E11,…,E1L\+1,E21,…,E2L\+1,…,ET1,…,ETL\+1\)\.\(E^\{1\}\_\{1\},\\ldots,E^\{L\+1\}\_\{1\},E^\{1\}\_\{2\},\\ldots,E^\{L\+1\}\_\{2\},\\ldots,E^\{1\}\_\{T\},\\ldots,E^\{L\+1\}\_\{T\}\)\.
whereLLrepresents the number of state tokens per frame, andEtiE^\{i\}\_\{t\}corresponds to theii\-th output embedding at timesteptt\. These embeddings are routed to prediction heads as follows:
1. 1\.Fori≤Li\\leq L, the embeddingEtiE^\{i\}\_\{t\}is input to the observation head, an MLP comprising a 128\-dimensional linear layer, a ReLU activation, and a final linear layer projecting to the codebook sizeKK\. The output logits define a categorical distribution over theKKpossible values of the predicted state tokenst\+1is^\{i\}\_\{t\+1\}\.
2. 2\.The embeddingEtL\+1E^\{L\+1\}\_\{t\}, corresponding to the position of the action token, is passed to both the reward and done heads\. Each head is an MLP consisting of a 128\-dimensional linear layer, a ReLU activation, and a final linear layer projecting to two output classes\. Although the Craftax\-classic environment defines reward values of−0\.1\-0\.1,0\.10\.1, and1\.01\.0, we followDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)and binarize the reward signal to improve stability, ignoring the−0\.1\-0\.1and0\.10\.1cases\.
The model is trained on trajectories of lengthTWMT\_\{\\text\{WM\}\}sampled from the replay buffer\. The total loss is the sum of three components:
1. 1\.Cross\-entropy loss over next\-state token predictions \(acrossKKclasses\)\.
2. 2\.Cross\-entropy loss for binary reward classification \(0 or 1\)\.
3. 3\.Cross\-entropy loss for done signal prediction\.
Optimization is performed using the Adam optimizer with gradient norm clipping to stabilize training\(Kingma and Ba,[2015](https://arxiv.org/html/2605.16457#bib.bib50)\)\. Hyperparameters for architecture and training are provided in Table[7](https://arxiv.org/html/2605.16457#A1.T7)\.
Table 7:World model hyperparameters\. Sweep range indicates the values tried per hyperparameter, with the final Value being chosen based on highest return\.
#### A\.2\.2Transformer RoPE Implementation
Our implementation of 3D RoPE is based on VideoRoPE\(Weiet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib25)\)\. While RoPE rotates pairs of embedding dimensions using frequencies based on a 1D position index, our implementation modulates the rotation amount based on three indices, two spatial and one temporal\. We divide dimension pairs in a 3:1 ratio between spatial and temporal encoding\. Pairs associated with lower rotation frequencies are used for temporal encoding and are rotated based on the temporal index\. In contrast, pairs with higher rotation frequencies are used for spatial encoding\. Given the 2D nature of spatial positions, spatial pairs are further split evenly between the horizontal and vertical axes\. These are interleaved across the embedding dimension to ensure balanced representation\. As a result, the axes contributing to rotation follow the pattern\(x,y,x,y,…,x,y,t,t,…,t\)\(x,y,x,y,\\ldots,x,y,t,t,\\ldots,t\), ordered by decreasing rotation frequency\. The positional encoding is implemented by applying block\-diagonal rotation matrices to the query and key vectors\. The matrix𝐑xy\\mathbf\{R\}\_\{xy\}applies higher\-frequency rotations parameterized by the spatial coordinates\(x,y\)\(x,y\), while𝐑t\\mathbf\{R\}\_\{t\}applies lower\-frequency rotations parameterized by the temporal indextt\. These rotations follow the standard RoPE formulation extended to two spatial dimensions and one temporal dimension\.
𝑹xy=\\displaystyle\\bm\{R\}\_\{xy\}=\(cosθ0x−sinθ0x00⋯00sinθ0xcosθ0x00⋯0000cosθ1y−sinθ1y⋯0000sinθ1ycosθ1y⋯00⋮⋮⋮⋮⋱⋮⋮0000⋯cosθk−1y−sinθk−1y0000⋯sinθk−1ycosθk−1y\)\\displaystyle\\tiny\\begin\{pmatrix\}\\cos\\theta\_\{0\}x&\-\\sin\\theta\_\{0\}x&0&0&\\cdots&0&0\\\\ \\sin\\theta\_\{0\}x&\\cos\\theta\_\{0\}x&0&0&\\cdots&0&0\\\\ 0&0&\\cos\\theta\_\{1\}y&\-\\sin\\theta\_\{1\}y&\\cdots&0&0\\\\ 0&0&\\sin\\theta\_\{1\}y&\\cos\\theta\_\{1\}y&\\cdots&0&0\\\\ \\vdots&\\vdots&\\vdots&\\vdots&\\ddots&\\vdots&\\vdots\\\\ 0&0&0&0&\\cdots&\\cos\\theta\_\{k\-1\}y&\-\\sin\\theta\_\{k\-1\}y\\\\ 0&0&0&0&\\cdots&\\sin\\theta\_\{k\-1\}y&\\cos\\theta\_\{k\-1\}y\\\\ \\end\{pmatrix\}𝑹t=\\displaystyle\\bm\{R\}\_\{t\}=\(cosθkt−sinθkt⋯00sinθktcosθkt⋯00⋮⋮⋱⋮⋮00⋯cosθD/2−1t−sinθD/2−1t00⋯sinθD/2−1tcosθD/2−1t\)\\displaystyle\\tiny\\begin\{pmatrix\}\\cos\\theta\_\{k\}t&\-\\sin\\theta\_\{k\}t&\\cdots&0&0\\\\ \\sin\\theta\_\{k\}t&\\cos\\theta\_\{k\}t&\\cdots&0&0\\\\ \\vdots&\\vdots&\\ddots&\\vdots&\\vdots\\\\ 0&0&\\cdots&\\cos\\theta\_\{D/2\-1\}t&\-\\sin\\theta\_\{D/2\-1\}t\\\\ 0&0&\\cdots&\\sin\\theta\_\{D/2\-1\}t&\\cos\\theta\_\{D/2\-1\}t\\end\{pmatrix\}
Given a query vector𝐪i\\mathbf\{q\}\_\{i\}for tokeniiand a key vector𝐤j\\mathbf\{k\}\_\{j\}for tokenjj, their rotary embeddings are obtained by applying the corresponding spatial and temporal rotations\. Letx\(i\)x\(i\),y\(i\)y\(i\), andt\(i\)t\(i\)denote the spatial and temporal coordinates of theii\-th token\. Then the transformed query and key vectors are:
𝐪i′=\\displaystyle\\mathbf\{q\}\_\{i\}^\{\\prime\}=\(𝑹x\(i\)y\(i\)𝟎𝟎𝑹t\(i\)\)𝐪i\\displaystyle\\begin\{pmatrix\}\\bm\{R\}\_\{x\(i\)y\(i\)\}&\\bm\{0\}\\\\ \\bm\{0\}&\\bm\{R\}\_\{t\(i\)\}\\end\{pmatrix\}\\mathbf\{q\}\_\{i\}𝐤j′=\\displaystyle\\mathbf\{k\}\_\{j\}^\{\\prime\}=\(𝑹x\(j\)y\(j\)𝟎𝟎𝑹t\(j\)\)𝐤j\.\\displaystyle\\begin\{pmatrix\}\\bm\{R\}\_\{x\(j\)y\(j\)\}&\\bm\{0\}\\\\ \\bm\{0\}&\\bm\{R\}\_\{t\(j\)\}\\end\{pmatrix\}\\mathbf\{k\}\_\{j\}\.𝐪i′⊤𝐤j′=\\displaystyle\\mathbf\{q\}\_\{i\}^\{\\prime\\top\}\\mathbf\{k\}\_\{j\}^\{\\prime\}=𝐪i⊤\(𝑹x\(j\)−x\(i\),y\(j\)−y\(i\)𝟎𝟎𝑹t\(j\)−t\(i\)\)𝐤j\\displaystyle\\mathbf\{q\}\_\{i\}^\{\\top\}\\begin\{pmatrix\}\\bm\{R\}\_\{x\(j\)\-x\(i\),y\(j\)\-y\(i\)\}&\\bm\{0\}\\\\ \\bm\{0\}&\\bm\{R\}\_\{t\(j\)\-t\(i\)\}\\end\{pmatrix\}\\mathbf\{k\}\_\{j\}
As action tokens lack inherent spatial coordinates, assigning them fixed spatial positions would limit the effectiveness of 3D RoPE across the majority of embedding dimensions\. To address this, spatial coordinates for action tokens are defined along the diagonal,\(t,t\)\(t,t\), wherettrepresents the temporal index\. State tokens are assigned spatial coordinates offset from this diagonal,\(x\+t,y\+t\)\(x\+t,y\+t\), ensuring temporal alignment with action tokens while preserving spatial variation\.
To avoid positional collisions between state and action tokens, they are given different temporal indices\. That is, the state and action tokensst1,…stL,at,st\+11,…,st\+1L,at\+1s^\{1\}\_\{t\},\\ldots s^\{L\}\_\{t\},a\_\{t\},s^\{1\}\_\{t\+1\},\\ldots,s^\{L\}\_\{t\+1\},a\_\{t\+1\}are given temporal indices2t,…,2t,2t\+1,2\(t\+1\),…2\(t\+1\),2\(t\+1\)\+12t,\\ldots,2t,2t\+1,2\(t\+1\),\\ldots 2\(t\+1\),2\(t\+1\)\+1\. This staggered assignment ensures that each token occupies a unique spatio\-temporal location, maintaining positional distinctiveness throughout the sequence\.
### A\.3Policy Network
#### A\.3\.1Policy Network Architecture
We adopt the policy network architecture introduced inDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\), which comprises three primary components: a convolutional encoder, a recurrent neural network \(RNN\), and separate MLP heads for action and value prediction\.
The convolutional encoder consists of three convolutional blocks with channel sizes \[64, 64, 128\]\. Each block contains an instance normalization layer, a3×33\\times 3convolutional layer with stride 1, a3×33\\times 3max\-pooling layer with stride 2, and two ResNet\-style sub\-blocks\. Each ResNet block includes a ReLU activation, instance normalization, a 3×3 convolution with stride 1, and a skip connection to preserve the input\. The encoder produces an output of shape8×8×1288\\times 8\\times 128, which is flattened into a 8192\-dimensional vector, denoted byzz\. The vectorzzis then projected into a 256\-dimensional representation through a ReLU activation, a linear layer, and layer normalization\. This projected representation serves as input to a GRU recurrent module, which outputs a vectory∈ℝ256y\\in\\mathbb\{R\}^\{256\}along with the updated hidden stateh∈ℝ256h\\in\\mathbb\{R\}^\{256\}\.
The action and value heads share an identical structure except for the final output projection\. Each head takes the concatenated vector\[z,y\]\[z,y\]as input and applies a sequence of transformations: ReLU activation, layer normalization, a linear projection to 2048, another ReLU activation, and a residual block composed of two linear layers with ReLU activations\. The output is passed through a final layer normalization, followed by the task\-specific output projection—either to action logits or a scalar value estimate\.
#### A\.3\.2Policy Training
We follow the policy training procedure described inDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\), using Proximal Policy Optimization \(PPO\)\(Schulmanet al\.,[2017](https://arxiv.org/html/2605.16457#bib.bib38)\)as the underlying policy gradient algorithm\.
Let the trajectory be denoted asτ=\(o1:T\+1,a1:T,r1:T,d1:T,h0:T\)\\tau=\(o\_\{1:T\+1\},a\_\{1:T\},r\_\{1:T\},d\_\{1:T\},h\_\{0:T\}\), whereoto\_\{t\}represents the observations,ata\_\{t\}the actions,rtr\_\{t\}the rewards,dtd\_\{t\}the done signals, andhth\_\{t\}the hidden states of the RNN\. At each timestep, PPO computes the value estimatesv1:T\+1=VΦold\(o1:T\+1\)v\_\{1:T\+1\}=V\_\{\\Phi\_\{\\textrm\{old\}\}\}\(o\_\{1:T\+1\}\)and the action probabilitiesπΦold\(at\|ot\)\\pi\_\{\\Phi\_\{\\textrm\{old\}\}\}\(a\_\{t\}\|o\_\{t\}\)under the current fixed parametersΦold\\Phi\_\{\\textrm\{old\}\}\. The policy is optimized by minimizing the following PPO objective:
ℒPPO\(Φ\)=1T∑t=1T\{\\displaystyle\\mathcal\{L\}\_\{\\textrm\{PPO\}\}\(\\Phi\)=\\frac\{1\}\{T\}\\sum\_\{t=1\}^\{T\}\\Big\\\{−min\(pt\(Φ\)At,clip\(pt\(Φ\)\)At\)\\displaystyle\-\\min\\left\(p\_\{t\}\(\\Phi\)A\_\{t\},\\textrm\{clip\}\(p\_\{t\}\(\\Phi\)\)A\_\{t\}\\right\)\+λTD\(VΦ\(ot\)−qt\)2\\displaystyle\+\\lambda\_\{\\textrm\{TD\}\}\(V\_\{\\Phi\}\(o\_\{t\}\)\-q\_\{t\}\)^\{2\}−λentℋ\(πΦ\(\.\|ot\)\)\}\\displaystyle\-\\lambda\_\{\\textrm\{ent\}\}\\mathcal\{H\}\\left\(\\pi\_\{\\Phi\}\(\.\|o\_\{t\}\)\\right\)\\Big\\\}wherept\(Φ\)p\_\{t\}\(\\Phi\)is the probability ratioπΦ\(at\|ot\)πΦold\(at\|ot\)\\frac\{\\pi\_\{\\Phi\}\(a\_\{t\}\|o\_\{t\}\)\}\{\\pi\_\{\\Phi\_\{\\textrm\{old\}\}\}\(a\_\{t\}\|o\_\{t\}\)\}andclip\(x\)\\textrm\{clip\}\(x\)is the clipping functionmin\(max\(x,1−ϵ\),1\+ϵ\)\\min\(\\max\(x,1\-\\epsilon\),1\+\\epsilon\)\. Here,AtA\_\{t\}denotes a generalized advantage estimation,qtq\_\{t\}is a temporal difference \(TD\) target, andℋ\\mathcal\{H\}is the entropy operator\. The advantagesAtA\_\{t\}and targetsqtq\_\{t\}are computed as
At=δt\+\(1−donet\)γλAt\+1,A\_\{t\}=\\delta\_\{t\}\+\(1\-\\textrm\{done\}\_\{t\}\)\\gamma\\lambda A\_\{t\+1\},\\;qt=At\+vt,q\_\{t\}=A\_\{t\}\+v\_\{t\},whereδt=rt\+\(1−donet\)γvt\+1−vt\\delta\_\{t\}=r\_\{t\}\+\(1\-\\textrm\{done\}\_\{t\}\)\\gamma v\_\{t\+1\}\-v\_\{t\}\.
We incorporate two modifications to the standard PPO implementation:
- •Generalized advantage estimatesAtA\_\{t\}are standardized across training batches to stabilize learning\.
- •We track the moving average of the mean and standard deviation ofqtq\_\{t\}, with discount factorα\\alpha, and train the value function to predict the standardized targets\.
Optimization is performed using the Adam optimizer with gradient norm clipping to stabilize training\(Kingma and Ba,[2015](https://arxiv.org/html/2605.16457#bib.bib50)\)\. Hyperparameters for architecture and training are provided in Table[8](https://arxiv.org/html/2605.16457#A1.T8)\.
Table 8:Policy hyperparameters\. Sweep range indicates the values tried per hyperparameter, with the final Value being chosen based on highest return\.AreaHyperparameterValueSweep rangeEnvironmentEnvironment interactionsTtotalT\_\{\\text\{total\}\}1,000,0001\{,\}000\{,\}000Warmup interactionsTwarmupT\_\{\\text\{warmup\}\}50,00050\{,\}000\{50k, 100k, 200k\}Number of environments \(batch size\)4848Rollout horizon in environment9696Rollout horizon in imaginationTWMT\_\{\\text\{WM\}\}2020Burn\-in horizon for RNN in imaginationTburnT\_\{\\text\{burn\}\}55TrainingNumber of updates in imaginationUimagU\_\{\\text\{imag\}\}300300\{150, 300, 600, 1200\}Number of epochs in environmentEenvE\_\{\\text\{env\}\}44Number of epochs in imagination11Number of minibatches in environmentBpolicyB\_\{\\text\{policy\}\}88Number of minibatches in imagination11PPODiscount factorγ\\gamma0\.9250\.925TD weightλ\\lambda0\.6250\.625Clipping valueϵ\\epsilon0\.20\.2TD loss coefficientλTD\\lambda\_\{\\textrm\{TD\}\}2\.02\.0Entropy loss coefficientλent\\lambda\_\{\\textrm\{ent\}\}0\.010\.01PPO target discount factorα\\alpha0\.950\.95OptimizationOptimizerAdamLearning rate0\.000450\.00045Max norm for gradient clipping0\.50\.5
## Appendix BOptimal Transport Implementation
This section describes various implementation details of using optimal transport, including the definition of distance cost, using the outputs, and hyperparameter search\. Algorithm[3](https://arxiv.org/html/2605.16457#alg3)describes the Sinkhorn algorithm\(Cuturi,[2013](https://arxiv.org/html/2605.16457#bib.bib37)\)\. We use the OTT\-JAX library for our Sinkhorn solver implementation\(Cuturiet al\.,[2022](https://arxiv.org/html/2605.16457#bib.bib60)\)\.
##### Distance Cost
For the affinity matrix𝑨\(prev\)\\bm\{A\}^\{\(prev\)\}in Equation[1](https://arxiv.org/html/2605.16457#S3.E1), we include a distance penaltyD\(\(xi,yi\),\(xj,yj\)\)D\(\(x\_\{i\},y\_\{i\}\),\(x\_\{j\},y\_\{j\}\)\)\. This choice of distance cost can be adapted based on the environment\. For our environments, we impose a movement constraint by defining the distance cost as
D\(\(xi,yi\),\(xj,yj\)\)\\displaystyle D\(\(x\_\{i\},y\_\{i\}\),\(x\_\{j\},y\_\{j\}\)\)=\{d,ifd≤4\+∞otherwise,\\displaystyle=\\begin\{cases\}d,&\\text\{ if \}d\\leq 4\\\\ \+\\infty&\\text\{ otherwise, \}\\end\{cases\}whered\\displaystyle\\text\{where \}d=‖\(xi,yi\)−\(xj,yj\)‖22\.\\displaystyle=\\\|\(x\_\{i\},y\_\{i\}\)\-\(x\_\{j\},y\_\{j\}\)\\\|\_\{2\}^\{2\}\.
This constraint reflects the fact that in environments like Craftax\-classic, Craftax, and MinAtar, a token’s spatial displacement between consecutive frames is limited to a maximum of two positions in any direction\. For Craftax\-classic and Craftax, this accounts for the potential movement of one by the player and one by a creature token, if applicable\.
Algorithm 3SinkhornimplementationInput:Cost matrix
𝑪\\bm\{C\}, Sinkhorn regularization parameter
ϵ\\epsilon, Number of Sinkhorn iterations
TT
Output:Optimal transport plan
𝑷\\bm\{P\}
𝑲=exp\(𝑪/ϵ\)\\bm\{K\}=\\exp\(\\bm\{C\}/\\epsilon\)
Set uniform marginals:
𝐫=1rows\(𝑪\)\\mathbf\{r\}=\\frac\{1\}\{\\text\{rows\}\(\\bm\{C\}\)\},
𝐜=1cols\(𝑪\)\\mathbf\{c\}=\\frac\{1\}\{\\text\{cols\}\(\\bm\{C\}\)\}
Initialize dual variables:
𝐮=𝟏\\mathbf\{u\}=\\mathbf\{1\},
𝐯=𝟏\\mathbf\{v\}=\\mathbf\{1\}
for
t=1t=1to
TTdo
𝐮=𝐫⊘\(𝑲𝐯\)\\mathbf\{u\}=\\mathbf\{r\}\\oslash\(\\bm\{K\}\\mathbf\{v\}\)\{⊘\\oslashdenotes element\-wise division\}
𝐯=𝐜⊘\(𝑲⊤𝐮\)\\mathbf\{v\}=\\mathbf\{c\}\\oslash\(\\bm\{K\}^\{\\top\}\\mathbf\{u\}\)
endfor
Return
𝑷=diag\(𝐮\)⋅𝑲⋅diag\(𝐯\)\\bm\{P\}=\\text\{diag\}\(\\mathbf\{u\}\)\\cdot\\bm\{K\}\\cdot\\text\{diag\}\(\\mathbf\{v\}\)
##### Choosing Between Transformer and Optimal Transport Output
Optimal transport provides an effective mechanism for reusing tokens from the previous frame\. However, it is less effective in scenarios where novel tokens must be introduced, such as when the agent moves to a previously unexplored area\. In such cases, optimal transport may fail to consistently route wildcard entries to the appropriate newly generated tokens\. Conversely, the transformer world model is capable of freely generating new tokens as needed, but lacks a mechanism for directly reusing tokens from prior frames\. Rather than committing to a single output modality, we adopt a hybrid strategy for Craftax\-classic and Craftax that selects between optimal transport and transformer outputs based on spatial position\. In Craftax\-classic and Craftax, new visual content appears along the screen boundaries as the player explores previously unseen regions\. Additionally, the inventory interface—fixed at the bottom of the screen—requires updates to token values without positional shifts\. To accommodate these patterns, we apply the optimal transport output to the central region of the screen, where token reuse is most appropriate, while using the transformer’s predictions for the screen edges and inventory regions, where new content is more likely\. For MinAtar, which does not have special behavior at the edges, the optimal transport output is used directly for the entire screen\.
##### Choosing Hyperparameters for Optimal Transport
Two hyperparameters are introduced in optimal transport\-based token correspondence:cdc\_\{d\}, a coefficient for the distance cost, andcwc\_\{w\}, a constant penalty for using a wildcard token\. The best hyperparameters are chosen by grid search, but optimal transport\-based token correspondence is robust to varying choices as shown in Tables[9](https://arxiv.org/html/2605.16457#A2.T9)and[10](https://arxiv.org/html/2605.16457#A2.T10)\.
Table 9:Average returns and scores with respect tocdc\_\{d\}, a coefficient of cost for distance\.Table 10:Average returns and scores with respect tocwc\_\{w\}, a constant penalty for using a wildcard token\.Table 11:Hyperparameter differences between Craftax\-classic and Craftax\.
## Appendix CCraftax Hyperparameters
FollowingDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\), we change some hyperparameters for Craftax, as shown in Table[11](https://arxiv.org/html/2605.16457#A2.T11)\. In particular, to accommodate the larger screen and additional tokens in memory, the batch size and replay buffer size are reduced\.
## Appendix DMinAtar Return Curves and Hyperparameters
Figure[8](https://arxiv.org/html/2605.16457#A4.F8)shows the return curves of ITC and baselineDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\)for each game in MinAtar\. ITC outperforms the baseline in every game\. Table[12](https://arxiv.org/html/2605.16457#A4.T12)lists the hyperparameters for MinAtar with different values compared to Craftax\-classic\. All hyperparameter changes followDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\), except the hyperparameters specific to optimal transport\. Also followingDedieuet al\.\([2025](https://arxiv.org/html/2605.16457#bib.bib23)\), the policy encoder uses layer normalization and the Swish activation function, and actor and value networks share weights except in their final linear layers\(Baet al\.,[2016](https://arxiv.org/html/2605.16457#bib.bib58); Ramachandranet al\.,[2017](https://arxiv.org/html/2605.16457#bib.bib59)\)\.
Figure 8:Return curves for MinAtar\. Shading indicates standard error among multiple seeds\. The vertical dashed lines indicate the start of training in imagination afterTwarmupT\_\{\\text\{warmup\}\}interactions\.Table 12:Hyperparameter differences between Craftax\-classic and MinAtar\.KKis the number of object types for each game \(4 for Asterix, 4 for Breakout, 7 for Freeway, and 6 for SpaceInvaders\)\. The number of actionsAAis 5 for Asterix, 3 for Breakout, 3 for Freeway, and 4 for SpaceInvaders\.AreaHyperparameterCraftax\-classicMinAtarEnvironmentObservation shape63×63×363\\times 63\\times 310×10×K10\\times 10\\times KNumber of possible actions1717AAWarmup interactionsTwarmupT\_\{\\text\{warmup\}\}50,00050\{,\}000200,000200\{,\}000TokenizerSingle patch shape7×7×37\\times 7\\times 32×2×K2\\times 2\\times KArchitectureState tokens per frameLL8125Optimal transportDistance cost coefficientcdc\_\{d\}0\.60\.2Wildcard costcwc\_\{w\}0\.30\.05TrainingNumber of world model updatesUWMU\_\{\\text\{WM\}\}50050020002000Number of policy updates in imaginationUimagU\_\{\\text\{imag\}\}30030020002000Coefficient for reward prediction loss111010Coefficient for done prediction loss111010PPODiscount factorγ\\gamma0\.9250\.9250\.950\.95TD weightλ\\lambda0\.6250\.6250\.750\.75Entropy loss coefficientλent\\lambda\_\{\\textrm\{ent\}\}in imagination0\.010\.010\.050\.05PPO target discount factorα\\alpha0\.950\.950\.9250\.925
## Appendix EAdditional Atari 100K Results and Hyperparameters
Table[13](https://arxiv.org/html/2605.16457#A5.T13)shows the average return for each of 26 games in Atari 100K\. We follow the Simulus hyperparameter settings\(Cohenet al\.,[2025](https://arxiv.org/html/2605.16457#bib.bib52)\), and additionally include our proposed hyperparameters: distance cost coefficientcdc\_\{d\}and wildcard costcwc\_\{w\}, set to0\.050\.05and0\.010\.01, respectively\.
Table 13:Mean returns on the 26 games of the Atari 100k benchmark followed by averaged human\-normalized performance metrics\. Each game score is computed as the average of 5 runs with different seeds\. Bold face mark the best score\.Similar Articles
Agentic RL: Token-In, Token-Out Done Right (16 minute read)
This article explains the 'Token-In, Token-Out' (TITO) invariant in reinforcement learning for LLMs, highlighting a common error when training multi-turn agents with tool calls. It presents two solutions: using per-model renderers or designing training to avoid re-encoding decoded tokens, emphasizing prefix-preserving chat templates.
Continuity and Ordinality Matter: Constraining Time Series Tokens for Effective Time Series Analysis with Large Language Models
This paper proposes COM, a method that enforces continuity and ordinality constraints on time series token embeddings to improve the performance of token-based time series large language models.
Adaptive Computation Depth via Learned Token Routing in Transformers
This paper presents Token-Selective Attention (TSA), a differentiable token routing mechanism that learns to skip unnecessary computations per token in transformer layers, reducing token-layer operations by 14–23% with minimal quality loss on language modeling tasks.
TONIC: Token-Centric Semantic Communication for Task-Oriented Wireless Systems
This paper proposes TONIC, a token-centric semantic communication framework for task-oriented wireless systems that assigns utility-aware unequal error protection to tokens and uses confidence-aware gating with a Transformer-based completion model, outperforming baselines on image classification.
World Machine: Towards Generative World Modeling for Time-Series
World Machine proposes a transformer-based generative world modeling architecture for time series that uses latent states to adapt to varying context lengths, addressing the quadratic memory cost of traditional transformers. Experiments on a synthetic dataset validate its feasibility and show improvements over conventional transformers.