Autoregressive next token prediction and KV Cache in transformers

Hacker News Top Tools

Summary

Explains autoregressive next token prediction in transformers and the KV cache optimization technique used to speed up token generation.

No content available
Original Article
View Cached Full Text

Cached at: 05/20/26, 02:27 PM

# Autoregressive next token prediction & KV Cache in transformers Source: [https://medium.com/advanced-deep-learning/autoregressive-next-token-prediction-kv-cache-in-transformers-afad22285baf](https://medium.com/advanced-deep-learning/autoregressive-next-token-prediction-kv-cache-in-transformers-afad22285baf) [![Frederik vom Lehn](https://miro.medium.com/v2/resize:fill:64:64/1*27E7pgzehiQOojn7I9NEbg.png)](https://medium.com/@frederik.vl?source=post_page---byline--afad22285baf---------------------------------------) Understand the optimization technique in LLMs to speed up token generation Press enter or click to view image in full size The general overview \(Image by author\)\. ## The Big Picture Before we dive into attention heads, KV caches, and the mechanics of generation, it helps to zoom out and see what an autoregressive language model actually*is*at a glance\. A prompt enters as plain text: “How are you?”\. A tokenizer chops it into vocabulary IDs — here`3, 7, 1, 9`, prefixed with a BOS \("beginning of sequence"\) token\. Each ID is just an integer pointing into a**lookup table**: a learned matrix of shape`\(vocab\_size, c\)`where every row is the embedding vector for one token in the vocabulary\. Selecting the rows for our 5 input IDs produces`X`, a`\(5, 4\)`matrix, five tokens, each living in a 4\-dimensional embedding space\. This is where text leaves the world of symbols and enters the world of vectors\. We use toy dimensions for our examples here\. From here,`X`flows through a stack of**decoder blocks**\. Each block is the same architecture, multi\-head self\-attention followed by an MLP, and each block transforms its input into a refined`\(5, 4\)`representation of the same shape\. The trick that makes deep transformers trainable is the**residual connection**wrapped around every block: instead of replacing the input, each block*adds*to it \(`X₁ = X \+ block\_output`\)\. Information flows along a continuous "residual stream" that each layer edits rather than overwrites\. Stack three of these and you get`X₃`, the final hidden state\. The last step inverts the first\. The**unembedding matrix,**often the lookup table transposed, since input and output vocabularies are the same, projects each row of`X₃`back into vocabulary space, producing a`\(5, 12\)`logits matrix: a score for every vocabulary token at every position\. For next\-token generation, only the last row matters\. Its argmax is the token the model wants to say next\. Here, that's token ID 5\. That’s the whole forward pass at altitude\. The rest of this article zooms in on what happens inside one of those decoder blocks and on the optimization,**KV caching**, that makes generating long sequences feasible at all\. Let's zoom in and check what happens inside one layer during the first forward pass inside a single decoding layer\. Press enter or click to view image in full size The Prefill Forward Pass \(Image by author\) ## The Prefill Forward Pass Before a language model can generate a single new token, it has to process the prompt\. This step \(**prefill\)**runs the entire input sequence through the network in one parallel forward pass\. Its job is twofold: produce the first predicted token, and populate the KV cache so that subsequent decode steps stay cheap\. Let’s walk through what happens to a 5\-token prompt in a tiny model with hidden dimension`c = 4`, 2 attention heads, and a vocabulary of 12 tokens\. ### From tokens to Q, K, V The input`X`arrives as a`\(5, 4\)`matrix: 5 tokens, each represented by a 4\-dimensional embedding pulled from the lookup table\. Three learned projection matrices`Wq`,`Wk`,`Wv`, each of shape`\(4, 4\)`, transform`X`into the query, key, and value matrices`Q`,`K`,`V`, all of shape`\(5, 4\)`\. Because we have 2 heads, each`\(5, 4\)`matrix is split column\-wise into two`\(5, 2\)`slices, one slice per head\. Each head will compute attention independently in its own 2\-dimensional subspace\. ### Attention within a head Inside a single head, attention is a weighted lookup\. The head’s`Q`slice`\(5, 2\)`is multiplied by the transpose of its`K`slice to produce a`\(5, 5\)`matrix of attention scores — every token's query dotted with every token's key\. After scaling and softmax \(and a causal mask, since this is an autoregressive model, token*t*must not see tokens \>*t*\), each row of this matrix becomes a probability distribution over "which past tokens should I pull information from\." These weights then multiply the head’s`V`slice`\(5, 2\)`, yielding the head's output of shape`\(5, 2\)`: each token now holds a context\-aware mix of value vectors from its allowed positions\. ### Concatenation and projection The two heads’ outputs are concatenated back into a`\(5, 4\)`matrix, then passed through an output projection`\(4, 4\)`\. The result,`X'`, is again`\(5, 4\),`same shape as the input, but every row now reflects information gathered from across the sequence\. ### The MLP Each token’s vector is then sent independently through a two\-layer MLP\.`W\_up`of shape`\(4, 8\)`expands each row to 8 dimensions, GeLU adds non\-linearity, and`W\_down`of shape`\(8, 4\)`projects back down\. The output`X₁`is`\(5, 4\)`and in a real model, this would feed into the next transformer block\. Stack a few of these \(here, 3 layers\) and you have the full forward pass\. Lets assume this is the final layer here\. ### Logits and the first prediction After the final layer, the`\(5, 4\)`hidden states are multiplied by the unembedding matrix`\(12, 4\)\.T`to produce logits of shape`\(5, 12\)`, a score for every vocabulary token at every position\. For generation, only the**last row**matters: it tells us what the model thinks comes after token 5\. Argmax \(or sampling\) over that row gives us the first generated token\. In our case token ID 5\. ### What the cache holds onto Here’s the quiet but crucial part: during this single pass, every layer computed`K`and`V`of shape`\(5, 4\)`for the prompt\. Those tensors get**stored**\. They are everything future tokens will ever need to know about the prompt at this layer\. The embeddings, the queries, the MLP activations — all discarded\. From here on, generation moves into decode mode, processing one new token at a time and reading from this cache instead of redoing the work\. So now let’s understand the big picture, what happens when we generate the next token with KV cache\. Second Forward Pass with KV Cache \(Image by author\)## The Decode Step with KV Cache Once prefill is done, the model switches into**decode mode**\. Every subsequent token is generated by a forward pass that looks structurally similar to prefill — but operates on just*one*row at a time, leaning on the KV cache to remember everything that came before\. Let’s continue our example\. Prefill predicted token 5, so we now feed token 5 back in as the input for the next step\. ### One token in, one token out The new input`X`is a single row of shape`\(1, 4\)`which is just token 5's embedding, looked up from the same table used during prefill\. The previous 5 tokens of the prompt are**not**re\-fed\. They don't need to be: everything the model will ever need from them at this layer is already sitting in the cache\. Multiplying this`\(1, 4\)`row by`Wq`,`Wk`,`Wv`\(each still`\(4, 4\)`\) yields a fresh`Q`,`K`, and`V`, each of shape`\(1, 4\)`\. Only the new token gets its query, key, and value computed\. ### Appending to the cache The newly computed`K`and`V`rows are appended to the cached`K`and`V`matrices from the previous step\. The cache, which held`\(5, 4\)`after prefill, now holds`\(6, 4\),`five rows from the prompt plus one fresh row for token 5\. This concatenated tensor is what attention will read against\. ### Attention against the cache Splitting across heads as before, each head now has a query of shape`\(1, 2\)`and a full key/value matrix of shape`\(6, 2\)`\. The dot product`Q · K^T`produces a`\(1, 6\)`score row — token 5's attention weights over all 6 positions, itself included\. No causal mask is needed here: every cached position is in the past by construction, so every score is valid\. Softmax turns this into a probability distribution, and the weighted sum over`V``\(6, 2\)`produces a`\(1, 2\)`head output\. Concatenating both heads gives`\(1, 4\)`, and the output projection`\(4, 4\)`yields`X'`of shape`\(1, 4\)`\. ### Why this matters Compare the shapes\. Prefill processed a`\(5, 4\)`input and ran every operation on 5 rows in parallel, which is necessary to populate the cache\. Decode processes a`\(1, 4\)`input and runs every operation on a single row, with the cache silently providing the historical context where it's needed \(inside attention\)\. The MLP, the projections, the unembedding, all do`1/N`of the work they'd do in a no\-cache forward pass\. This is the whole reason long\-context generation is tractable\. Without the KV cache, every new token would mean redoing the entire prefill, slightly longer each time, the cost of generating N tokens would grow quadratically\. With it, each new token costs roughly the same amount of compute, plus a cheap attention sum over a growing cache\. Generating a token is, at its core, a small amount of fresh work standing on the shoulders of a lot of remembered work\.

Similar Articles

Models Take Notes at Prefill: KV Cache Can Be Editable and Composable

arXiv cs.LG

This paper proposes that the KV cache in transformers acts as a notebook of memoized conclusions, enabling surgical editing and composition without full recomputation. The method achieves significant latency reductions while preserving decision equivalence across model scales.

Self-Pruned Key-Value Attention: Learning When to Write by Predicting Future Utility

arXiv cs.LG

Introduces Self-Pruned Key-Value Attention (SP-KV), a mechanism that learns to predict future utility of key-value pairs to dynamically prune the KV cache, reducing memory usage and decoding speed by 3-10x with minimal performance degradation. The model and utility predictor are trained end-to-end using next-token prediction.