Cached at:
05/26/26, 04:27 AM
# Shard - getting to 10× KV cache compression
Source: [https://krishgarg.com/shard](https://krishgarg.com/shard)
[← back](https://krishgarg.com/writing)Krish Garg and Kirrithan Sathananthan · May 2026 ·[github](https://github.com/krish1905/shard)
**TL;DR\.***Shard*is a drop\-in HuggingFace`Cache`that makes Llama\-3\.1\-8B's KV memory about10×smaller at 8K context \(11×at 32K\) without measurable hits to NIAH or LongBench\. It started as a reimplementation of Google's TurboQuant[\[1\]](https://krishgarg.com/shard#fn1), stalled around 4×, and ended up as a different design once we noticed K and V need different treatments: PCA plus int4 quantization on K \(the matrix is effectively low\-rank once you undo RoPE\), and a Hadamard rotation plus vector quantization on V\. Attention runs directly on the compressed K, no fp16 reconstruction\. Code:[`krish1905/shard`](https://github.com/krish1905/shard)\.
## 1\. Where this started
We were watching a Llama\-3\.1\-8B run on an 8K prompt and noticed the KV cache was eating**1 GB**of VRAM\. Per request\. For a model that's 16 GB on disk\. One\-sixteenth of the whole model just to remember what it already saw\. That ratio felt wrong\.
The math is simple\. For a decoder\-only transformer at sequence length $T$:
$$ \|\\text\{KV\}\| = 2 \\cdot L \\cdot H\_\{kv\} \\cdot d \\cdot T \\cdot \\text\{bytes\}\(fp16\) $$For Llama\-3\.1\-8B\-Instruct the hyperparameters are $L=32$, $H\_\{kv\}=8$, $d=128$\. At $T=8192$ tokens that's $2 \\times 32 \\times 8 \\times 128 \\times 8192 \\times 2 = 1\.07\\text\{ GB\}$\. At 128K context, it's 16 GB, more than the model weights\.
Every paper on LLM inference efficiency in the last two years has some take on this\. Quantize the KV cache[\[kivi\]](https://krishgarg.com/shard#fn-kivi)\. Evict unimportant tokens[\[h2o\]](https://krishgarg.com/shard#fn-h2o)\. Project to a low\-rank subspace[\[kvtc\]](https://krishgarg.com/shard#fn-kvtc)\. Restructure attention entirely[\[mla\]](https://krishgarg.com/shard#fn-mla)\. Different problems, different answers, different trade\-offs\.
We had one target written down:**10× compression on a real LLM, no accuracy tax\.**We didn't know if it was reachable\.
## 2\. The paper we tried to reimplement first
A friend sent us**TurboQuant**[\[1\]](https://krishgarg.com/shard#fn1)\(Zandieh, Daliri, Hadian, Mirrokni at Google Research \+ Google DeepMind \+ NYU;[arXiv:2504\.19874](https://arxiv.org/abs/2504.19874); ICLR 2026\)\. The headline claim is "absolute quality neutrality with 3\.5 bits per channel" and "marginal quality degradation with 2\.5 bits per channel" on Llama\-3\.1\-8B\-Instruct\.
We read it four times in two days\. The core trick is beautiful\. Take any unit vector $x \\in \\mathbb\{S\}^\{d\-1\}$ and multiply it by a random orthogonal matrix $\\Pi$\. Their**Lemma 1**says each coordinate of $\\Pi x$ follows a scaled Beta:
$$ f\_\{X\_j\}\(x\) = \\frac\{\\Gamma\(d/2\)\}\{\\sqrt\{\\pi\}\\,\\Gamma\(\(d\-1\)/2\)\} \(1\-x^2\)^\{\(d\-3\)/2\} $$which, for moderate $d$, converges to $\\mathcal\{N\}\(0, 1/d\)$\. Two facts follow that turn the problem from hard to almost trivial:
1. **Every coordinate has the same distribution**, independent of the input data\. You can precompute one optimal scalar quantizer \(Lloyd\-Max[\[lloyd\]](https://krishgarg.com/shard#fn-lloyd)\) and apply it to every coordinate forever\.
2. **Distinct coordinates are near\-independent**in high dimensions \(formally a concentration\-of\-measure argument on the sphere\)\. Per\-coordinate scalar quantization becomes near\-optimal over what you'd get from full vector quantization\.
Their upper bound on mean\-squared error after $b$ bits per coordinate \(**Theorem 1**\):
$$ D\_\{\\text\{mse\}\}\(Q\_\{\\text\{mse\}\}\) \\leq \\frac\{\\sqrt\{3\}\\,\\pi\}\{2\} \\cdot 4^\{\-b\} \\approx 2\.72 \\cdot 4^\{\-b\} $$and their information\-theoretic lower bound \(**Theorem 3**, via Shannon and Yao's minimax\):
$$ D\_\{\\text\{mse\}\}\(Q\) \\geq 4^\{\-b\} $$The gap is a small constant \($\\sqrt\{3\}\\pi/2 \\approx 2\.72$\)\. For inner\-product distortion, there's an analogous pair via their QJL \(quantized Johnson–Lindenstrauss\) residual correction[\[qjl\]](https://krishgarg.com/shard#fn-qjl): $D\_\{\\text\{prod\}\} \\leq \\frac\{\\sqrt\{3\}\\pi^2\}\{d\} \\cdot 4^\{\-b\} \\cdot \\\|y\\\|^2$\.
Nice math\. But way down in Section 4 we realized their actual Llama\-3\.1\-8B\-Instruct experiments top out at**2\.5–3\.5 bits per coordinate**, which means:
$$ \\text\{compression\} = \\frac\{16\}\{3\.5\} \\approx 4\.57\\times \\quad \\text\{or\} \\quad \\frac\{16\}\{2\.5\} \\approx 6\.4\\times $$Good, not 10×\. We kept rereading the paper to find what we were missing, and eventually had to accept:*the quantizer is optimal against the worst\-case input distribution \(uniform on the sphere\)*\. That's what lets you precompute the codebook and skip calibration\. It's also what stops you from exploiting structure your data actually has\. And the KV cache, presumably, is not a uniform point on a sphere\.
## 3\. Looking at the K cache
We ran Llama\-3\.1\-8B\-Instruct on 8K fineweb prompts and dumped the K tensor for layer 12\. Ran an SVD\.
The**K matrix is effectively low\-rank**\. Not mathematically \(it has full rank\), but 192 out of 1024 singular values captured over 99\.5% of the Frobenius norm\. The tail falls off a cliff\.
Then we did the same thing without first unapplying RoPE\. The singular values decayed much more slowly\. RoPE was hiding the structure\.
**Why\.**Llama's $W\_K$ weight matrix has limited effective rank\. This is known, and weight\-level low\-rank compression like ASVD[\[asvd\]](https://krishgarg.com/shard#fn-asvd)exploits it directly\. When activations go through $W\_K$, the outputs live on a low\-dimensional subspace\. But RoPE rotates each token's K by a*different*angle, which spreads energy across all dimensions and masks the subspace when you stack tokens together\. If you undo RoPE on each token \(which you can do exactly because you know the angle\), the subspace reemerges\.
KVTC[\[kvtc\]](https://krishgarg.com/shard#fn-kvtc)\(Staniszewski & Łańcucki,[arXiv:2511\.01815](https://arxiv.org/abs/2511.01815), ICLR 2026\) says this out loud:*"Remove rotary position embeddings from keys before PCA \(they obscure the low\-rank structure\)\."*Exactly\.
SVD of K \(one layer, 8K tokens\), normalized singular values: with RoPE applied: ████████████████████████████████\.\.\. \(slow decay\) RoPE undone: ███████▓▓▒▒▒░░░░ \(sharp elbow ≈ 192/1024\)
So K has structure\. And the TurboQuant lower bound $D\_\{\\text\{mse\}\} \\geq 4^\{\-b\}$*does not apply to structured data*; it's for uniform points on a sphere\. If K lives on a rank\-$r$ subspace of a $d$\-dimensional ambient space, you get log\-factor free compression before the first bit of quantization\.
This was the first*oh*moment\.
## 4\. But V doesn't have that structure
Same dump, same SVD, on V this time\. Singular values were close to flat\. V looks basically random\. No low\-rank story to exploit\.
This felt important\. Every paper we'd read applied the*same*compression technique to both K and V\. TurboQuant rotates and scalar\-quantizes both; KIVI[\[kivi\]](https://krishgarg.com/shard#fn-kivi)quantizes both; PolarQuant[\[polar\]](https://krishgarg.com/shard#fn-polar)rotates both\. But K and V are**structurally different**\.
We went searching for prior work that already understood this\. Two hits:
- **AsymKV**[\[asym1\]](https://krishgarg.com/shard#fn-asym1)\(Tao et al\.,[arXiv:2410\.13212](https://arxiv.org/abs/2410.13212), Oct 2024\) argues for different*bit widths*on K and V, but applies the same method \(scalar quantization\) to both\.
- **"Homogeneous Keys, Heterogeneous Values"**[\[asym2\]](https://krishgarg.com/shard#fn-asym2)\([arXiv:2506\.05410](https://arxiv.org/abs/2506.05410), NeurIPS 2025\) observes that adjacent keys get similar attention \(so keys can be merged\) while values vary, and proposes a different merging scheme per side\.
Neither combines*low\-rank PCA on K*with*vector quantization on V*in one pipeline\. That's the gap we wanted to fill\.
## 5\. Building the K path
The K side of Shard ended up at:
1. **Undo RoPE\.**Llama uses rotate\-half RoPE\. For each token at position $p$, apply the inverse: $$ k\_\{\\text\{no\-rope\}\} = k \\odot \\cos\(\\theta p\) \- \\text\{rotate\\\_half\}\(k\) \\odot \\sin\(\\theta p\) $$ where $\\text\{rotate\\\_half\}\(\[a, b\]\) = \[\-b, a\]$\.
2. **Per\-layer SVD\.**Flatten K to $\(B \\cdot T, H\_\{kv\} \\cdot d\)$, subtract mean, run`torch\.svd\_lowrank\(centered, q=192, niter=24\)`\.**niter=24 was not optional\.**At niter=6 the basis quality swung across runs; at niter=24 it stabilised\. We verified with a local test: variance of SVD\-recovered eigenvalues across 10 random seeds dropped from 3% at niter=6 to 0\.1% at niter=24\.
3. **DP bit allocation\.**Not all 192 components matter equally\. Split into groups of 64 and run a DP over bit options $\\\{0, 2, 4, 6, 8\\\}$ under a total budget\. This mirrors KVTC's strategy[\[kvtc\]](https://krishgarg.com/shard#fn-kvtc)\. The twist: we added a**4× penalty on the zero\-bit option**\. More on that in a moment\.
4. **Quantize the basis in int8**with per\-column scale\.**Quantize the coefficients in symmetric int4**\(range $\[\-7, 7\]$\) with per\-component scale\. This is what lives on GPU\.
Per\-token K storage:
$$ \\underbrace\{\\text\{rank\} \\cdot 4\}\_\{\\text\{int4 coeffs\}\} \+ \\underbrace\{\\frac\{\\text\{rank\} \\cdot 16\}\{T\}\}\_\{\\text\{fp16 scale, amortized\}\} \+ \\underbrace\{\\frac\{H\_\{kv\} \\cdot d \\cdot 8\}\{T\}\}\_\{\\text\{int8 basis, amortized\}\} \\text\{ bits\} $$For rank=192 at 8K tokens, that's about**0\.75 bits per K element**effective\.
### The 4× drop penalty
We started with a naive MSE objective in the DP\. Every benchmark regressed from the moment we added it\. The fix came from a paper called "Quantization Dominates Rank Reduction for KV\-Cache Compression"[\[qdrr\]](https://krishgarg.com/shard#fn-qdrr), which argues formally that under softmax\-attention's Fisher metric,**dropping a direction is quadratically worse**than scalar\-quantizing it at the same bit cost\. The intuition: if you kill the direction that determined routing for a particular token, that token's argmax flips to something else, a categorical, non\-recoverable error\. Scalar quant only adds noise\.
They derive a ratio of $3 \\times 2^\{2b\}$ \(768× at INT4\) in favor of quantization over deletion\. We didn't need that much; we just added a 4× multiplier:
```
err = gv * DROP_PENALTY if bits == 0 else gv / (3.0 * (1 << (2 * bits)))
```
One constant\. NIAH went from 0\.92 back to 1\.000\.
## 6\. Building the V path
V's flat singular\-value spectrum ruled out low\-rank tricks\. We tried three variants in sequence:
1. **Per\-channel NF4 scalar quant**on the raw V\. Acceptable but lossy\.
2. **Hadamard rotation \+ NF4 scalar quant**\. Hadamard decorrelates the channels \(each one independently looks Gaussian after rotation\), so NF4's Gaussian\-optimal codebook works better\. Small improvement\.
3. **Hadamard \+ K\-means vector quantization on groups of 4 channels, 256\-entry codebook\.**VQ captures joint structure that scalar quant misses\. 256 entries works well: big enough to cover the distribution, small enough that a lookup is one cache line\.
Option 3 won by a clear margin on reconstruction cosine and downstream quality\. Per\-token V storage: $H\_\{kv\} \\cdot d / 4 \\cdot 8 = 256$ bytes for the whole layer, or**2 bits per V element**\.
Combined with the 0\.75 bits/elem for K, we target $\\approx 1\.5$ bits per element on average, which lines up with the 10–11× compression actually measured at 8K–32K context\.
```
def vq_encode(v, group_size, codebook_size):
bh, seq, hd = v.shape
ng = hd // group_size
ch_max = v.abs().amax(dim=1).clamp(min=1e-8)
flat = (v / ch_max.unsqueeze(1)).reshape(bh * seq, ng, group_size).reshape(-1, group_size)
centroids, idx = kmeans(flat.float(), codebook_size, n_iter=30)
bits = max(1, (codebook_size - 1).bit_length())
packed, n_orig = pack_nbits(idx.to(torch.uint8).long(), bits)
return {"centroids": centroids.half(), "packed": packed,
"ch_maxabs": ch_max.half(), "shape": (bh, seq, hd),
"group_size": group_size, "bits": bits, "n_orig": n_orig}
```
## 7\. Attention sinks
We ran the full pipeline end\-to\-end and got NIAH recall of 0\.3\. Not just bad, catastrophic\. Day of bisecting\. Two stacked problems\.
**Attention sinks\.**Xiao et al\.'s StreamingLLM[\[sinks\]](https://krishgarg.com/shard#fn-sinks)\([arXiv:2309\.17453](https://arxiv.org/abs/2309.17453), ICLR 2024\) observed that LLMs dump a huge fraction of their attention onto the first few tokens, even when those tokens are semantically trivial\. Softmax has to sum to 1; if no current token is "the right place to look," attention parks itself on the start\. If you compress the first 4 tokens lossily, the sink gets distorted, and every subsequent attention head produces garbage\. We added**4 FP16 sink tokens**, preserved exactly\.
**Recency bias\.**Language has a strong "look at what just happened" prior\. Compressing the most recent tokens lossily hurts next\-token prediction badly\. We added a**64\-token FP16 residual window**, also preserved exactly\.
With both in place, NIAH jumped from 0\.3 to 1\.000 on the next run\. The sink\+window is a rounding error of storage at 8K context \(68 fp16 tokens out of 8192, under 1% overhead\)\.
Final storage layout \(per layer\): \[ 4 sink fp16 \]\[ middle: PCA\(K\) \+ VQ\(V\) \]\[ 64 window fp16 \] ^^^^^^^^^^^^^^^^^^^^^^^^^^ this is where the 10× comes from
## 8\. Decode tokens: the drift problem
The 10× number is for prefill\. During decode, the model generates one new token at a time\. Where do those go?
Attempt one was dumb: compress each new decode token through the same PCA basis\. Good for the first 20 tokens\. Fell apart at 60\.
The reason is subtle\. Our PCA basis was fit on the prefill K distribution\. Decode tokens are a*different distribution*: they're the model's own outputs, conditioned on the prefill\. Projecting them onto a basis that's too small for them produces a biased reconstruction\. Because each decode step conditions on the previous decode tokens,**errors compound**\. Token 20 is 5% off, token 60 is 30% off, token 150 is unreadable\.
This is where TurboQuant came back\. Its whole point is*data\-oblivious*quantization: the codebook doesn't depend on the data, so every token is independently quantized and there's no drift\. We split the pipeline:
SegmentMethodWhyPrefill middlePCA\(K\) \+ VQ\(V\)data\-dependent, exploits structure, 10×Sink \+ windowFP16tiny, preserved exactlyDecode streamHadamard \+ Lloyd\-Max \(\+ QJL optional\)data\-oblivious, no driftFor the decode stream we implemented Lloyd\-Max centroids for $\\mathcal\{N\}\(0, 1/d\)$ via the closed\-form update $E\[X \\mid X \\in \[a, b\]\]$:
$$ c\_k^\{\\text\{new\}\} = \\frac\{\\phi\(a\_k/\\sigma\) \- \\phi\(b\_k/\\sigma\)\}\{\\Phi\(b\_k/\\sigma\) \- \\Phi\(a\_k/\\sigma\)\} \\cdot \\sigma $$where $a\_k, b\_k$ are the Voronoi boundaries \(midpoints between adjacent centroids\) and $\\phi, \\Phi$ are the standard normal PDF/CDF\. Converges in ~20 iterations for $b \\in \\\{2, 3, 4, 8\\\}$, centroids precomputed once at init\.
```
def lloyd_max_codebook(d, bits, n_iter=100):
n = 1 << bits
sigma = 1.0 / math.sqrt(d)
c = [_inv_cdf((i + 0.5) / n) * sigma for i in range(n)]
phi = lambda z: math.exp(-0.5 * z * z) / math.sqrt(2 * math.pi)
Phi = lambda z: 0.5 * (1 + math.erf(z / math.sqrt(2)))
for _ in range(n_iter):
bounds = [-1e9] + [(c[k-1] + c[k]) / 2 for k in range(1, n)] + [1e9]
nc = []
for k in range(n):
a, b = bounds[k], bounds[k+1]
m = Phi(b/sigma) - Phi(a/sigma)
nc.append(c[k] if m < 1e-12 else (phi(a/sigma) - phi(b/sigma)) * sigma / m)
if max(abs(x - y) for x, y in zip(nc, c)) < 1e-9: break
c = nc
return torch.tensor(c, dtype=torch.float32)
```
Tested four streaming configurations on 150\-token decodes across 5 diverse prompts on an 8K prefix:
max\_newoff \(fp16 window\)4\-bit \+ QJL8\-bit20100%100%100%60100%93\.7%**100%**150100%85\.3%**100% \(750/750\)**8\-bit streaming is our bit\-exact lossless decode path: every single generated token matches FP16 across 150 decodes on 5 prompts\. That was the second*oh*: the theoretical result let us*guarantee*lossless decode rather than empirically hope for it\.
## 9\. Fused attention: deriving the per\-pair Δ identity
At this point the compression pipeline worked end\-to\-end, but there was still a problem\. To compute attention, we had to**decompress the whole K back to fp16**, run the matmul, free\. The persistent storage savings were real; the peak memory savings*during*attention weren't\. If we wanted real memory savings we needed $Q \\cdot K$ to happen on int4 PCA coefficients directly\.
The naive version is easy\. If $K\_t = K\_\{\\text\{coefs\}\}\[t\] \\cdot B^\\top \+ \\mu$ \(no RoPE\):
$$ Q \\cdot K\_t^\\top = \(Q \\cdot B\) \\cdot K\_\{\\text\{coefs\}\}\[t\]^\\top \+ Q \\cdot \\mu^\\top $$Precompute $Q \\cdot B$ once \(size $r=192$\), then for each K token a rank\-192 inner product against int4 coefficients\. No FP16 K ever materializes\. Great\.
**But that's with no RoPE\.**When K actually has RoPE applied post\-reconstruction, the identity breaks\. Q's RoPE angle depends on Q's position; K's depends on K's position\. They don't commute with the basis\.
We spent days\. Read PALU[\[palu\]](https://krishgarg.com/shard#fn-palu)\(low\-rank on the*weight*matrices, different problem\), CommVQ[\[commvq\]](https://krishgarg.com/shard#fn-commvq)\(Apple, ICML 2025; specifically RoPE\-commutative but uses a very constrained codebook\), EliteKV[\[elite\]](https://krishgarg.com/shard#fn-elite)\(joint K/V projection, RoPE\-frequency\-selection\)\. None of them had exactly what we needed\.
Then we did the algebra from first principles\. Llama uses**rotate\-half RoPE**\. Split $q$ into halves $q = \[a, b\]$ with $a, b \\in \\mathbb\{R\}^\{d/2\}$\. Then:
$$ \\text\{RoPE\}\(q, p\) = q \\odot \\cos\(\\theta p\) \+ \\text\{rotate\\\_half\}\(q\) \\odot \\sin\(\\theta p\) $$where $\\text\{rotate\\\_half\}\(\[a, b\]\) = \[\-b, a\]$, and $\\theta$ is the per\-pair frequency vector $\\theta\_i = \\text\{base\}^\{\-2i/d\}$\.
Pick a Q\-token at $p\_q$, a K\-token at $p\_t$, with $\\Delta = p\_t \- p\_q$\. Let $q, k$ be the no\-RoPE vectors, $q = \[a, b\]$, $k = \[c, d\]$\.
Show the full algebra$\\text\{RoPE\}\(q, p\_q\) = \[a \\cos\(\\theta p\_q\) \- b \\sin\(\\theta p\_q\),\\ b \\cos\(\\theta p\_q\) \+ a \\sin\(\\theta p\_q\)\]$
$\\text\{RoPE\}\(k, p\_t\) = \[c \\cos\(\\theta p\_t\) \- d \\sin\(\\theta p\_t\),\\ d \\cos\(\\theta p\_t\) \+ c \\sin\(\\theta p\_t\)\]$
Inner product \(contract over all head\-dim positions, then group per pair $i$\):
$\\langle \\text\{RoPE\}\(q, p\_q\),\\ \\text\{RoPE\}\(k, p\_t\) \\rangle\_i = $ $\\quad \(a\_i \\cos\_q \- b\_i \\sin\_q\)\(c\_i \\cos\_k \- d\_i \\sin\_k\) \\ \+ \\ \(b\_i \\cos\_q \+ a\_i \\sin\_q\)\(d\_i \\cos\_k \+ c\_i \\sin\_k\)$
Expanding and collecting terms \(using $\\cos\(a\-b\) = \\cos a \\cos b \+ \\sin a \\sin b$ and $\\sin\(a\-b\) = \\sin a \\cos b \- \\cos a \\sin b$\):
$= \(a\_i c\_i \+ b\_i d\_i\)\(\\cos\_q \\cos\_k \+ \\sin\_q \\sin\_k\) \+ \(b\_i c\_i \- a\_i d\_i\)\(\\sin\_q \\cos\_k \- \\cos\_q \\sin\_k\)$ $= \(a\_i c\_i \+ b\_i d\_i\)\\cos\(\\theta\_i\(p\_q \- p\_k\)\) \+ \(b\_i c\_i \- a\_i d\_i\)\\sin\(\\theta\_i\(p\_q \- p\_k\)\)$
Using $\\cos\(\-x\) = \\cos\(x\)$ and $\\sin\(\-x\) = \-\\sin\(x\)$ with $\\Delta = p\_t \- p\_q$:
$= \(a\_i c\_i \+ b\_i d\_i\)\\cos\(\\theta\_i \\Delta\) \+ \(b\_i c\_i \- a\_i d\_i\)\\sin\(\\theta\_i \\Delta\)$ \(QED\)
The result:
$$ \\langle \\text\{RoPE\}\(q, p\_q\), \\text\{RoPE\}\(k, p\_t\) \\rangle = \\sum\_\{i=1\}^\{d/2\} \\Big\[ \(a\_i c\_i \+ b\_i d\_i\)\\cos\(\\theta\_i \\Delta\) \+ \(b\_i c\_i \- a\_i d\_i\)\\sin\(\\theta\_i \\Delta\) \\Big\] $$The inner product depends*only on $\\Delta = p\_t \- p\_q$*and the*no\-RoPE*halves\. This is basic relative\-position RoPE algebra, what RoFormer[\[rope\]](https://krishgarg.com/shard#fn-rope)got famous for\. We didn't invent that\. What we hadn't seen done before is**plugging a low\-rank compressed K into this identity\.**
Substitute $k\_n = K\_\{\\text\{coefs\}\} B^\\top \+ \\mu$ \(no\-RoPE reconstruction\)\. The pair contribution becomes \(splitting the basis into halves $B\_c$ for rows $0\.\.d/2\-1$ and $B\_d$ for rows $d/2\.\.d\-1$\):
$$ a\_i c\_i \+ b\_i d\_i = \\sum\_r K\_\{\\text\{coefs\}\}\[r\] \\cdot \\underbrace\{\(a\_i B\_c\[i, r\] \+ b\_i B\_d\[i, r\]\)\}\_\{A\_i\[r\]\} \+ \(a\_i \\mu^c\_i \+ b\_i \\mu^d\_i\) $$and similarly $b\_i c\_i \- a\_i d\_i = \\sum\_r K\_\{\\text\{coefs\}\}\[r\] \\cdot B^\*\_i\[r\] \+ \(\\ldots\)$\.
Per decode step, we precompute matrices $A$ and $B^\*$ of shape $\(n\_q, d/2, r\)$ \(small: $32 \\times 64 \\times 192 \\approx 400$K floats\), then for each K token compute two rank\-$r$ inner products against the int4 coefficients, and mix them with per\-pair cos/sin of $\\theta\_i \\Delta$\.
We wrote it in PyTorch and added a unit test\. Random Q, random compressed K at Llama dimensions \($n\_q=32$, $n\_\{kv\}=8$, $d=128$, $r=192$, $n\_\{\\text\{comp\}\}=1024$\)\. Against the standard dequant\-then\-matmul reference:**max abs diff 0\.0023**, mean 0\.0004\. Within fp16 tolerance\. Then end\-to\-end with the Llama attention monkey\-patched:**120/120 tokens match FP16 across 3 prompts**\.
```
def compressed_scores(q_rope, gc, pos_q, n_q, nkv, hd, inv_freq):
B, dev = q_rope.shape[0], q_rope.device
n_total, rank = gc["k_coeffs_shape"]
n_comp = gc["n_compress"]; gqa = n_q // nkv; hd2 = hd // 2
# undo RoPE on Q at pos_q
freqs = torch.tensor([pos_q], device=dev).float().unsqueeze(-1) * inv_freq.to(dev).float()
cos_q = torch.cat([freqs, freqs], dim=-1).cos().to(torch.float16).view(1,1,1,hd)
sin_q = torch.cat([freqs, freqs], dim=-1).sin().to(torch.float16).view(1,1,1,hd)
q_n = q_rope * cos_q - rotate_half(q_rope) * sin_q
a, b = q_n[..., :hd2].contiguous(), q_n[..., hd2:].contiguous()
# split basis halves; broadcast over GQA
k_basis = gc["k_basis_q"].half() * gc["k_basis_scale"].half()
bph = k_basis.view(nkv, hd, rank)
bc, bd = bph[:, :hd2, :], bph[:, hd2:, :]
kv_idx = torch.arange(n_q, device=dev) // gqa
bc, bd = bc[kv_idx], bd[kv_idx]
coeff_scale = gc["k_coeff_scale"].half()
bc_e, bd_e = bc.unsqueeze(0).unsqueeze(2), bd.unsqueeze(0).unsqueeze(2)
A = (a.unsqueeze(-1) * bc_e + b.unsqueeze(-1) * bd_e) * coeff_scale
Bm = (b.unsqueeze(-1) * bc_e - a.unsqueeze(-1) * bd_e) * coeff_scale
coefs = _unpack_int4(gc["k_coeffs"], n_total, rank).to(dev).view(B, n_comp, rank)
U = torch.einsum("bhspr,bcr->bhspc", A, coefs)
V = torch.einsum("bhspr,bcr->bhspc", Bm, coefs)
# per-pair cos(Δ), sin(Δ), reduce over pair dim
pos_t = torch.arange(gc["compress_start"], gc["compress_start"] + n_comp, device=dev)
phase = (pos_t - pos_q).float().unsqueeze(-1) * inv_freq.to(dev).float()
cos_d, sin_d = phase.cos().half(), phase.sin().half()
return (torch.einsum("bhspc,cp->bhsc", U, cos_d)
+ torch.einsum("bhspc,cp->bhsc", V, sin_d)) / math.sqrt(hd)
```
This derivation was the third*oh*\. Nothing exotic, just rotate\-half RoPE algebra combined with the linearity of $k\_n$\. But we spent three days thinking it was impossible before we tried it\.
## 10\. Fused V: moving the Hadamard past the sum
With K fused, V still materialized to FP16 before the weighted sum\. Another opportunity\.
Recall V storage is Hadamard\-rotated VQ\. On decode:
$$ V\_t = H\(V^\{\\text\{rot\}\}\_t\), \\quad V^\{\\text\{rot\}\}\_t = \\text\{codebook\}\[\\text\{idx\}\[t\]\] \\odot \\text\{chMax\} $$where $H$ is the normalized Hadamard matrix\. The attention output for query $t\_q$ is $\\sum\_t w\_t \\cdot V\_t = \\sum\_t w\_t \\cdot H\(V^\{\\text\{rot\}\}\_t\)$\. Since Hadamard is linear:
$$ \\sum\_t w\_t \\cdot H\(V^\{\\text\{rot\}\}\_t\) = H\\Big\(\\sum\_t w\_t \\cdot V^\{\\text\{rot\}\}\_t\\Big\) $$So we never need to apply the inverse Hadamard to every token\. We do*one*Hadamard on the final weighted sum\. The intermediate $V^\{\\text\{rot\}\}\_t$ is cheap to compute on demand \(a codebook lookup \+ broadcast multiply\)\.
Impact:
- Before: materialize $\(n\_\{\\text\{comp\}\}, d\)$ fp32 intermediate, apply Hadamard to each, then weighted sum\.
- After: codebook\-lookup in rotated space, weighted sum, one Hadamard at the end\.
Local numerical check: max abs diff**0\.000043**\(much smaller than the K path because Hadamard is exact and linear; the only loss is VQ itself\)\.
```
def _v_weighted_sum_compressed(weights_mid, gc, B, nkv, n_comp, hd):
qd = gc["v_qdata"]
gs = qd["group_size"]; ng = hd // gs
idx = unpack_nbits(qd["packed"].to(weights_mid.device), qd["bits"], qd["n_orig"])
idx = idx[:B * nkv * n_comp * ng].long().view(B * nkv, n_comp, ng)
centroids = qd["centroids"].to(weights_mid.device).float()
ch_max = qd["ch_maxabs"].to(weights_mid.device).float()
v_rot = centroids[idx].view(B * nkv, n_comp, hd) * ch_max.unsqueeze(1)
v_rot = v_rot.view(B, nkv, n_comp, hd)
gqa = weights_mid.shape[1] // nkv
v_rot_q = v_rot.repeat_interleave(gqa, dim=1)
out_rot = torch.einsum("bhsk,bhkd->bhsd", weights_mid, v_rot_q.float())
return hadamard(out_rot.to(torch.float16)).float()
```
## 11\. The Triton kernel
The PyTorch einsums for`compressed\_scores`are correct but slow\. The hot matmul is
$$ \(M, R\) \\times \(R, N\)\_\{\\text\{from int4\}\} \\to \(M, N\), \\quad M \\approx n\_q \\cdot d/2,\\ R = 192,\\ N = n\_\{\\text\{comp\}\} $$where the right\-hand matrix is packed int4 \(one byte per two values\) with a per\-column fp16 scale\. Standard Triton matmul structure, plus inline int4 unpacking:
```
@triton.jit
def _int4_matmul_kernel(A_ptr, B_packed_ptr, B_scale_ptr, Out_ptr,
M, N, R, sam, sar, sbn, sbr, som, son,
BM: tl.constexpr, BN: tl.constexpr, BR: tl.constexpr):
pid_m = tl.program_id(0); pid_n = tl.program_id(1)
offs_m = pid_m * BM + tl.arange(0, BM)
offs_n = pid_n * BN + tl.arange(0, BN)
acc = tl.zeros((BM, BN), dtype=tl.float32)
for r_start in range(0, R, BR):
offs_r = r_start + tl.arange(0, BR)
r_mask = offs_r < R
a = tl.load(A_ptr + offs_m[:, None] * sam + offs_r[None, :] * sar,
mask=(offs_m[:, None] < M) & r_mask[None, :], other=0.0)
half = offs_r // 2
is_hi = (offs_r % 2) == 1
b_bytes = tl.load(B_packed_ptr + offs_n[:, None] * sbn + half[None, :] * sbr,
mask=(offs_n[:, None] < N) & r_mask[None, :], other=0)
lo = (b_bytes & 0x0F).to(tl.float32) - 8.0
hi = ((b_bytes >> 4) & 0x0F).to(tl.float32) - 8.0
b_vals = tl.where(is_hi[None, :], hi, lo)
s = tl.load(B_scale_ptr + offs_r, mask=r_mask, other=0.0)
b_scaled = (b_vals * s[None, :]).to(tl.float16)
acc += tl.dot(a.to(tl.float16), tl.trans(b_scaled))
tl.store(Out_ptr + offs_m[:, None] * som + offs_n[None, :] * son,
acc.to(tl.float16),
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
```
Block sizes tuned to $BM=64, BN=128, BR=64$\. Wired in as an optional fast path in`compressed\_scores`, gated on`is\_cuda & HAS\_TRITON & B==1`, with the PyTorch fallback verified at**0\.0 max diff**against the reference\.
## 12\. Throughput
The compressed cache was small at rest, but the decode hot path was rebuilding too much fp16 state and launching too many small operations\. Memory looked good in a checkpoint, then tok/s made it obvious\.
We split the benchmark into three clocks \(prefill, one\-time compression, and steady\-state decode\) and the profiler made the bottleneck visible: extra materialization, Python\-level tensor plumbing, cat/copy churn, and the VQ weighted\-sum path\.
### Profiler view: what changed after optimization
Shard decode before
15\.057s
Shard decode after
3\.050s
The first profiler pass showed a 15\.057s Shard decode region\. After the profiling\-driven fixes, the same mixed profile dropped to 3\.050s; the cleaner steady\-state decode\-only profile measured 849ms versus 237ms for FP16\.
The first profile pointed at the obvious problem: the implementation was still paying a decompression tax in places where the algorithm said it should not\. The fix sequence was:
1. **Stop materializing K for attention\.**The relative\-Δ RoPE identity made it possible to compute attention scores from int4 PCA coefficients directly\.
2. **Fuse the K\-side rank matmul\.**The PyTorch einsum reference was correct, but the hot shape is a small fixed\-rank int4 matmul\. Moving that into Triton removed a lot of framework overhead\.
3. **Stop materializing V before the weighted sum\.**The Hadamard linearity trick moved the inverse transform after the softmax\-weighted sum\.
4. **Move the VQ weighted sum into a single Triton kernel\.**Codebook lookup, per\-channel scale, and accumulation happen together instead of as a chain of indexing operations, and one launch covers all KV heads\.
5. **Add lazy per\-layer eviction\.**Reconstructed per\-layer tensors are freed as the decode loop advances through layers, so peak memory tracks the compressed representation instead of accumulating all reconstructed layers\.
```
Before profiling-driven optimization:
shard_decode_loop CUDA total: 15.057s
fp16_decode_loop CUDA total: 1.096s
aten::cat CUDA total: 295.298ms, 171.40GB CUDA mem, 45,915 calls
aten::copy_ CUDA total: 225.648ms, 56,610 calls
_int4_matmul_kernel: 160.195ms, 2,044 calls
_vq_weighted_sum_all_kernel: 239.290ms, 511 calls
After profiling-driven optimization:
shard_decode_loop CUDA total: 3.050s
fp16_decode_loop CUDA total: 837.203ms
aten::cat CUDA total: 94.797ms, 31.37GB CUDA mem, 23,554 calls
aten::copy_ CUDA total: 162.949ms, 34,556 calls
_int4_matmul_kernel: 80.289ms, 1,024 calls
_vq_weighted_sum_all_kernel: 238.762ms, 512 calls
Clean decode-only profile after optimization:
shard_decode_only CUDA total: 849.314ms
fp16_decode_only CUDA total: 236.553ms
aten::cat CUDA total: 15.968ms, 1.07GB CUDA mem, 5,768 calls
aten::copy_ CUDA total: 19.736ms, 7,752 calls
_int4_matmul_kernel: 6.052ms, 256 calls
_vq_weighted_sum_all_kernel: 14.844ms, 128 calls
```
After those changes, the result is no longer catastrophically slow, but it is also not faster than FP16\. The latest decode throughput benchmark reports:
contextFP16 tok/sShard tok/srelative speedcompression2K28\.8514\.090\.49×7\.1×4K28\.6313\.720\.48×8\.8×8K29\.8411\.620\.39×10\.0×Quality is unchanged, storage compression is strong, and decode throughput is the remaining engineering problem\. The bottleneck now is many small launches and tensor reshapes around attention, competing against a very optimized FP16 baseline\.
## 13\. Lazy per\-layer eviction
Small but real implementation detail\. During decode, when attention is on layer $i$, nothing needs layer $i\-1$'s reconstructed FP16 K/V anymore\. By default HuggingFace's Cache API keeps it cached on the layer forever, so peak memory = all 32 layers' K/V materialized simultaneously\.
One\-line invariant: when layer $i$'s attention starts, free layer $i\-1$'s FP16 K/V \(they get rebuilt from compressed buffers next time\)\. At any instant,**at most one layer's decompressed K/V exists in GPU memory\.**
```
def _lazy_evict(self):
c = self._cache_ref
if c is None or not c._lazy_per_layer: return
last = c._last_accessed_layer
if last != -1 and last != self.layer_idx:
prev = c.layers[last]
if getattr(prev, "_gpu_compressed", None) is not None:
prev._keys = None
prev._values = None
prev._needs_decompress = True
c._last_accessed_layer = self.layer_idx
```
Not a research idea\. Just the kind of plumbing that's the difference between "we technically save memory" and "we fit one more request on the GPU\."
## 14\. Final results
All on**Llama\-3\.1\-8B\-Instruct**, single NVIDIA B200\. Same model as TurboQuant's Table 1 so the comparison is direct\.
### Needle\-in\-a\-haystack \(Fu et al\. setup[\[fu\]](https://krishgarg.com/shard#fn-fu), 5 depths × 4 lengths\)
contextShard recallcompressionTurboQuant paper4K**1\.000**8\.8×0\.997 @ 4–6×8K**1\.000**10\.0×0\.997 @ 4–6×16K**1\.000**10\.8×0\.997 @ 4–6×32K**1\.000**11\.2×0\.997 @ 4–6×### LongBench\-E \(8 tasks, F1 / ROUGE\-L, 15 samples each\)
taskFP16ShardΔcrqasper12\.0811\.32−0\.768\.1×multifieldqa\_en16\.6216\.77\+0\.157\.9×hotpotqa15\.2715\.26−0\.008\.7×2wikimqa14\.8515\.13\+0\.278\.9×gov\_report22\.3422\.54\+0\.205\.7×multi\_news14\.4713\.50−0\.974\.5×triviaqa16\.6817\.21\+0\.538\.5×samsum17\.6217\.82\+0\.197\.6×**average****16\.24****16\.19****−0\.05****7\.5×**### WikiText\-2 PPL
PPLΔFP16 baseline6\.45—Shard \(10×\)6\.47**\+0\.26%**### Compression by context
contextcompressionstored cache4K8\.8×58 MB8K10\.0×102 MB16K10\.8×190 MB32K11\.2×366 MB### Streaming quality
pathtargetresult8\-bit lossless streaming100% match vs FP16 at 150 decode tokens**750/750 \(100\.0%\)**Compared to the TurboQuant paper's Table 1 on the same model: they report near\-lossless quality around 4–6× compression\. Shard reaches 10× at 8K and 11\.2× at 32K with the same NIAH ceiling and a \+0\.26% WikiText\-2 PPL delta\. The remaining caveat is throughput: the compressed decode path currently runs at 0\.39–0\.49× FP16, so the current implementation is primarily a memory\-capacity win, not a latency win\.
## 15\. Stuff we tried that didn't work
- **Cross\-layer K prediction\.**Predict layer $\\ell$'s K from layer $\\ell\-1$'s K via a learned linear map, compress only the residual\. R² ≈ 0\.76 on held\-out data\. In practice: two days of work, then we realised cumulative error compounds across 32 layers and crushed NIAH to 1\.2%\. Dead\.
- **K delta encoding\.**Compress $K\[t\] \- K\[t\-1\]$ instead of $K\[t\]$\. Same cumulative error story\. 29% NIAH\. Dead\.
- **Asymmetric int4 range $\[\-8, 7\]$\.**Uses one more code point but loses the symmetry that zero\-centered PCA coefficients naturally have\. Measured quality dropped\. Reverted to symmetric $\[\-7, 7\]$\.
- **Per\-token 4\-bit requantization of prefill K/V after PCA reconstruction\.**Double quantization \(float → PCA → int4 → fp16 → int4\) killed quality by \+6% PPL\. Reverted\.
- **Decompress\-per\-layer during attention \(no lazy eviction, explicit recompress\)\.**5–8× slower than the fp16 baseline because decompression reruns on every decode step\. Reverted\.
- **zlib entropy coding of VQ indices\.**Gave another 3–8% storage boost, but moved the bytes to CPU, incompatible with the GPU\-resident format that the fused kernels need\. Reverted\.
- **PALU\-style Hadamard\-in\-the\-PCA\-coefficient\-space\.**Tried to rotate the PCA coefficients by a Hadamard matrix to uniformize their variance\. Broke the DP bit allocation \(which assumes the variance\-ordered structure of PCA coords\)\. Reverted\.
---
Code:[github\.com/krish1905/shard](https://github.com/krish1905/shard)