Less Data, Faster Training: repeating smaller datasets speeds up learning via sampling biases

arXiv cs.LG Papers

Summary

This paper investigates the 'small-vs-large gap', where training on fewer samples with more repetitions can lead to faster learning and compute savings compared to using larger datasets, attributing the speedup to layer-wise growth enabled by sampling biases. The findings suggest that smaller datasets with repetition can be proactively leveraged as favorable inductive biases, particularly in reasoning tasks.

arXiv:2605.20314v1 Announce Type: new Abstract: This work investigates the ``small-vs-large gap'', where repeating on fewer samples can lead to compute saving during training compared to using a larger dataset. This is observed across algorithmic tasks, architectures and optimizers and cannot be explained using prior theory. We argue that the speedup comes from appropriate layer-wise growth enabled by sampling biases, which is more pronounced when the dataset size is smaller. We provide both theoretical analysis and empirical evidence from various interventions. Our results suggest that using a smaller dataset with more repetitions is not just a fallback strategy under data scarcity, but can be proactively leveraged as a favorable inductive biases for optimization, particularly in reasoning tasks.
Original Article
View Cached Full Text

Cached at: 05/21/26, 06:23 AM

# Less Data, Faster Training: repeating smaller datasets speeds up learning via sampling biases
Source: [https://arxiv.org/html/2605.20314](https://arxiv.org/html/2605.20314)
###### Abstract

This work investigates the “small\-vs\-large gap”, where repeating onfewer samplescan lead tocompute savingduring training compared to using a larger dataset\. This is observed across algorithmic tasks, architectures and optimizers and cannot be explained using prior theory\. We argue that the speedup comes from appropriate layer\-wise growth enabled bysampling biases, which is more pronounced when the dataset size is smaller\. We provide both theoretical analysis and empirical evidence from various interventions\. Our results suggest that using a smaller dataset with more repetitions is not just a fallback strategy under data scarcity, but can be proactively leveraged as a favorable inductive biases for optimization, particularly in reasoning tasks\.

Jingwen Liu Columbia University

jingwenliu@cs\.columbia\.edu

Ezra Edelman University of Pennsylvania

ezrzae@cis\.upenn\.edu

Surbhi Goel University of Pennsylvania

surbhig@cis\.upenn\.edu

Bingbin Liu Kempner Institute, Harvard Universitybliu@g\.harvard\.edu

## 1Introduction

The conventional wisdom on data use is the more the better, a view supported by both classic generalization theory and extensive empirical evidence\(Hernandezet al\.,[2022](https://arxiv.org/html/2605.20314#bib.bib34); Muennighoffet al\.,[2023](https://arxiv.org/html/2605.20314#bib.bib3)\)\. Recent work has reported a counterintuitive phenomenon that fewer samples can lead tofasterlearning\. One example is online SGD for single\-index models, where taking more than one gradient steps on the same batch can lead to faster convergence in terms of steps\(Dandiet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib5); Arnaboldiet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib6); Leeet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib8)\)\. Similarly, empirical study on Transformers observes that given a number of training steps, multi\-epoch training on a randomly sampled dataset can achieve a better test performance than training with per\-step fresh samples, for various algorithmic tasks\(Charton and Kempe,[2024](https://arxiv.org/html/2605.20314#bib.bib2)\)\. In LLM post\-training, a concurrent work\(Kopiczkoet al\.,[2026](https://arxiv.org/html/2605.20314#bib.bib39)\)also observes more epochs on fewer samples can lead to better performance under a fixed compute budget for math and coding tasks\.

These are examples of what we refer to assmall\-vs\-large gaps, where training on a smaller number of samples results in reduced trainingcomputefor a given model, where compute is defined as the total number of \(possibly repeated\) samples on which the model performs gradient updates \(e\.g\., training steps×\\timesbatch size\) in order to reach a target performance\.

This work aims to better understand such small\-vs\-large gaps\. We begin by extending prior work, confirming that the small\-vs\-large gaps appear across a variety of settings \([Figures1](https://arxiv.org/html/2605.20314#S1.F1)and[2](https://arxiv.org/html/2605.20314#S3.F2)\), including different tasks, architectures, and optimizers, and under both mini\-batch and full\-batch updates\. In contrast to prior studies, many of the settings we examine are not explained by existing theory \([Section4\.1](https://arxiv.org/html/2605.20314#S4.SS1)\)\. These include comparisons of CSQ\-SQ lower bounds\(Dandiet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib5); Arnaboldiet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib6); Leeet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib8)\), gradient variance reduction\(Kothaet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib37)\), and curriculum learning\(Valiant,[2012](https://arxiv.org/html/2605.20314#bib.bib15); Abbeet al\.,[2023b](https://arxiv.org/html/2605.20314#bib.bib17)\)or learning under biased distributions\(Kalaiet al\.,[2009](https://arxiv.org/html/2605.20314#bib.bib16); Cornacchiaet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib19)\)\. Notably, the small\-vs\-large gap exists even under full\-batch gradient updates \([Figure2](https://arxiv.org/html/2605.20314#S3.F2)\), implying that explanations based on stochastic gradient updates are not sufficient\.

Instead, we show that the small\-vs\-large gap primarily results from favorable optimization biases due tosampling biasesof the dataset\. Intuitively, repeating the dataset reinforces the bias induced by sampling, which helps adjust the relative growth of different layers and in turn speeds up feature learning\. This becomes more evident when the dataset is smaller due to a stronger sampling bias\. We formalize this intuition in[Section4\.2](https://arxiv.org/html/2605.20314#S4.SS2)and shows that training on smaller datasets can reduce the number of steps required for convergence \([Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1)\)\. This is further supported by the fact that the small\-vs\-large gap can be removed with proper selection of layerwise initialization or learning rates\. We provide empirical evidence from various interventions in[Section5](https://arxiv.org/html/2605.20314#S5)\. Such sampling biases make the model more robust to learning rate and initialization choice, leading to a gap under standard parameterization\.

In summary, our work characterizes the small\-vs\-large gap with the following contributions:

- •We confirm that the small\-vs\-large gap exists across tasks, architectures, and optimizers\. The gap is evident in both the number ofoptimization stepsand the overallcompute complexity, which depends on both the number of steps and the per\-step cost, proportional to the batch size\.
- •We show thatsampling biasesinduced by smaller datasets is a primary driver of the small\-vs\-large gap \([Section4](https://arxiv.org/html/2605.20314#S4)\): sampling biases modulate the relative magnitude of updates across layers, which in turn helps with feature learning\. We identify regimes where existing theory fails to explain the gap \([Section4\.1](https://arxiv.org/html/2605.20314#S4.SS1)\), and theoretically show that training on smaller datasets reduces the number of steps required for convergence in MLP \([Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1)\)\.
- •We further support the theoretical explanation with a broad set of empirical evidence\. First, training on a small dataset withrandom labelsleads to a speedup comparable to that observed with real labels \([Section5\.1](https://arxiv.org/html/2605.20314#S5.SS1)\), indicating that sampling bias is the main mechanism, as the gap persists without task\-relevant signal\. Moreover,parameter\-wise interventionssubstantially reduce the small\-vs\-large gap \([Section5\.2](https://arxiv.org/html/2605.20314#S5.SS2)\), including adjustments to initialization scales and parameter\-wise learning rates across both MLP and Transformers\. For Transformers, our findings additionally suggest that the widely used QK normalization has nuanced effects on optimization that merit further investigation\.

We discuss the implications and limitations of our results in[Section6](https://arxiv.org/html/2605.20314#S6)\. Together, our results suggest that training on a smaller dataset with increased repetitions is not merely a fallback under data scarcity, but a source of beneficial optimization inductive biases that can be leveraged more proactively, particularly for reasoning tasks\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/sgd_transformer_d20_k6.png)\(a\)Sparse parity
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/transformer_sgd_d40_sim.png)\(b\)Single\-index model
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/transformer_sgd_icl.png)\(c\)ICL linear regression
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/mod_addition.png)\(d\)Modular addition

Figure 1:Small\-vs\-large gap exists in various tasks\. Across various feature learning and algorithmic tasks \([Section2](https://arxiv.org/html/2605.20314#S2)\), training on a smaller dataset \(yellow curves\) leads to faster convergence than training on a larger dataset \(blue curves\)\. Results are based on 2\-layer Transformers optimized with mini\-batched AdamW\. An “nn\-phase” schedule denotes that the training set size is progressively increased overnnphases \([Section2](https://arxiv.org/html/2605.20314#S2)\)\.### 1\.1Related work

It is widely believed that in deep learning, more is better, as captured by the study of scaling laws\. However, different resources may not need to be scaled together\. For instance,data repetition, which keeps the sample size fixed while scaling up compute, can achieve similar performance to compute\-matched online training when the amount of repetition is moderate\(Xuet al\.,[2021](https://arxiv.org/html/2605.20314#bib.bib13); Sekhariet al\.,[2021](https://arxiv.org/html/2605.20314#bib.bib9); Muennighoffet al\.,[2023](https://arxiv.org/html/2605.20314#bib.bib3); Linet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib1); Yanet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib10)\)\. We are interested in the more extreme phenomenon termed thesmall\-vs\-large gap, wherereducingthe sample size when holding compute constant can help improve performance\. The small\-vs\-large gap has been observed in recent work on algorithmic tasks\(Charton and Kempe,[2024](https://arxiv.org/html/2605.20314#bib.bib2)\), in\-context learning\(Zucchetet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib35)\), and language model finetuning on reasoning tasks\(Kopiczkoet al\.,[2026](https://arxiv.org/html/2605.20314#bib.bib39)\)\. Prior work has shown this for learning single\-index models, where taking more than one gradient steps on the same set of samples can reduce the total number of gradient steps\(Dandiet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib5); Arnaboldiet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib6); Leeet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib8)\)\. The intuition is that while online SGD lies in the class of correlational statistical query \(CSQ\) algorithms, SGD with sample repeats belongs the more general class of statistical query \(SQ\) algorithms\. In contrast, we find that the compute savings from using less data hold even in regimes where the CSQ\-SQ distinction does not apply, including training with full\-batch gradient descent, and tasks with discrete domains\. Close to our quadratic setting in[Section4\.2](https://arxiv.org/html/2605.20314#S4.SS2), a concurrent work byKovačevićet al\.\([2026](https://arxiv.org/html/2605.20314#bib.bib24)\)shows the statistical advantage of full\-batch gradient descent over SGD where mini\-batches are sampled fresh from the population\. A key difference is that the model inKovačevićet al\.\([2026](https://arxiv.org/html/2605.20314#bib.bib24)\)only has a single layer \(i\.e\.,f​\(x\)=σ​\(w⊤​x\)f\(x\)=\\sigma\(w^\{\\top\}x\)\), hence the effect of relative weight norm does not apply\.

Previous work has also studied how multi\-pass SGD can improve the sample complexity over single\-pass SGD in various settings, including linear regression\(Pillaud\-Vivienet al\.,[2018](https://arxiv.org/html/2605.20314#bib.bib14); Linet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib1)\), general stochastic convex optimization\(Sekhariet al\.,[2021](https://arxiv.org/html/2605.20314#bib.bib9)\), and non\-convex problems under the PL condition\(Xuet al\.,[2021](https://arxiv.org/html/2605.20314#bib.bib13)\)\. A crucial difference from our work is that these results focus on savingsamplesbut not thecompute: they show that the population error achieved byTTonline SGD steps, where each step is taken on an iid sample from the population, can be achieved byTTsteps of multi\-pass SGD, where each step is taken on an iid sample drawn from an empirical distribution of size smaller thanTT\. In contrast, we will show that it is possible to achieve the same error asTTonline SGD steps usingfewer thanTTstepsof multi\-pass SGD\.

We will show that the key mechanism behind such speedup comes from the strong sampling biases from smaller datasets, which effectively adjusts the relative update speeds across the layers, leading to faster learning\. Such adjustments relate to the idea of balancing contributions from different layers, which has been studied extensively in the optimization and feature learning\(Yang and Hu,[2020](https://arxiv.org/html/2605.20314#bib.bib18); Azulayet al\.,[2021](https://arxiv.org/html/2605.20314#bib.bib36); Yanget al\.,[2022](https://arxiv.org/html/2605.20314#bib.bib7),[2023](https://arxiv.org/html/2605.20314#bib.bib22); Everettet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib28)\)\.

## 2Setup

##### Tasks

We consider synthetic tasks, which have tunable parameters and thus allow for explicit control over task complexity\.

We start with two classic feature learning which have been extensively studied in the literature\.

- •Single\-index model \(SIM\): the input is a Gaussian vectorx∼𝒩​\(0,Id\)x\\sim\{\\mathcal\{N\}\}\(0,I\_\{d\}\), and the label is given byy:=ϕ​\(⟨w∗,x⟩\)y:=\\phi\(\\langle w^\{\*\},x\\rangle\), wherew∗w^\{\*\}is the ground truth feature vector, andϕ:ℝ→ℝ\\phi:\\mathbb\{R\}\\rightarrow\\mathbb\{R\}is an unknown link function\. Our experiments take the link function to be a Hermite polynomial, denoted asHek\\text\{He\}\_\{k\}for some orderkk\.
- •\(d,k\)\(d,k\)\-sparse parity: the input is a boolean vectorx∼Unif​\(\{±1\}d\)x\\sim\\text\{Unif\}\(\\\{\\pm 1\\\}^\{d\}\), and the label is given byy:=∏i∈Sxiy:=\\prod\_\{i\\in S\}x\_\{i\}, whereS⊂\[d\]S\\subset\[d\]is an unknown support of sizekk\.

We consider two additional algorithmic tasks for Transformers:

- •In\-context linear regression: the input is a sequencex1,y1,x2,y2,…,xk,yk,xqx\_\{1\},y\_\{1\},x\_\{2\},y\_\{2\},\\ldots,x\_\{k\},y\_\{k\},x\_\{q\}of length2​k\+12k\+1, where each sequence we independently sample aw∼𝒩​\(0,In\)w\\sim\{\\mathcal\{N\}\}\(0,I\_\{n\}\),xi∼𝒩​\(0,In\)x\_\{i\}\\sim\{\\mathcal\{N\}\}\(0,I\_\{n\}\),yi=w⊤​xi,∀i∈\[k\]y\_\{i\}=w^\{\\top\}x\_\{i\},\\forall i\\in\[k\]and the label is given byy:=w⊤​xqy:=w^\{\\top\}x\_\{q\}\.
- •\(N,p\)\(N,p\)\-modular addition: the input are two numbersx,z∼Unif​\(\[N\]\)x,z\\sim\\text\{Unif\}\(\[N\]\), and the label is given byy:=\(x\+z\)modpy:=\(x\+z\)\\mod pfor some primepp\. For Transformer experiments,x,zx,zare each represented by⌈logb⁡N⌉\\lceil\\log\_\{b\}N\\rceildigits in base\-bb, and the output logits have sizepp\.

##### Data reuse strategies

We consider both batch stochastic gradient descent \(SGD\) and \(full\-batch\) gradient descent \(GD\) over datasets of different sizes\.111See[Figure18](https://arxiv.org/html/2605.20314#A2.F18)for an ablation on the dataset size\.For batch SGD, the batches are sampled uniformly over the distribution with replacement\. We additionally considermulti\-phase training, where the dataset sizes across phases can vary\. In particular, forTT\-phase repeat, batches are sampled from a subset𝒮i\{\\mathcal\{S\}\}\_\{i\}at theit​hi\_\{th\}stage fori∈\[T\]i\\in\[T\], where𝒮i⊂𝒮j\{\\mathcal\{S\}\}\_\{i\}\\subset\{\\mathcal\{S\}\}\_\{j\}forj\>ij\>i\.222We experimented with an alternative where each subsets are drawn independently without requiring to be a superset of the previous subsets\. The results were similar, so we keep the subset requirement which has the additional benefit of smaller sample complexity\.An example is 2\-phase training, where the first phase uses a subset randomly sampled from the population, and the second phase optimizes on the full population\. This is similar to the two\-set training proposed inCharton and Kempe \([2024](https://arxiv.org/html/2605.20314#bib.bib2)\), where each batch is a mix of samples from two sets: one small set which is repeated, and one large set consisting of online samples\. General multi\-phase training requires specifying the sizes and number of steps per phase\. A heuristic is to make \(1\) the first few subsets relatively small so that the model can both quickly reach non\-trivial train set performance and deviate non\-trivially from initialization; and \(2\) the final subset𝒮T\{\\mathcal\{S\}\}\_\{T\}sufficiently large to ensure good generalization\. As an ablation, we experiment with auto\-scheduling which suggest that such heuristic is effective \([Figure19](https://arxiv.org/html/2605.20314#A2.F19)\)\. Details are provided in[SectionB\.2\.2](https://arxiv.org/html/2605.20314#A2.SS2.SSS2)\.

##### Experimental setup

Our primary focus is the performance under a given compute, which is measured by the batch size×\\timesnumber of optimization steps\. We reportexpected performanceunder a fixed compute by taking the accuracy or loss averaged over random seeds\. For tasks where the accuracy exhibits shape phase transitions, the average accuracy can also be interpreted as theprobability of success\.

We train with both MLPs and Transformers\. The MLP has ReLU activation and no residual connections\. The Transformer has an optional QK normalization\. Models are of depth\-2 unless specified otherwise\. All weights are initialized with Pytorch defaults; for example,Wi​j∼Unif\[−1/din,1/din\]W\_\{ij\}\\sim\\text\{Unif\}\[\-1/\\sqrt\{d\_\{\\text\{in\}\}\},1/\\sqrt\{d\_\{\\text\{in\}\}\]\}\. We use the SGD optimizer for MLPs and AdamW for Transformers unless specified, and sweep over the learning rate for each setup\.

More details are provided in[AppendixB](https://arxiv.org/html/2605.20314#A2)\.

## 3Small\-vs\-large gap: less data can lead to faster learning

We first present empirical evidence that less data leads to accelerated learning across tasks and setups\.

##### Mini\-batch updates

We start with training using mini\-batch updates, which is the common training strategy in practice and where prior work has also reported the small\-vs\-large gap\(Charton and Kempe,[2024](https://arxiv.org/html/2605.20314#bib.bib2)\)\. As shown in[Figure1](https://arxiv.org/html/2605.20314#S1.F1), smaller datasets lead to faster convergence for all tasks\. For SIM, in\-context learning regression and modular addition, multi\-phase is used to balance accelerated optimization and good generalization\.

However, mini\-batch updates introduce a confounding factor of thenumber of repetitions: when trained for the same number of steps, each sample in a smaller dataset is reused more frequently over the course of training\. It is therefore unclear whether this increased repetition is the primary source of the observed speedup\. Our subsequent gradient descent results show that this is not the case\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_SGD_d20k6_m64_online_vs_64x.png)\(a\)SGD on \(20, 6\)\-parity
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_GD_d20k6_m64_full_vs_subset.png)\(b\)GD on \(20, 6\)\-parity\.
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_sgd_m64_sim.png)\(c\)SGD on SIM,d=40d=40\.
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim.png)\(d\)GD on SIM,d=40d=40\.

Figure 2:Small\-vs\-large gap exists for both mini\-batch and full\-batch training\. Results are based on SIM and parity with 2\-layer MLPs, optimized with both mini\-batch \(SGD\) and full\-batch \(GD\) updates\. The small\-vs\-large gap with GD is a notable example that prior theory fails to capture \([Section4\.1](https://arxiv.org/html/2605.20314#S4.SS1)\)\.
##### Gradient descent \(full\-batch updates\)

We sample datasets of varying sizes from the population and run \(full\-batch\) gradient descent on each dataset\. As shown in[Figure2](https://arxiv.org/html/2605.20314#S3.F2), smaller datasets have better performance at each time step throughout training\. Moreover, the total saving in compute is much more significant than the reduction in steps, since smaller datasets also incur lower per\-step computational cost\. For instance, for\(20,6\)\(20,6\)\-sparse parity, usingN=214N=2^\{14\}converges in 1500 steps whereas using training on the full population \(i\.e\.N=220N=2^\{20\}\) requires more than 2000 steps, leading to a 100x speedup in compute\.

## 4Unpacking the efficiency gain from smaller datasets

In this section, we use sparse parity as a sandbox for understanding the small\-vs\-large gap\. We first explain why alternative theories in existing work are not sufficient, and then provide our explanation that hinges on dataset sampling biases\.

### 4\.1Prior theories are insufficient

#### 4\.1\.1SQ\-CSQ difference

One mechanism of acceleration identified by prior work is that taking multiple gradient updates on the same data effectively transform \(batch\) SGD from a correlational statistical query \(CSQ\) algorithm to a statistical query \(SQ\) algorithm\(Dandiet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib5); Arnaboldiet al\.,[2024](https://arxiv.org/html/2605.20314#bib.bib6); Leeet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib8)\), with the latter being a more powerful class of algorithms\. Prior work uses this to explain the speedup for batch SGD on single\-index model \(SIM\), for which there is a known gap between CSQ and SQ lower bounds\.

While insightful, the CSQ\-SQ gap cannot fully explain the observed speedup\. It cannot explain why there is a speedup for SIM even when using \(full\-batch\) GD, where smaller and larger datasets are both repeated and for the same amount of times\. Moreover, it is not applicable to the wide range of tasks where the SQ and CSQ lower bound coincide, which include all discrete problems such as sparse parity and mod addition\.

#### 4\.1\.2Gradient variance reduction

It is known that reducing gradient variances help speed up convergence\(Johnson and Zhang,[2013](https://arxiv.org/html/2605.20314#bib.bib38)\), and more recently\(Kothaet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib37)\)has reported that reducing gradient variances across batches can accelerate training in sparse parity and in\-context linear regression tasks\. However, the variance reduction point of view cannot explain the speedup with full\-batch updates, where each step uses the full dataset and hence has no sampling\-induced variances\.

#### 4\.1\.3Biased \(input\) distribution

Specific to sparse parity and SIM, prior theory indicates another explanation from the benefit of biased distributions\. The focus was on the biases in the input distribution, which can either be explicitly constructed sparsity\(Valiant,[2012](https://arxiv.org/html/2605.20314#bib.bib15); Abbeet al\.,[2023b](https://arxiv.org/html/2605.20314#bib.bib17)\), or randomly perturbed distributions\(Kalaiet al\.,[2009](https://arxiv.org/html/2605.20314#bib.bib16); Cornacchiaet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib19)\)\. Our setup is closer to the latter due to the deviation of the empirical mean from the population mean\. However, a critical difference is that the signal strength \(in terms of first\-order Fourier coefficients\) inCornacchiaet al\.\([2025](https://arxiv.org/html/2605.20314#bib.bib19)\)is exponential in the sparsekk, whereas the sampling bias in our analysis depends on the dataset bias only and is independent of the sparsity\. In particular, for a random subset of sizeNN, the signal strength inCornacchiaet al\.\([2025](https://arxiv.org/html/2605.20314#bib.bib19)\)isO​\(N−k/2\)O\(N^\{\-k/2\}\), which is much smaller than theO​\(N−1/2\)O\(N^\{\-1/2\}\)sampling bias as detailed in[Section4\.2](https://arxiv.org/html/2605.20314#S4.SS2)\.

We additionally provide two pieces of empirical evidence that the input distribution biases are not the primary driver of the gap: the speedup persists even when the dataset input bias is removed, and that online training with biased input distribution does not lead to the same amount of speedup\.

Small datasets without biases still lead to speedup\.We show that the small\-vs\-large gap exists when we remove input biases in the small datasets\. For parity, we ensure the training set satisfies𝔼^​\[x\]=0\\hat\{\\mathbb\{E\}\}\[x\]=0, and optionally further requiring𝔼^​\[y\]=0\\hat\{\\mathbb\{E\}\}\[y\]=0and𝔼^​\[x\|y=−1\]=𝔼^​\[x\|y=1\]=0\\hat\{\\mathbb\{E\}\}\[x\|y=\-1\]=\\hat\{\\mathbb\{E\}\}\[x\|y=1\]=0\. The small\-vs\-large gap persists as shown in[Figure3](https://arxiv.org/html/2605.20314#S4.F3)\(a\)\. Consistent results are observed on Transformer for both parity and SIM \([Figure16](https://arxiv.org/html/2605.20314#A2.F16)\)\. For SIM, we whiten the inputs to match both 1st and 2nd order statistics, i\.e\. transform the dataset withx~=Σ^−1/2​\(x−μ^\)\\tilde\{x\}=\\hat\{\\Sigma\}^\{\-1/2\}\(x\-\\hat\{\\mu\}\), whereμ^\\hat\{\\mu\}andΣ^\\hat\{\\Sigma\}are the empirical mean and covariance\.

Online training with biased distributions has minor effects\.On the other hand, we train with freshly sampled online data, whose per\-coordinate Bernoulli parameters are set to the biases of a finite offline dataset\. Specifically, ford=20d=20,k=6k=6, we take the biases from2i2^\{i\}samples fori∈\{4,6,8,10,12\}i\\in\\\{4,6,8,10,12\\\}\.Cornacchiaet al\.\([2025](https://arxiv.org/html/2605.20314#bib.bib19)\)indicates that using biasing the distributions will improve training speed, which we also confirm in[Figure3](https://arxiv.org/html/2605.20314#S4.F3)\(b\)\. However, unless the samples size is exceedingly small \(e\.g\., fewer than25=322^\{5\}=32samples\), the speedup is much more moderate compared to training on a small subset, and the small\-vs\-large gap persists\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_GD_d20k6_remove_input_bias.png)\(a\)\(20,6\)\(20,6\)\-parity, accuracy
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_GD_d20k6_add_input_bias.png)\(b\)\(20,6\)\(20,6\)\-parity, accuracy

Figure 3:Small\-vs\-large gap is not explained by input distribution biases\.\(Left\)Removing input biases does not affect the performance of training on a small set \(size2142^\{14\}\)\. Removing biases means requiring𝔼​\[x\]=0\\mathbb\{E\}\[x\]=0, or additionally requiring𝔼​\[y\]=0\\mathbb\{E\}\[y\]=0and𝔼​\[x\|y\]=0\\mathbb\{E\}\[x\|y\]=0\.\(Right\)Introducing biases to the large set does not bridge the small\-vs\-large gap\. The biases are taken from the empirical distribution of an size\-2m2^\{m\}set, for varyingmm\. When biased withm=14m=14, large\-set training still has a performance gap to training on the small set of size2142^\{14\}\. The maximum speedup would requirem=5m=5, which is much smaller than the actual small set size and not sufficient for learning\. Similar results are shown for SIM and with Transformers \([Figure16](https://arxiv.org/html/2605.20314#A2.F16)\)\.

### 4\.2Our explanation: dataset sampling bias accelerates learning by adjusting the relative norm growth

We claim that a primary source of the speedup from smaller datasets is their sampling biases, which adjust the norms across layers and effectively change the per\-layer learning rates\. Intuitively, for parity learned with 2\-layer MLPs, feature learning occurs in the first \(i\.e\., input\) layer, hence growing the second layer can speed up the feature learning in the first layer, and smaller datasets have stronger biases that grow the second layer faster\. The rest of this section formalizes this intuition\. We provide empirical evidence in[Section5](https://arxiv.org/html/2605.20314#S5)\.

For the theoretical analysis, we consider learning the22\-sparse parity task with a 2\-layer network with quadratic activation, i\.e\.f​\(x\)=a​σ​\(w⊤​x\)−1f\(x\)=a\\sigma\(w^\{\\top\}x\)\-1, withσ​\(z\):=12​z2\\sigma\(z\):=\\frac\{1\}\{2\}z^\{2\}\. The model is optimized using the correlation loss, i\.e\.L​\(f\)=𝔼x,y​\[ℓ​\(y,y′\)\]L\(f\)=\\mathbb\{E\}\_\{x,y\}\[\\ell\(y,y^\{\\prime\}\)\], whereℓ​\(y,y′\)=−y​y′\\ell\(y,y^\{\\prime\}\)=\-yy^\{\\prime\}\. Letη\\etadenote the learning rate\. We consider projected updates

a\(t\+1\)\\displaystyle a^\{\(t\+1\)\}=max⁡\{−1,min⁡\{1,a\(t\)−η​∇a\(t\)L\}\},w\(t\+1\)=w\(t\)−η​∇w\(t\)L‖w\(t\)−η​∇w\(t\)L‖2\.\\displaystyle=\\max\\\{\-1,\\min\\big\\\{1,a^\{\(t\)\}\-\\eta\\nabla\_\{a^\{\(t\)\}\}L\\big\\\}\\\},\\quad w^\{\(t\+1\)\}=\\frac\{w^\{\(t\)\}\-\\eta\\nabla\_\{w^\{\(t\)\}\}L\}\{\\bigl\\\|w^\{\(t\)\}\-\\eta\\nabla\_\{w^\{\(t\)\}\}L\\bigr\\\|\_\{2\}\}\.\(1\)Following standard practices, we initialize withw\(0\)∼Unif​\(𝕊d−1\)w^\{\(0\)\}\\sim\\text\{Unif\}\(\\mathbb\{S\}^\{d\-1\}\), anda\(0\)∼𝒩​\(0,1/m\)a^\{\(0\)\}\\sim\{\\mathcal\{N\}\}\(0,1/m\)wheremmcan be considered as a model width parameter\. We focus on the gradient descent \(GD\) process; the SGD process is a noisy version of GD and can be analyzed using techniques fromArouset al\.\([2020](https://arxiv.org/html/2605.20314#bib.bib40)\); Abbeet al\.\([2023a](https://arxiv.org/html/2605.20314#bib.bib23)\)\. Since the correlation loss does not introduce interaction across neurons, the analysis below can be considered as focusing on one neuron of a width\-mmnetwork\. Without loss of generality, we have the support on the first22coordinates and the minimizerw⋆w^\{\\star\}has\|w1⋆\|=\|w2⋆\|=12\|w\_\{1\}^\{\\star\}\|=\|w\_\{2\}^\{\\star\}\|=\\frac\{1\}\{\\sqrt\{2\}\}, andwi=0w\_\{i\}=0otherwise\.

We show that the smallerNNleads to faster convergence:

###### Theorem 1\(2\-phase training from standard initialization\)\.

Consider a 2\-phase training withm\>dm\>dand learning rateη=Θ​\(1\)\\eta=\\Theta\(1\)\. The first phase updates with[Equation1](https://arxiv.org/html/2605.20314#S4.E1)using a randomly sampled dataset of sized≤N≤d2d\\leq N\\leq d^\{2\}, until\|a\|≥a⋆\|a\|\\geq a\_\{\\star\}for some0<a⋆≲1\(N​d\)1/4​\(log⁡d/δ\)1/20<a\_\{\\star\}\\lesssim\\frac\{1\}\{\(Nd\)^\{1/4\}\(\\log d/\\delta\)^\{1/2\}\}; the second phase updates with[Equation1](https://arxiv.org/html/2605.20314#S4.E1)using the full population gradient, until reaching aw^\\hat\{w\}such that‖w^−w⋆‖2≲ε\\\|\\hat\{w\}\-w^\{\\star\}\\\|\_\{2\}\\lesssim\\sqrt\{\\varepsilon\}\. LetT1,T2T\_\{1\},T\_\{2\}denote the numbers of steps required in each phase respectively\. Letpall∈\(0,1\)p\_\{\\mathrm\{all\}\}\\in\(0,1\)be a universal constant wherepall=Θ​\(1\)p\_\{\\mathrm\{all\}\}=\\Theta\(1\)\.333pallp\_\{\\mathrm\{all\}\}is formally defined in Lemma[10](https://arxiv.org/html/2605.20314#Thmlemma10)\.Then, with probability at leastpall−δp\_\{\\mathrm\{all\}\}\-\\deltaover the random initialization and the phase\-1 samples,

T1≲a∗​Nη,T2≲2η​a∗​log⁡\(dε\)\.\\displaystyle T\_\{1\}\\lesssim\\frac\{a^\{\*\}\\sqrt\{N\}\}\{\\eta\},\\quad T\_\{2\}\\lesssim\\frac\{2\}\{\\eta a^\{\*\}\}\\log\\Big\(\\frac\{d\}\{\\varepsilon\}\\Big\)\.\(2\)With the optimal choice ofa⋆a\_\{\\star\}, the total number of steps isO​\(\(N​d\)1/4​log⁡\(dε\)​log1/2⁡\(dδ\)\)O\\left\(\(Nd\)^\{1/4\}\\log\\left\(\\frac\{d\}\{\\varepsilon\}\\right\)\\log^\{1/2\}\\left\(\\frac\{d\}\{\\delta\}\\right\)\\right\)\.

[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1)implies that a smallNNleads to a direct saving in the number of stepsTT\. One could alternatively skip Phase 1 and train directly on the full population\. This will requireO​\(m1/2​log⁡\(d/ϵ\)\)O\(m^\{1/2\}\\log\(d/\\epsilon\)\)steps \([Lemma6](https://arxiv.org/html/2605.20314#Thmlemma6)\), which is worse than the 2\-phase convergence whenm≫d2m\\gg d^\{2\}\.

Proof sketch\.The gradient magnitude ofwwdepends on the magnitude ofaa, and sinceaais initialized to be very small \(O​\(1/m\)O\(1/\\sqrt\{m\}\)\), this slows down the learning ofww\. However, the sampling bias of small datasets can quickly grow the magnitude ofaaat the initial stage of training\. The gradient ofaais given byq\(t\):=\(w\(t\)\)⊤​𝑴^​w\(t\)q^\{\(t\)\}:=\(w^\{\(t\)\}\)^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(t\)\}, where𝑴^:=𝔼^​\[y​x​x⊤\]:=1N​∑s=1Ny\(s\)​x\(s\)​x\(s\)⊤\\widehat\{\{\\bm\{M\}\}\}:=\\widehat\{\\mathbb\{E\}\}\[yxx^\{\\top\}\]:=\\frac\{1\}\{N\}\\sum\_\{s=1\}^\{N\}y^\{\(s\)\}x^\{\(s\)\}x^\{\(s\)\\top\}\. Due to the anti\-concentration of𝑴^\\widehat\{\{\\bm\{M\}\}\},q\(t\)q^\{\(t\)\}is on the order ofN−1/2N^\{\-1/2\}, whereas the population quantity is on the order of1/d1/d\. Therefore, in the first phase,aagrows at a linear rate ofN−1/2N^\{\-1/2\}, hence the number of steps foraato reacha⋆a\_\{\\star\}is proportional toN1/2N^\{1/2\}\. Onceaareachesa⋆a\_\{\\star\}, we switch to the second phase which uses population updates\. The analysis is on the power iteration on the true moment𝑴\{\\bm\{M\}\}, whose contraction rate depends onη​a⋆\\eta a\_\{\\star\}\.

In fact, the first\-phase analysis primarily relies on the anti\-concentration of the empirical moment matrix𝑴^\\widehat\{\{\\bm\{M\}\}\}, which is largely independent of the true label signal\. This suggests that the early\-stage acceleration should still persist even when the labels are random\. We further study this in[Section5\.1](https://arxiv.org/html/2605.20314#S5.SS1)\. Furthermore, the convergence rate of the feature directionwwdepends on the strength of its gradient signal relative to its scale\. The second\-phase analysis reveals that the strength of the gradient is governed by its learning rate and the initial magnitude ofaa\. This directly motivates empirical interventions explored in[Section5\.2](https://arxiv.org/html/2605.20314#S5.SS2), where we manipulate per\-layer learning rates and layer\-wise initialization scales\.

Notably, theN−1/2N^\{\-1/2\}bias in Phase 1 contrasts with the result inCornacchiaet al\.\([2025](https://arxiv.org/html/2605.20314#bib.bib19)\), where a per\-coordinate biasη\\etainduces a Fourier coefficient on the order ofηk\\eta^\{k\}\. Such a coefficient is non\-negligible only whenkkis bounded\. In contrast, our analysis is independent ofkk, suggesting that the small\-vs\-large gap persists even in dense regimes wherekkcan be as large asdd, as confirmed empirically \([Figure14](https://arxiv.org/html/2605.20314#A2.F14)\)\. Moreover, the notion of “bias” differ between the two settings\.Cornacchiaet al\.\([2025](https://arxiv.org/html/2605.20314#bib.bib19)\)assumesη=Ω​\(1\)\\eta=\\Omega\(1\), which is substantially larger than theO​\(N−1/2\)O\(N^\{\-1/2\}\)sampling bias unlessNNis unreasonably small, as discussed in[Section4\.1\.3](https://arxiv.org/html/2605.20314#S4.SS1.SSS3)\.

## 5Empirical evidence for relative norm growth

This section provides empirical evidence supporting the claim in[Section4](https://arxiv.org/html/2605.20314#S4)that less data leads to faster learning by affecting the relative growth rates of the two layers\.[Figure12](https://arxiv.org/html/2605.20314#A2.F12)provides direct observational evidence: during the initial phase of training, the weight norm ratio‖a‖2‖W‖F\\frac\{\\\|a\\\|\_\{2\}\}\{\\\|W\\\|\_\{F\}\}increases more rapidly when the dataset is smaller, for both parity and SIM\.

We further considerinterventionson the training process to provide stronger evidence, in terms of 1\)data, where we show that the speedup from small\-set exists even when the labels are random; 2\)weight norms, by changing initialization scale or the adoption of normalization layers; and 3\) layer\-wiselearning ratecontrols\. All results support our hypothesis, as detailed below\.

### 5\.1Small\-vs\-large gap exists when training first onrandomlabels

One implication of[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1)is that the role of small dataset biases is to help grow the second layer faster, which in turn amplifies the update in the input layer which is responsible for feature learning\. This suggests that any methods that help grow the second layer should achieve similar acceleration, even when the growth results from signals unrelated to the target function\.

Training withrandom labelsis one example that provides such a speedup\. Specifically, consider a modified 2\-phase training procedure similar to[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1), where we alter the first phase to train with random labels, i\.e\.,yyis sampled i\.i\.d\. from Uniform\{−1,\+1\}\\\{\-1,\+1\\\}\. We show that the gradient of the second layer \(aa\) is still of orderΘ​\(N−1/2\)\\Theta\(N^\{\-1/2\}\)around initialization\.

###### Corollary 2\(2\-phase training, with Phase 1 on random labels\)\.

Consider the setting in[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1)where the first phase uses a randomly sampled dataset of sized≤N≤d2d\\leq N\\leq d^\{2\}, with labels uniformly sampled from\{−1,\+1\}\\\{\-1,\+1\\\}\. Similar to[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1), letpall∈\(0,1\)p\_\{\\mathrm\{all\}\}\\in\(0,1\)be a universal constant wherepall=Θ​\(1\)p\_\{\\mathrm\{all\}\}=\\Theta\(1\)\. Then, with probability at leastpall−δp\_\{\\mathrm\{all\}\}\-\\deltaover the random initialization and the phase\-1 samples, the total number of steps required to reach aw^\\hat\{w\}such that‖w^−w⋆‖2≤ε\\\|\\hat\{w\}\-w^\{\\star\}\\\|\_\{2\}\\leq\\varepsilonis

T=O​\(Nη​d\+dη​log⁡\(dε\)\)\.T=O\\left\(\\frac\{\\sqrt\{N\}\}\{\\eta\\sqrt\{d\}\}\+\\frac\{\\sqrt\{d\}\}\{\\eta\}\\log\\Big\(\\frac\{d\}\{\\varepsilon\}\\Big\)\\right\)\.

We verify the speedup empirically by training a MLP first on a small set ofrandomlabels, and then switching to the large\-set training\. Our hypothesis will be supported if the first phase on random labels leads to accelerated learning of the actual task\. We experiment with MLP on parity and SIM: for parity, the random labels are obtained by samplingyyuniformly from\{−1,\+1\}\\\{\-1,\+1\\\}; for SIM, we sample a random feature vectorwrandom∼𝒩​\(0,I\)w\_\{\\text\{random\}\}\\sim\{\\mathcal\{N\}\}\(0,I\)and usewrandomw\_\{\\text\{random\}\}along with the true link function to label the small dataset\. As shown in[Figure4](https://arxiv.org/html/2605.20314#S5.F4), both the accuracy and the weight norm growth \(measured by‖a‖2/‖W‖F\\\|a\\\|\_\{2\}/\\\|W\\\|\_\{F\}\) are sensitive only to the dataset size, but not the label choice \(i\.e\., true or random labels\)\. This also agrees with our theory \(i\.e\., phase 1 analysis\) that the benefit of sampling bias exists even with random labels\. We observe similar results for Transformer learning mod addition, where an initial phase of random label training speeds up the subsequent actual learning \([Figure20](https://arxiv.org/html/2605.20314#A2.F20)\)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/random_label_comparison_d20k6_m64_parity.png)\(a\)\(20,6\)\(20,6\)\-parity, acc\.
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_parity_a_over_w_norm_random_label.png)\(b\)\(20,6\)\(20,6\)\-parity,‖a‖2‖W‖F\\frac\{\\\|a\\\|\_\{2\}\}\{\\\|W\\\|\_\{F\}\}
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim_random_label.png)\(c\)SIM,d=40d=40, loss
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim_a_over_w_norm_random_label.png)\(d\)SIM,d=40d=40,‖a‖2‖W‖F\\frac\{\\\|a\\\|\_\{2\}\}\{\\\|W\\\|\_\{F\}\}

Figure 4:Training on small datasets withrandom labelsleads to faster learning\. For GD on both parity and SIM, the initial random\-training leads to significant speedup and faster growth of‖a‖2/‖W‖F\\\|a\\\|\_\{2\}/\\\|W\\\|\_\{F\}\. The blue/yellow curves correspond to large/small sets\. The green curves correspond to training first on a small set ofrandomlabels and then switching to large sets with true labels\.
### 5\.2Small\-vs\-large gap diminishes with parameter\-wise interventions

This section investigates direct interventions on layer\-wise scalings\. We find that, consistent with the analysis in[Section4\.2](https://arxiv.org/html/2605.20314#S4.SS2), these interventions can significantly reduce and sometimes eliminate the small\-vs\-large performance gap\. However, smaller sets exhibit favorable optimization biases that make training more robust to hyperparameter choices\. Results are reported on MLPs trained with the SGD optimizer and Transformers trained with AdamW\.444AdamW affects MLP and Transformer training differently, which we discuss in more details in[SectionB\.2\.2](https://arxiv.org/html/2605.20314#A2.SS2.SSS2)\.

#### 5\.2\.1MLP: layer\-wise interventions

One takeaway from[Section4\.2](https://arxiv.org/html/2605.20314#S4.SS2)is that the small\-vs\-large gap comes from the fact that smaller\-set training can better adjust the relative norms across layers\. Thus, interventions that directly adjust the relative layer norms should be able to reduce the gap\. Such relative norm change can be achieved in two ways: 1\) at initialization, by directly adjusting the standard deviations from which the weights are sampled from, and 2\) throughout training, by explicitly supplying layer\-specific learning rates\.

Forinitialization scales, we multiply the initial weights by layer\-specific constants\.[Figure6](https://arxiv.org/html/2605.20314#S5.F6)shows the final iterate test accuracy for a large sweep over scales for a 2\-hidden\-layer MLP, where we consider 2 scaling parameters, one for the first layer weights and one for the other layers\. Note that there exist scaling constants that completely close the small\-vs\-large gap between size\-2142^\{14\}random subsets and the full population \(2202^\{20\}\)\. In particular, we observe that the larger dataset performs worse at the default init \(marked by the red star\), but shrinking the first layer scale and growing the other layers improves its performance and eventually makes the gap vanishes\. However, identifying the right constants requires searching through a large set of combinations, whereas small\-set training is able to adjust the scaling automatically and is much more robust to initialization scale\.

For MLP with ReLU activations, growing one layer leads to larger updates on another\. Therefore, shrinking the input layer and scaling up the output layer can be effectively considered as usinglayer\-wise learning rateswhich are larger for the input layer and smaller for the other\. The empirical evidence agrees with this hypothesis\. Letη1,η2\\eta\_\{1\},\\eta\_\{2\}denote the learning rates for the first and second layer, respectively\.[Figure6](https://arxiv.org/html/2605.20314#S5.F6)shows that the optimal choice of\(η1,η2\)\(\\eta\_\{1\},\\eta\_\{2\}\), whereη1≫η2\\eta\_\{1\}\\gg\\eta\_\{2\}, significantly reduces the small\-vs\-large gap between the full population \(N=220N=2^\{20\}\) versus a random subset \(N=214N=2^\{14\}\)\.

##### What is the optimal scaling?

The above results show that the small\-vs\-large gap can be bridged when using proper layerwise initialization scaling or learning rates\. It is then desirable to identify such a scheme without extensive hyperparameter search\. A natural candidate is theμ​P\\mu Pparameterization\(Yang and Hu,[2020](https://arxiv.org/html/2605.20314#bib.bib18); Yanget al\.,[2022](https://arxiv.org/html/2605.20314#bib.bib7)\), which is designed to maximize feature learning\. Empirically, we findμ​P\\mu Pto be a strong starting point, bridging the gap in 2\-layer MLPs for parity \([Figure7](https://arxiv.org/html/2605.20314#S5.F7)\)\. However, the small\-vs\-large gap exists for SIM\. Following the argument in[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1), we additionally consider a one\-dimensional search of a single parameterα\\alpha, which adjusts initialization scaling by dividing the first layer’s standard deviation byα\\alphaand multiplying the second layer’s byα\\alpha\. As shown in[Figure7](https://arxiv.org/html/2605.20314#S5.F7), suchα\\alphascaling suffices to bridge the small\-vs\-large gap in both parity and SIM, and the optimalα\\alpharemainsconstantacross model widths \([Figure23](https://arxiv.org/html/2605.20314#A2.F23)\)\.

One could potentially consider a more thorough search on the design choices of layerwise initialization or learning rates scheme, for which there is a vast existing literature as discussed in[Section1\.1](https://arxiv.org/html/2605.20314#S1.SS1)\. Our results on the small\-vs\-large gap provide an orthogonal angle, suggesting thatdata use strategycan offer helpful inductive biases that make the model robust to the scale or learning rates\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ezra_figs/mlp_online_heatmap.png)\(a\)Full population \(size2202^\{20\}\)
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ezra_figs/mlp_subset_heatmap.png)\(b\)Random subsets of size2142^\{14\}

Figure 5:Proper initialization removes the small\-vs\-large gap, though smaller\-set training is more robust to the initialization scale\. Results are shown for\(20,6\)\(20,6\)\-parity with MLP\. The heatmaps show the accuracies \(averaged over 256 seeds\) using per\-setup best learning rate\.
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/layerwise_lr_MLP_GD_d20k6_m64.png)Figure 6:Layer\-wise learning rate removes the small\-vs\-large gap\.The optimal learning rate forwwis larger than that ofaa\. Results are shown for MLP on \(20, 6\)\-parity, trained with GD\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/muP_comparison_d20k6_m64.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/sim_mlp_mup.png)

Figure 7:Comparison toμ​P\\mu P\(Yanget al\.,[2022](https://arxiv.org/html/2605.20314#bib.bib7)\)\.μ​P\\mu Pand theα\\alphascaling both close the small\-vs\-large gap in 2\-layer width\-64 MLPs\.

#### 5\.2\.2Transformer: scaling ofWq,WkW\_\{q\},W\_\{k\}

The small\-vs\-large gap is observed in Transformers for both mini\-batch updates \([Figure1](https://arxiv.org/html/2605.20314#S1.F1)\) and full\-batch updates \([Figure15](https://arxiv.org/html/2605.20314#A2.F15)\)\. Similar to MLP, we find that the gap can be attributed to favorable optimization biases induced by small\-set training, and proper interventions can reduce the gap\. We focus on interventions onWq,WkW\_\{q\},W\_\{k\}matrices which directly affect attention, and provide ablation studies on other parameters in[SectionB\.2\.2](https://arxiv.org/html/2605.20314#A2.SS2.SSS2)\.

The small\-vs\-large gap is reduced when increasing thelearning rateforWq,WkW\_\{q\},W\_\{k\}or increasing theinitialization scaleofWq,WkW\_\{q\},W\_\{k\}for parity \([Figure8](https://arxiv.org/html/2605.20314#S5.F8)\)\. Specifically, for the learning rate intervention, we tune a separate learning rate forWq,WkW\_\{q\},W\_\{k\}while keeping the learning rate for other parameters fixed at the optimal value for the global learning rate\. In[Figure8\(a\)](https://arxiv.org/html/2605.20314#S5.F8.sf1), the optimal QK learning rate is 3\.6x that of the optimal global learning rate for large\-set training \(i\.e\. using online mini\-batches\), 2\.5x for the small\-set training with 100 epochs, and 1\.25x for the small\-set 6\-phase training\. For QK initialization, large\-set training benefits greatly from increased initialization scale; the optimal scale is 8x the default initialization, which sharpens the attention logits by a factor of 64 when considering bothWq,WkW\_\{q\},W\_\{k\}\. In contrast, small\-set training gets much more moderate improvements from tuning the initialization scale: 100\-epoch training gets the best speedup when scaling the default by 2x, while 6\-phase training sees no benefit from scaling changes\. Similar phenomena are observed in SIM and ICL as well \([Figure9](https://arxiv.org/html/2605.20314#S5.F9)\), where tuning the QK initialization narrows \(but not necessarily closes\) the small\-vs\-large gap\.

##### Connection to and implications of QK normalization

The attention logits scaling adjustment mentioned above is reminiscent of QK normalization, which normalizes the key and query vectors to a sphere\(Henryet al\.,[2020](https://arxiv.org/html/2605.20314#bib.bib26); Dehghaniet al\.,[2023](https://arxiv.org/html/2605.20314#bib.bib31); Zhaiet al\.,[2023](https://arxiv.org/html/2605.20314#bib.bib25)\)\. For example, the commonly adopted RMSNorm imposes a scaling of≍d\\asymp\\sqrt\{d\}compared to default non\-normalized version\. QK normalization is now common practice for preventing training instabilities, as it constrains the attention logit magnitudes\(Dehghaniet al\.,[2023](https://arxiv.org/html/2605.20314#bib.bib31); Zhaiet al\.,[2023](https://arxiv.org/html/2605.20314#bib.bib25); Wortsmanet al\.,[2023](https://arxiv.org/html/2605.20314#bib.bib30)\)\.

Our results suggest that its implications on optimization may be more nuanced\. On the plus side, for sparse parity and SIM, QK normalization can significantly speed up optimization for online training, removing the initial saddle\-like behavior and almost entirely closing the small\-vs\-large gap \([Figure8\(c\)](https://arxiv.org/html/2605.20314#S5.F8.sf3),[Figure9\(b\)](https://arxiv.org/html/2605.20314#S5.F9.sf2)\)\. However, such acceleration is not universal across tasks\. In particular, QK normalization hurts both online and subset training for ICL \([Figure9\(d\)](https://arxiv.org/html/2605.20314#S5.F9.sf4)\) and mod addition \([Figure27\(b\)](https://arxiv.org/html/2605.20314#A2.F27.sf2)\)\. Moreover, when data is repeated during training, QK normalization can exacerbate overfitting, which is observed in both mini\-batch \([Figure8\(c\)](https://arxiv.org/html/2605.20314#S5.F8.sf3), 6\-phase\) and full\-batch training \([Figure27\(a\)](https://arxiv.org/html/2605.20314#A2.F27.sf1)\); a train\-test comparison is provided in[Figure28](https://arxiv.org/html/2605.20314#A2.F28)\. Understanding the exact mechanism is left as future work\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/sgd_transformer_d20_k6_compare_qk_lr.png)\(a\)QK learning rate
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/sgd_transformer_d20_k6_compare_qk_init.png)\(b\)QK initialization scale
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/sgd_transformer_d20_k6_compare_qk_norm.png)\(c\)QK normalization

Figure 8:Small\-vs\-large gap in Transformers can be reduced with interventions onWQ,WKW\_\{Q\},W\_\{K\}\.Results are shown for\(20,6\)\(20,6\)\-parity with two\-layer Transformers; similar results are also observed for SIM and ICL \([Figure9](https://arxiv.org/html/2605.20314#S5.F9)\)\. The small\-vs\-large gap is reduced by \(a\) tuning learning rate onWQ,WKW\_\{Q\},W\_\{K\}separately than the rest of the parameters; \(b\) using the optimal initialization scaling ofWQ,WKW\_\{Q\},W\_\{K\}; \(c\) using QK\-layernorm\. Solid lines are the default setup where we observe clear small\-vs\-large gaps, and dashed lines are interventions that reduce or even revert the small\-vs\-large gap\. Interventions are not helpful especially for 6\-phase training: for \(b\), the optimal initialization scaling is the default one, hence we omit the corresponding dashed line; for \(c\), QK layernorm slows down 6\-phase training\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/transformer_sim_qk_scale.png)\(a\)SIM QK scale
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/transformer_sim_qk_norm.png)\(b\)SIM QK normalization
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/transformer_sgd_icl_qk_scale.png)\(c\)ICL QK scale
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/transformer_sgd_icl_qk_norm.png)\(d\)ICL QK normalization

Figure 9:Small\-vs\-large gap in Transformers can be reduced with interventions onWQ,WKW\_\{Q\},W\_\{K\}, for \(a\) SIM and \(b\) in\-context learning regression trained using two\-layer Transformers\. Solid lines are the default setup where we observe clear small\-vs\-large gaps, and dashed lines are interventions that reduce or even revert the small\-vs\-large gap\.

## 6Discussions

This work studies the small\-vs\-large gap, a phenomenon that under a fixed compute budget, repeating on smaller dataset outperforms training on larger dataset\. We identify the sampling bias from smaller datasets as a key mechanism underlying this effect, which helps adjust the relative norms across layers\. We formalize this mechanism theoretically and substantiate it with empirical results\. Notably, across tasks and for both MLP and Transformers, we find that the small\-vs\-large gap persists even when training on randomly labeled data, and that proper choices of initialization scale and layer\-wise learning rates can substantially reduce or even close the gap\.

Beyond the small\-vs\-large gap itself, our results suggest that dataset size and repetition can be leveraged as favorable optimization biases\. In particular, repeated training on smaller subsets can act as an implicit layerwise preconditioner that steers models into a more favorable feature\-learning regime, partially substituting for carefully tuned initialization, normalization, or layerwise learning rates\. This perspective motivates training pipelines that explicitly separate an early “optimization\-shaping” phase from a later “coverage/generalization” phase, and suggests that compute\-optimal training may require jointly considering data, optimizer, and parameterization\.

Below we discuss other implications of our findings\.

##### How to enlarge the small\-vs\-large gap

We investigate how the small\-vs\-large gap changes as we scale along different axes\. The following factors are found to widen the gap:

- •Increasing model depth\.The small\-vs\-large gap is due to the relative norm growth across layers, which suggests that the gap should be more pronounced when the model depth increases as the scaling effect will percolate exponentially in depth\. Our empirical findings confirm this, where the small\-vs\-large gap widens for both MLP \([Figure11\(a\)](https://arxiv.org/html/2605.20314#S6.F11.sf1)\) and Transformers \([Figure26](https://arxiv.org/html/2605.20314#A2.F26)\)\.
- •Reducing model width\.We find that the small\-vs\-large gap is the widest at a small model width \(e\.g\., width 64\), as shown in[Figure11\(b\)](https://arxiv.org/html/2605.20314#S6.F11.sf2)for parity and[Figure24](https://arxiv.org/html/2605.20314#A2.F24)for SIM\. We hypothesize this is because models with standard initialization approach the kernel regime as the width increases, suggesting that the small\-vs\-large gap is specific to feature learning\.
- •Increasing task complexity\.More difficult tasks lead to a wider small\-vs\-large gap, across tasks and architectures\.[Figure10\(a\)](https://arxiv.org/html/2605.20314#S6.F10.sf1)shows SIM results on MLP trained with full\-batches, where the gap is wider ond=80d=80thand=40d=40\.[Figure10\(b\)](https://arxiv.org/html/2605.20314#S6.F10.sf2)shows parity results on Transformers trained with mini\-batch updates, where\(20,6\)\(20,6\)\-parity sees a bigger gap than\(10,6\)\(10,6\)\-parity\.
- •Smoother transition between training phases\.Compared to the 2\-phase analysis in[Section4\.2](https://arxiv.org/html/2605.20314#S4.SS2), growing the repeating subset size more gradually helps improve training and hence increases the small\-vs\-large gap\. In[Figure8](https://arxiv.org/html/2605.20314#S5.F8), training transformer using 6\-phase training greatly enlarges the gap comparing to 1\-phase training in parity\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim_d80.png)

\(a\)MLP on SIM with dimensiond=40d=40\(left\) andd=80d=80\(right\), trained with GD\.
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ezra_figs/transformer_d=10k=6.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/sgd_transformer_d20_k6.png)

\(b\)Transformer on parity with dimensiond=10d=10\(left\) andd=20d=20\(right\), trained with batched AdamW\.

Figure 10:Increasing the task complexityincreases the small\-vs\-large gap\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_GD_d20k6_m64_full_vs_subset.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_GD_d20k6_m64_full_vs_subset_depth4.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_GD_d20k6_m64_full_vs_subset_depth6.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_GD_d20k6_m64_full_vs_subset_depth8.png)

\(a\)MLP on\(20,6\)\(20,6\)\-parity, acrossvarying depths\(2, 4, 6, 8\)\.
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/vary_width_gd_mlp_d20_k6_width32.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/vary_width_gd_mlp_d20_k6_width64.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/vary_width_gd_mlp_d20_k6_width256.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/vary_width_gd_mlp_d20_k6_width1024.png)

\(b\)MLP on\(20,6\)\(20,6\)\-parity, acrossvarying widths\(32, 64, 256, 1024\)\.

Figure 11:Model sizes affect the small\-vs\-large gap\. Increasing model depth widens the gap \(top row\), whereas increasing the model width reduces the gap \(bottom row\)\. Results are shown on sparse parity learned with MLP using full\-batch updates; similar results are also observed on SIM with MLP \([Figure24](https://arxiv.org/html/2605.20314#A2.F24)\) and parity with Transformers \([Figure26](https://arxiv.org/html/2605.20314#A2.F26)\)\.
##### When is data repetition \(not\) helpful?

Even though the small\-vs\-large gap has been observed across various choice of tasks, architectures, and optimizers, we do not believe it to exist universally\. A classic example is linear regression, which does not have a small\-vs\-large gap\. While prior work has shown that multi\-epoch training improves the statistical complexity for linear regression\(Pillaud\-Vivienet al\.,[2018](https://arxiv.org/html/2605.20314#bib.bib14); Linet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib1)\), no result has suggested an improvement in terms of optimization steps or the compute cost\. We hypothesize that the speedup from data repetition is exclusive fornon\-convexoptimization, and acomputational\-statistical gapis likely required\. Connecting to practice, especially the era of large\-language model training, we do not expect this phenomena to be directly observable in many real\-world language corpora, due to both \(near\) duplicates widely present in the web datasets, and the lack of clear structure in free\-form texts\. However, we hypothesize that the small\-vs\-large gap may be of interest for morestructured taskssuch as formal reasoning\. As empirical evidence, the concurrent work\(Kopiczkoet al\.,[2026](https://arxiv.org/html/2605.20314#bib.bib39)\)observes the small\-vs\-large gap in LLM post\-training for math and coding tasks\.

## Acknowledgment

We thank Alex Damian, Aditi Raghunathan, Andrej Risteski, Daniel Hsu, Eric Wong, and Samuel Deng for helpful discussions and feedback\. JL is supported by NSF award DMS\-2502259 and ONR N00014\-22\-1\-2713\. SG was supported in part by an AI2050 Early Career Fellowship from Schmidt Sciences\. BL was supported by the Kempner Fellowship from the Kempner Institute at Harvard\. This work was enabled in part by a gift from the Chan Zuckerberg Initiative Foundation to establish the Kempner Institute for the Study of Natural and Artificial Intelligence\. Part of this work was done when the authors were participating in the Special Year on Large Language Models and Transformers and the Program on Modern Paradigms in Generalization at the Simons Institute for the Theory of Computing at UC Berkeley\.

## References

- SGD learning on neural networks: leap complexity and saddle\-to\-saddle dynamics\.Annual Conference Computational Learning Theory\.External Links:[Document](https://dx.doi.org/10.48550/arXiv.2302.11055)Cited by:[§4\.2](https://arxiv.org/html/2605.20314#S4.SS2.p2.14)\.
- E\. Abbe, E\. Cornacchia, and A\. Lotfi \(2023b\)Provable advantage of curriculum learning on parity targets with mixed inputs\.Advances in Neural Information Processing Systems36,pp\. 24291–24321\.Cited by:[§1](https://arxiv.org/html/2605.20314#S1.p3.1),[§4\.1\.3](https://arxiv.org/html/2605.20314#S4.SS1.SSS3.p1.4)\.
- L\. Arnaboldi, Y\. Dandi, F\. Krzakala, L\. Pesce, and L\. Stephan \(2024\)Repetita iuvant: data repetition allows sgd to learn high\-dimensional multi\-index functions\.arXiv preprint arXiv: 2405\.15459\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p3.1),[§4\.1\.1](https://arxiv.org/html/2605.20314#S4.SS1.SSS1.p1.1)\.
- G\. B\. Arous, R\. Gheissari, and A\. Jagannath \(2020\)Online stochastic gradient descent on non\-convex losses from high\-dimensional inference\.Journal of machine learning research\.Cited by:[§4\.2](https://arxiv.org/html/2605.20314#S4.SS2.p2.14)\.
- S\. Azulay, E\. Moroshko, M\. S\. Nacson, B\. E\. Woodworth, N\. Srebro, A\. Globerson, and D\. Soudry \(2021\)On the implicit bias of initialization shape: beyond infinitesimal mirror descent\.InProceedings of the 38th International Conference on Machine Learning,M\. Meila and T\. Zhang \(Eds\.\),Proceedings of Machine Learning Research, Vol\.139,pp\. 468–477\.External Links:[Link](https://proceedings.mlr.press/v139/azulay21a.html)Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p3.1)\.
- B\. Barak, B\. L\. Edelman, S\. Goel, S\. Kakade, E\. Malach, and C\. Zhang \(2022\)Hidden progress in deep learning: sgd learns parities near the computational limit\.Neural Information Processing Systems\.External Links:[Document](https://dx.doi.org/10.48550/arXiv.2207.08799)Cited by:[Remark](https://arxiv.org/html/2605.20314#Thmremarkx1.p1.1),[Remark 7](https://arxiv.org/html/2605.20314#Thmtheorem7.p1.1)\.
- F\. Charton and J\. Kempe \(2024\)Emergent properties with repeated examples\.arXiv preprint arXiv: 2410\.07041\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p1.1),[§2](https://arxiv.org/html/2605.20314#S2.SS0.SSS0.Px2.p1.7),[§3](https://arxiv.org/html/2605.20314#S3.SS0.SSS0.Px1.p1.1)\.
- E\. Cornacchia, D\. Mikulincer, and E\. Mossel \(2025\)Low\-dimensional functions are efficiently learnable under randomly biased distributions\.InThe Thirty Eighth Annual Conference on Learning Theory, 30\-4 July 2025, Lyon, France,N\. Haghtalab and A\. Moitra \(Eds\.\),Proceedings of Machine Learning Research, Vol\.291,pp\. 1331–1365\.External Links:[Link](https://proceedings.mlr.press/v291/cornacchia25a.html)Cited by:[§1](https://arxiv.org/html/2605.20314#S1.p3.1),[§4\.1\.3](https://arxiv.org/html/2605.20314#S4.SS1.SSS3.p1.4),[§4\.1\.3](https://arxiv.org/html/2605.20314#S4.SS1.SSS3.p4.5),[§4\.2](https://arxiv.org/html/2605.20314#S4.SS2.p7.10)\.
- Y\. Dandi, E\. Troiani, L\. Arnaboldi, L\. Pesce, L\. Zdeborová, and F\. Krzakala \(2024\)The benefits of reusing batches for gradient descent in two\-layer networks: breaking the curse of information and leap exponents\.InForty\-first International Conference on Machine Learning, ICML 2024, Vienna, Austria, July 21\-27, 2024,External Links:[Link](https://openreview.net/forum?id=iKkFruh4d5)Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p3.1),[§4\.1\.1](https://arxiv.org/html/2605.20314#S4.SS1.SSS1.p1.1)\.
- M\. Dehghani, J\. Djolonga, B\. Mustafa, P\. Padlewski, J\. Heek, J\. Gilmer, A\. P\. Steiner, M\. Caron, R\. Geirhos, I\. Alabdulmohsin, R\. Jenatton, L\. Beyer, M\. Tschannen, A\. Arnab, X\. Wang, C\. Riquelme Ruiz, M\. Minderer, J\. Puigcerver, U\. Evci, M\. Kumar, S\. V\. Steenkiste, G\. F\. Elsayed, A\. Mahendran, F\. Yu, A\. Oliver, F\. Huot, J\. Bastings, M\. Collier, A\. A\. Gritsenko, V\. Birodkar, C\. N\. Vasconcelos, Y\. Tay, T\. Mensink, A\. Kolesnikov, F\. Pavetic, D\. Tran, T\. Kipf, M\. Lucic, X\. Zhai, D\. Keysers, J\. J\. Harmsen, and N\. Houlsby \(2023\)Scaling vision transformers to 22 billion parameters\.InProceedings of the 40th International Conference on Machine Learning,A\. Krause, E\. Brunskill, K\. Cho, B\. Engelhardt, S\. Sabato, and J\. Scarlett \(Eds\.\),Proceedings of Machine Learning Research, Vol\.202,pp\. 7480–7512\.External Links:[Link](https://proceedings.mlr.press/v202/dehghani23a.html)Cited by:[§5\.2\.2](https://arxiv.org/html/2605.20314#S5.SS2.SSS2.Px1.p1.1)\.
- K\. E\. Everett, L\. Xiao, M\. Wortsman, A\. A\. Alemi, R\. Novak, P\. J\. Liu, I\. Gur, J\. Sohl\-Dickstein, L\. P\. Kaelbling, J\. Lee, and J\. Pennington \(2024\)Scaling exponents across parameterizations and optimizers\.InForty\-first International Conference on Machine Learning, ICML 2024, Vienna, Austria, July 21\-27, 2024,External Links:[Link](https://openreview.net/forum?id=0ksNeD1SJT)Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p3.1)\.
- A\. Henry, P\. R\. Dachapally, S\. Pawar, and Y\. Chen \(2020\)Query\-key normalization for transformers\.arXiv preprint arXiv: 2010\.04245\.Cited by:[§5\.2\.2](https://arxiv.org/html/2605.20314#S5.SS2.SSS2.Px1.p1.1)\.
- D\. Hernandez, T\. Brown, T\. Conerly, N\. DasSarma, D\. Drain, S\. El\-Showk, N\. Elhage, Z\. Hatfield\-Dodds, T\. Henighan, T\. Hume, S\. Johnston, B\. Mann, C\. Olah, C\. Olsson, D\. Amodei, N\. Joseph, J\. Kaplan, and S\. McCandlish \(2022\)Scaling laws and interpretability of learning from repeated data\.arXiv preprint arXiv: 2205\.10487\.Cited by:[§1](https://arxiv.org/html/2605.20314#S1.p1.1)\.
- R\. Johnson and T\. Zhang \(2013\)Accelerating stochastic gradient descent using predictive variance reduction\.InAdvances in Neural Information Processing Systems,C\.J\. Burges, L\. Bottou, M\. Welling, Z\. Ghahramani, and K\.Q\. Weinberger \(Eds\.\),Vol\.26,pp\.\.External Links:[Link](https://proceedings.neurips.cc/paper_files/paper/2013/file/ac1dd209cbcc5e5d1c6e28598e8cbbe8-Paper.pdf)Cited by:[§4\.1\.2](https://arxiv.org/html/2605.20314#S4.SS1.SSS2.p1.1)\.
- A\. T\. Kalai, A\. Samorodnitsky, and S\. Teng \(2009\)Learning and smoothed analysis\.In2009 50th Annual IEEE Symposium on Foundations of Computer Science,pp\. 395–404\.Cited by:[§1](https://arxiv.org/html/2605.20314#S1.p3.1),[§4\.1\.3](https://arxiv.org/html/2605.20314#S4.SS1.SSS3.p1.4)\.
- D\. J\. Kopiczko, S\. Vaze, T\. Blankevoort, and Y\. M\. Asano \(2026\)Data repetition beats data scaling in long\-cot supervised fine\-tuning\.External Links:2602\.11149,[Link](https://arxiv.org/abs/2602.11149)Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p1.1),[§6](https://arxiv.org/html/2605.20314#S6.SS0.SSS0.Px2.p1.1)\.
- S\. Kotha, U\. Girit, T\. Kumar, G\. R\. Ghosal, and A\. Raghunathan \(2025\)Lowering data diversity can accelerate training: case studies in synthetic tasks\.External Links:[Link](https://openreview.net/forum?id=xlxDTVAbNM)Cited by:[§1](https://arxiv.org/html/2605.20314#S1.p3.1),[§4\.1\.2](https://arxiv.org/html/2605.20314#S4.SS1.SSS2.p1.1)\.
- F\. Kovačević, H\. C\. Ji, D\. Wu, M\. Soltanolkotabi, and M\. Mondelli \(2026\)Full\-batch gradient descent outperforms one\-pass sgd: sample complexity separation in single\-index learning\.arXiv preprint arXiv: 2602\.02431\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1)\.
- J\. D\. Lee, K\. Oko, T\. Suzuki, and D\. Wu \(2025\)Neural network learns low\-dimensional polynomials with sgd near the information\-theoretic limit\.Advances in Neural Information Processing Systems37,pp\. 58716–58756\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p3.1),[§4\.1\.1](https://arxiv.org/html/2605.20314#S4.SS1.SSS1.p1.1)\.
- L\. Lin, J\. Wu, and P\. L\. Bartlett \(2025\)Improved scaling laws in linear regression via data reuse\.arXiv preprint arXiv: 2506\.08415\.Cited by:[§B\.1](https://arxiv.org/html/2605.20314#A2.SS1.SSS0.Px2.p2.1),[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p2.5),[§6](https://arxiv.org/html/2605.20314#S6.SS0.SSS0.Px2.p1.1)\.
- N\. Muennighoff, A\. Rush, B\. Barak, T\. Le Scao, N\. Tazi, A\. Piktus, S\. Pyysalo, T\. Wolf, and C\. A\. Raffel \(2023\)Scaling data\-constrained language models\.Advances in Neural Information Processing Systems36,pp\. 50358–50376\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1](https://arxiv.org/html/2605.20314#S1.p1.1)\.
- L\. Pillaud\-Vivien, A\. Rudi, and F\. Bach \(2018\)Statistical optimality of stochastic gradient descent on hard learning problems through multiple passes\.Neural Information Processing Systems\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p2.5),[§6](https://arxiv.org/html/2605.20314#S6.SS0.SSS0.Px2.p1.1)\.
- A\. Sekhari, K\. Sridharan, and S\. Kale \(2021\)Sgd: the role of implicit regularization, batch\-size and multiple\-epochs\.Advances In Neural Information Processing Systems34,pp\. 27422–27433\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p2.5)\.
- G\. Valiant \(2012\)Finding correlations in subquadratic time, with applications to learning parities and juntas\.In2012 IEEE 53rd Annual Symposium on Foundations of Computer Science,pp\. 11–20\.Cited by:[§1](https://arxiv.org/html/2605.20314#S1.p3.1),[§4\.1\.3](https://arxiv.org/html/2605.20314#S4.SS1.SSS3.p1.4)\.
- M\. Wortsman, P\. J\. Liu, L\. Xiao, K\. Everett, A\. Alemi, B\. Adlam, J\. D\. Co\-Reyes, I\. Gur, A\. Kumar, R\. Novak, J\. Pennington, J\. N\. Sohl\-Dickstein, K\. Xu, J\. Lee, J\. Gilmer, and S\. Kornblith \(2023\)Small\-scale proxies for large\-scale transformer training instabilities\.International Conference on Learning Representations\.External Links:[Document](https://dx.doi.org/10.48550/arXiv.2309.14322)Cited by:[§5\.2\.2](https://arxiv.org/html/2605.20314#S5.SS2.SSS2.Px1.p1.1)\.
- J\. Wu, P\. L\. Bartlett, J\. D\. Lee, S\. M\. Kakade, and B\. Yu \(2025\)Risk comparisons in linear regression: implicit regularization dominates explicit regularization\.arXiv preprint arXiv: 2509\.17251\.Cited by:[§B\.1](https://arxiv.org/html/2605.20314#A2.SS1.SSS0.Px2.p2.1)\.
- Y\. Xu, Q\. Qian, H\. Li, and R\. Jin \(2021\)Why does multi\-epoch training help?\.arXiv preprint arXiv: 2105\.06015\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1),[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p2.5)\.
- T\. Yan, H\. Wen, B\. Li, K\. Luo, W\. Chen, and K\. Lyu \(2025\)Larger datasets can be repeated more: a theoretical analysis of multi\-epoch scaling in linear regression\.arXiv preprint arXiv: 2511\.13421\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1)\.
- G\. Yang, E\. J\. Hu, I\. Babuschkin, S\. Sidor, X\. Liu, D\. Farhi, N\. Ryder, J\. Pachocki, W\. Chen, and J\. Gao \(2022\)Tensor programs v: tuning large neural networks via zero\-shot hyperparameter transfer\.arXiv preprint arXiv: 2203\.03466\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p3.1),[Figure 7](https://arxiv.org/html/2605.20314#S5.F7),[Figure 7](https://arxiv.org/html/2605.20314#S5.F7.6.3.2),[§5\.2\.1](https://arxiv.org/html/2605.20314#S5.SS2.SSS1.Px1.p1.7)\.
- G\. Yang and E\. J\. Hu \(2020\)Feature learning in infinite\-width neural networks\.arXiv preprint arXiv: 2011\.14522\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p3.1),[§5\.2\.1](https://arxiv.org/html/2605.20314#S5.SS2.SSS1.Px1.p1.7)\.
- G\. Yang, J\. B\. Simon, and J\. Bernstein \(2023\)A spectral condition for feature learning\.arXiv preprint arXiv: 2310\.17813\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p3.1)\.
- S\. Zhai, T\. Likhomanenko, E\. Littwin, D\. Busbridge, J\. Ramapuram, Y\. Zhang, J\. Gu, and J\. Susskind \(2023\)Stabilizing transformer training by preventing attention entropy collapse\.arXiv preprint arXiv: 2303\.06296\.Cited by:[§5\.2\.2](https://arxiv.org/html/2605.20314#S5.SS2.SSS2.Px1.p1.1)\.
- N\. Zucchet, F\. d’Angelo, A\. K\. Lampinen, and S\. C\. Y\. Chan \(2025\)The emergence of sparse attention: impact of data distribution and benefits of repetition\.arXiv preprint arXiv: 2505\.17863\.Cited by:[§1\.1](https://arxiv.org/html/2605.20314#S1.SS1.p1.1)\.

###### Contents

1. [1Introduction](https://arxiv.org/html/2605.20314#S1)1. [1\.1Related work](https://arxiv.org/html/2605.20314#S1.SS1)
2. [2Setup](https://arxiv.org/html/2605.20314#S2)
3. [3Small\-vs\-large gap: less data can lead to faster learning](https://arxiv.org/html/2605.20314#S3)
4. [4Unpacking the efficiency gain from smaller datasets](https://arxiv.org/html/2605.20314#S4)1. [4\.1Prior theories are insufficient](https://arxiv.org/html/2605.20314#S4.SS1)1. [4\.1\.1SQ\-CSQ difference](https://arxiv.org/html/2605.20314#S4.SS1.SSS1) 2. [4\.1\.2Gradient variance reduction](https://arxiv.org/html/2605.20314#S4.SS1.SSS2) 3. [4\.1\.3Biased \(input\) distribution](https://arxiv.org/html/2605.20314#S4.SS1.SSS3) 2. [4\.2Our explanation: dataset sampling bias accelerates learning by adjusting the relative norm growth](https://arxiv.org/html/2605.20314#S4.SS2)
5. [5Empirical evidence for relative norm growth](https://arxiv.org/html/2605.20314#S5)1. [5\.1Small\-vs\-large gap exists when training first onrandomlabels](https://arxiv.org/html/2605.20314#S5.SS1) 2. [5\.2Small\-vs\-large gap diminishes with parameter\-wise interventions](https://arxiv.org/html/2605.20314#S5.SS2)1. [5\.2\.1MLP: layer\-wise interventions](https://arxiv.org/html/2605.20314#S5.SS2.SSS1) 2. [5\.2\.2Transformer: scaling ofWq,WkW\_\{q\},W\_\{k\}](https://arxiv.org/html/2605.20314#S5.SS2.SSS2)
6. [6Discussions](https://arxiv.org/html/2605.20314#S6)
7. [References](https://arxiv.org/html/2605.20314#bib)
8. [AProof details](https://arxiv.org/html/2605.20314#A1)1. [A\.1Setup](https://arxiv.org/html/2605.20314#A1.SS1) 2. [A\.2Analysis of Phase 1](https://arxiv.org/html/2605.20314#A1.SS2)1. [A\.2\.1Constant\-probability lower bound on\|q\(0\)\|=Ω​\(1/N\)\|q^\{\(0\)\}\|=\\Omega\(1/\\sqrt\{N\}\)](https://arxiv.org/html/2605.20314#A1.SS2.SSS1) 2. [A\.2\.2Stability ofq\(t\)q^\{\(t\)\}for smalla\(t\)a^\{\(t\)\}](https://arxiv.org/html/2605.20314#A1.SS2.SSS2) 3. [A\.2\.3Linear growth ofa\(t\)a^\{\(t\)\}](https://arxiv.org/html/2605.20314#A1.SS2.SSS3) 4. [A\.2\.4Choosing largest stablea⋆a\_\{\\star\}](https://arxiv.org/html/2605.20314#A1.SS2.SSS4) 3. [A\.3Analysis of Phase 2](https://arxiv.org/html/2605.20314#A1.SS3) 4. [A\.4Combining both phases](https://arxiv.org/html/2605.20314#A1.SS4)1. [A\.4\.1w\(t\)w^\{\(t\)\}grows slowly](https://arxiv.org/html/2605.20314#A1.SS4.SSS1) 2. [A\.4\.2Connectingα0\\alpha\_\{0\}anda⋆a\_\{\\star\}](https://arxiv.org/html/2605.20314#A1.SS4.SSS2) 3. [A\.4\.3Final Convergence Bound](https://arxiv.org/html/2605.20314#A1.SS4.SSS3) 5. [A\.5Proof ofCorollary2: training first phase on random labels](https://arxiv.org/html/2605.20314#A1.SS5)
9. [BExperiment details and additional results](https://arxiv.org/html/2605.20314#A2)1. [B\.1Experiment details](https://arxiv.org/html/2605.20314#A2.SS1) 2. [B\.2Additional empirical results](https://arxiv.org/html/2605.20314#A2.SS2)1. [B\.2\.1More setups with the small\-vs\-large gap](https://arxiv.org/html/2605.20314#A2.SS2.SSS1) 2. [B\.2\.2Ablation studies](https://arxiv.org/html/2605.20314#A2.SS2.SSS2)

## Appendix AProof details

Motivated by the empirical evidence in[Sections3](https://arxiv.org/html/2605.20314#S3),[5](https://arxiv.org/html/2605.20314#S5)and[4](https://arxiv.org/html/2605.20314#S4), we analyze a minimal setting, a single quadratic neuron on 2\-sparse parity under a two\-phase schedule: phase 1 runs GD on a fixed dataset of sizeNN, and phase 2 switches to the full population\.

### A\.1Setup

We assume the inputx∈\{±1\}dx\\in\\\{\\pm 1\\\}^\{d\}is sampled uniformly from the hypercube and the labely=x1​x2∈\{±1\}y=x\_\{1\}x\_\{2\}\\in\\\{\\pm 1\\\}is a 2\-sparse parity\. We study the quadratic neuron

f​\(x\)=12​a​\(w⊤​x\)2,f\(x\)=\\frac\{1\}\{2\}a\(w^\{\\top\}x\)^\{2\},trained with correlation lossℓ​\(y,y^\)=−y​y^\\ell\(y,\\hat\{y\}\)=\-y\\hat\{y\}\. Letw⋆w^\{\\star\}denote a global minimizer, where\|w1⋆\|=\|w2⋆\|=12\|w\_\{1\}^\{\\star\}\|=\|w\_\{2\}^\{\\star\}\|=\\frac\{1\}\{\\sqrt\{2\}\}, andwi=0w\_\{i\}=0otherwise\.

##### Projection\.

To ensure that the weights are bounded at the solution, we use the following projections:

- •Output weight clipping:after each update we seta←clip\[−1,1\]​\(a\)a\\leftarrow\\mathrm\{clip\}\_\{\[\-1,1\]\}\(a\)\.
- •Input weight renormalization:after each update we setw←w/‖w‖2w\\leftarrow w/\\\|w\\\|\_\{2\}\.

We consider 2\-phase training, where Phase 1 uses a randomly sampled dataset of sizeNN, and Phase 2 uses the full population\.

##### Phase 1 \(fixed batch of sizeNN\)\.

Fix a dataset\{\(x\(s\),y\(s\)\)\}s=1N\\\{\(x^\{\(s\)\},y^\{\(s\)\}\)\\\}\_\{s=1\}^\{N\}and define the empirical moment matrix

𝑴^:=𝔼^​\[y​x​x⊤\]:=1N​∑s=1Ny\(s\)​x\(s\)​x\(s\)⊤\.\\widehat\{\{\\bm\{M\}\}\}:=\\widehat\{\\mathbb\{E\}\}\[yxx^\{\\top\}\]:=\\frac\{1\}\{N\}\\sum\_\{s=1\}^\{N\}y^\{\(s\)\}x^\{\(s\)\}x^\{\(s\)\\top\}\.One step of \(projected\) gradient descent on this fixed batch takes the form

a\(t\+1\)\\displaystyle a^\{\(t\+1\)\}=clip\[−1,1\]​\(a\(t\)\+η2​\(w\(t\)\)⊤​𝑴^​w\(t\)\),\\displaystyle=\\mathrm\{clip\}\_\{\[\-1,1\]\}\\Big\(a^\{\(t\)\}\+\\frac\{\\eta\}\{2\}\(w^\{\(t\)\}\)^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(t\)\}\\Big\),\(A\.3\)w\(t\+1\)\\displaystyle w^\{\(t\+1\)\}=w\(t\)\+η​a\(t\)​𝑴^​w\(t\)‖w\(t\)\+η​a\(t\)​𝑴^​w\(t\)‖2\.\\displaystyle=\\frac\{w^\{\(t\)\}\+\\eta a^\{\(t\)\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(t\)\}\}\{\\bigl\\\|w^\{\(t\)\}\+\\eta a^\{\(t\)\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(t\)\}\\bigr\\\|\_\{2\}\}\.\(A\.4\)Let’s define

q\(t\):=\(w\(t\)\)⊤​𝑴^​w\(t\)q^\{\(t\)\}:=\(w^\{\(t\)\}\)^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(t\)\}which will be useful in the analysis\.

##### Phase 2 \(population\)\.

After phase 1, we switch to population gradients, replacing𝑴^\\widehat\{\{\\bm\{M\}\}\}by the population matrix

𝑴:=𝔼​\[y​x​x⊤\]=e1​e2⊤\+e2​e1⊤\.\{\\bm\{M\}\}:=\\mathbb\{E\}\[yxx^\{\\top\}\]=e\_\{1\}e\_\{2\}^\{\\top\}\+e\_\{2\}e\_\{1\}^\{\\top\}\.in the above updates\.

We restate[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1)below, which shows that using a smallerNNin Phase 1 improves convergence\.

###### Theorem\(2\-phase training from standard initialization;[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1)restated\.\)\.

Consider a 2\-phase training withm\>d≥3m\>d\\geq 3555d≥3d\\geq 3is required for the proof of[Lemma10](https://arxiv.org/html/2605.20314#Thmlemma10)regarding the probability of the Beta distribution\.and learning rateη≤12\\eta\\leq\\frac\{1\}\{2\}\. The first phase uses a randomly sampled dataset of sized≤N≤d2d\\leq N\\leq d^\{2\}, until\|a\|≥a⋆\|a\|\\geq a\_\{\\star\}for somea⋆∈\(0,1\)a\_\{\\star\}\\in\(0,1\)wherea⋆≲1\(N​d\)1/4​log⁡\(d/δ\)a\_\{\\star\}\\lesssim\\frac\{1\}\{\(Nd\)^\{1/4\}\\sqrt\{\\log\(d/\\delta\)\}\}; the second phase uses the full population gradient, until reaching aw^\\hat\{w\}such that‖w^−w⋆‖2≲ε\\\|\\hat\{w\}\-w^\{\\star\}\\\|\_\{2\}\\lesssim\\sqrt\{\\varepsilon\}\. LetT1,T2T\_\{1\},T\_\{2\}denote the numbers of steps required in each phase respectively\. Letpall∈\(0,1\)p\_\{\\mathrm\{all\}\}\\in\(0,1\)be a universal constant wherepall=Θ​\(1\)p\_\{\\mathrm\{all\}\}=\\Theta\(1\)\.666pallp\_\{\\mathrm\{all\}\}is formally defined in Lemma[10](https://arxiv.org/html/2605.20314#Thmlemma10)\.Then, with probability at leastpall−δp\_\{\\mathrm\{all\}\}\-\\deltaover the random initialization and the phase\-1 samples,

T1≲a⋆​Nη,T2≲2η​a⋆​log⁡\(dε\)\.\\displaystyle T\_\{1\}\\lesssim\\frac\{a\_\{\\star\}\\sqrt\{N\}\}\{\\eta\},\\quad T\_\{2\}\\lesssim\\frac\{2\}\{\\eta a\_\{\\star\}\}\\log\\Big\(\\frac\{d\}\{\\varepsilon\}\\Big\)\.\(A\.5\)On the same event, with the optimal choice ofa⋆a\_\{\\star\}, the total number of steps isO​\(\(N​d\)1/4​log⁡\(dε\)​log1/2⁡\(dδ\)\)O\\left\(\(Nd\)^\{1/4\}\\log\\left\(\\frac\{d\}\{\\varepsilon\}\\right\)\\log^\{1/2\}\\left\(\\frac\{d\}\{\\delta\}\\right\)\\right\)\.

### A\.2Analysis of Phase 1

First we show that a smaller dataset sizeNNleads to faster growth of the outer weight\.

###### Theorem 3\(Phase 1 is faster with fewer fixed\-batch samples\)\.

Initializew\(0\)∼Unif​\(𝕊d−1\)w^\{\(0\)\}\\sim\\text\{Unif\}\(\\mathbb\{S\}^\{d\-1\}\), anda\(0\)∼𝒩​\(0,1/m\)a^\{\(0\)\}\\sim\{\\mathcal\{N\}\}\(0,1/m\)wheremmcan be considered as a model width parameter\. Fix any targeta⋆∈\(0,1\]a\_\{\\star\}\\in\(0,1\]and learning rateη\>0\\eta\>0\. Consider phase 1 training on a fixed batch of sizeNNand let

T⋆​\(N\):=min⁡\{t:\|a\(t\)\|≥a⋆\}\.T\_\{\\star\}\(N\):=\\min\\\{t:\\ \|a^\{\(t\)\}\|\\geq a\_\{\\star\}\\\}\.For some choice ofc0∈\(0,32\)c\_\{0\}\\in\(0,\\frac\{\\sqrt\{3\}\}\{2\}\), define the event

𝒢0:=\{sign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)​and​\|q\(0\)\|≥c0/N\},\\mathcal\{G\}\_\{0\}:=\\Big\\\{\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)\\ \\text\{and\}\\ \|q^\{\(0\)\}\|\\geq c\_\{0\}/\\sqrt\{N\}\\Big\\\},andpgood:=12​pP​Z​\(c0\)p\_\{\\mathrm\{good\}\}:=\\frac\{1\}\{2\}p\_\{PZ\}\(c\_\{0\}\)wherepP​Z​\(c0\)p\_\{PZ\}\(c\_\{0\}\)is the constant from Lemma[1](https://arxiv.org/html/2605.20314#Thmlemma1)\. Then𝐏𝐫\[𝒢0\]≥pgood\\mathop\{\\bf Pr\\/\}\[\\mathcal\{G\}\_\{0\}\]\\geq p\_\{\\mathrm\{good\}\}\. On the intersection of𝒢0\\mathcal\{G\}\_\{0\}with the stability event\{η​a⋆‖𝐌^∥2≤1/2\}\\\{\\eta a\_\{\\star\}\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq 1/2\\\},

T⋆​\(N\)≤⌈2​\(a⋆−\|a\(0\)\|\)\+η​c0​N⌉=O​\(a⋆η​N\)\.T\_\{\\star\}\(N\)\\ \\leq\\ \\left\\lceil\\frac\{2\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\}\{\\eta\\,c\_\{0\}\}\\sqrt\{N\}\\right\\rceil\\ =\\ O\\\!\\Big\(\\frac\{a\_\{\\star\}\}\{\\eta\}\\sqrt\{N\}\\Big\)\.\(A\.6\)In particular, for everyδ∈\(0,1\)\\delta\\in\(0,1\), ifa⋆≤min⁡\{1,12​η​BN,d,δ\}a\_\{\\star\}\\leq\\min\\\!\\left\\\{1,\\frac\{1\}\{2\\eta B\_\{N,d,\\delta\}\}\\right\\\}withBN,d,δB\_\{N,d,\\delta\}from Corollary[4](https://arxiv.org/html/2605.20314#Thmtheorem4), then equation[A\.6](https://arxiv.org/html/2605.20314#A1.E6)holds with probability at leastpgood−δp\_\{\\mathrm\{good\}\}\-\\delta\. In particular, holding\(η,a⋆\)\(\\eta,a\_\{\\star\}\)fixed, the required number of phase 1 steps scales asN\\sqrt\{N\}\.

###### Proof sketch\.

The proof consists of three parts\.

1. 1\.Initialization gives a nontrivialq\(0\)q^\{\(0\)\}at constant probability\.Forw\(0\)w^\{\(0\)\}randomly sampled from the unit sphere and an i\.i\.d\. batch of sizeNN, the empirical quadratic formq\(0\)=\(w\(0\)\)⊤​𝑴^​w\(0\)=1N​∑s=1Ny\(s\)​\(x\(s\)⊤​w\(0\)\)2q^\{\(0\)\}=\(w^\{\(0\)\}\)^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(0\)\}=\\frac\{1\}\{N\}\\sum\_\{s=1\}^\{N\}y^\{\(s\)\}\(x^\{\(s\)\\top\}w^\{\(0\)\}\)^\{2\}has magnitudeΩ​\(1/N\)\\Omega\(1/\\sqrt\{N\}\)with constant probability \(Lemma[1](https://arxiv.org/html/2605.20314#Thmlemma1)\)\. Sincea\(0\)a^\{\(0\)\}is initialized symmetric about0and independent ofq\(0\)q^\{\(0\)\}, we also have𝐏𝐫\[sign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)\]=1/2\\mathop\{\\bf Pr\\/\}\[\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)\]=1/2conditional onq\(0\)≠0q^\{\(0\)\}\\neq 0\(Lemma[2](https://arxiv.org/html/2605.20314#Thmlemma2)\)\. Together this yields the constant\-probability “good” event𝒢0=\{sign⁡\(a\(0\)\)=sign⁡\(q\(0\)\),\|q\(0\)\|≥c0/N\}\\mathcal\{G\}\_\{0\}=\\\{\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\),\\ \|q^\{\(0\)\}\|\\geq c\_\{0\}/\\sqrt\{N\}\\\}\.
2. 2\.While\|a\(t\)\|≤a⋆\|a^\{\(t\)\}\|\\leq a\_\{\\star\}, the sign ofq\(t\)q^\{\(t\)\}is stable and\|q\(t\)\|\|q^\{\(t\)\}\|does not decrease\.Under the stability conditionη​a⋆​‖𝑴^‖2≤1/2\\eta a\_\{\\star\}\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq 1/2, the normalized updates to the inner weights is a signed power iteration\. Lemma[3](https://arxiv.org/html/2605.20314#Thmlemma3)shows that the one\-step incrementq\(t\+1\)−q\(t\)q^\{\(t\+1\)\}\-q^\{\(t\)\}has the same sign asa\(t\)a^\{\(t\)\}\. On𝒢0\\mathcal\{G\}\_\{0\}, the outer weight update \(with clipping being inactive since\|a\(t\)\|\|a^\{\(t\)\}\|hasn’t grown toa⋆<1a\_\{\\star\}<1\) isa\(t\+1\)=a\(t\)\+η2​q\(t\)a^\{\(t\+1\)\}=a^\{\(t\)\}\+\\frac\{\\eta\}\{2\}q^\{\(t\)\}, hencesign⁡\(a\(t\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)cannot flip as long assign⁡\(a\(t\)\)=sign⁡\(q\(t\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(t\)\}\)\. Combining these gives by induction that for allt<T⋆t<T\_\{\\star\},sign⁡\(a\(t\)\)=sign⁡\(q\(t\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)and therefore\|q\(t\)\|≥\|q\(0\)\|\|q^\{\(t\)\}\|\\geq\|q^\{\(0\)\}\|\.
3. 3\.Linear growth of\|a\(t\)\|\|a^\{\(t\)\}\|and theN\\sqrt\{N\}time scale\.On the same event, for allt<T⋆t<T\_\{\\star\}we have\|a\(t\+1\)\|=\|a\(t\)\|\+η2​\|q\(t\)\|≥\|a\(t\)\|\+η2​\|q\(0\)\|\|a^\{\(t\+1\)\}\|=\|a^\{\(t\)\}\|\+\\frac\{\\eta\}\{2\}\|q^\{\(t\)\}\|\\geq\|a^\{\(t\)\}\|\+\\frac\{\\eta\}\{2\}\|q^\{\(0\)\}\|, so\|a\(t\)\|\|a^\{\(t\)\}\|grows at least linearly until it reachesa⋆a\_\{\\star\}\. Thus T⋆≤⌈2​\(a⋆−\|a\(0\)\|\)\+η​\|q\(0\)\|⌉≤⌈2​\(a⋆−\|a\(0\)\|\)\+η​c0​N⌉T\_\{\\star\}\\leq\\left\\lceil\\frac\{2\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\}\{\\eta\\,\|q^\{\(0\)\}\|\}\\right\\rceil\\leq\\left\\lceil\\frac\{2\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\}\{\\eta\\,c\_\{0\}\}\\sqrt\{N\}\\right\\rceilon𝒢0\\mathcal\{G\}\_\{0\}, which is the claimedO​\(N\)O\(\\sqrt\{N\}\)bound\.

Finally, Lemma[5](https://arxiv.org/html/2605.20314#Thmlemma5)provides a high\-probability bound on‖𝑴^‖2\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}, yielding an explicit stable choice ofa⋆a\_\{\\star\}\(Corollary[4](https://arxiv.org/html/2605.20314#Thmtheorem4)\)\. ∎

#### A\.2\.1Constant\-probability lower bound on\|q\(0\)\|=Ω​\(1/N\)\|q^\{\(0\)\}\|=\\Omega\(1/\\sqrt\{N\}\)

We now prove that, in the parity setting, the fixed\-batch quadratic form at initialization

q\(0\)=\(w\(0\)\)⊤​𝑴^​w\(0\)=1N​∑s=1Ny\(s\)​\(x\(s\)⊤​w\(0\)\)2q^\{\(0\)\}=\(w^\{\(0\)\}\)^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(0\)\}=\\frac\{1\}\{N\}\\sum\_\{s=1\}^\{N\}y^\{\(s\)\}\\bigl\(x^\{\(s\)\\top\}w^\{\(0\)\}\\bigr\)^\{2\}typically has magnitudeΩ​\(1/N\)\\Omega\(1/\\sqrt\{N\}\)with*constant*probability\.

###### Lemma 1\.

Assumew\(0\)w^\{\(0\)\}is uniform on the unit sphere \(equivalently,w\(0\)=g/‖g‖2w^\{\(0\)\}=g/\\\|g\\\|\_\{2\}forg∼𝒩​\(0,𝐈\)g\\sim\\mathcal\{N\}\(0,\{\\bm\{I\}\}\)\)\. For somec0∈\(0,32\)c\_\{0\}\\in\(0,\\frac\{\\sqrt\{3\}\}\{2\}\)andpP​Z​\(c\):=\(1−1/2\)⋅\(1−43​c2\)238p\_\{PZ\}\(c\):=\(1\-1/\\sqrt\{2\}\)\\cdot\\frac\{\(1\-\\frac\{4\}\{3\}c^\{2\}\)^\{2\}\}\{3^\{8\}\}, for allN≥1N\\geq 1and alld≥3d\\geq 3,

𝐏𝐫\[\|q\(0\)\|≥c0N\]≥pP​Z​\(c0\)\.\\mathop\{\\bf Pr\\/\}\\left\[\|q^\{\(0\)\}\|\\ \\geq\\ \\frac\{c\_\{0\}\}\{\\sqrt\{N\}\}\\right\]\\ \\geq\\ p\_\{PZ\}\(c\_\{0\}\)\.As an example, we can choosec0=3/8c\_\{0\}=\\sqrt\{3/8\}, in which casepP​Z​\(c0\)=2−28⋅38p\_\{PZ\}\(c\_\{0\}\)=\\frac\{2\-\\sqrt\{2\}\}\{8\\cdot 3^\{8\}\}\.

###### Proof\.

Fixw=w\(0\)w=w^\{\(0\)\}and define a single\-sample random variableZ:=y​\(x⊤​w\)2Z:=y\(x^\{\\top\}w\)^\{2\}so thatq\(0\)=1N​∑s=1NZsq^\{\(0\)\}=\\frac\{1\}\{N\}\\sum\_\{s=1\}^\{N\}Z\_\{s\}for i\.i\.d\. copiesZsZ\_\{s\}\.

##### Step 1: conditional mean\.

μ​\(w\):=𝔼​\[Z∣w\]=𝔼​\[x1​x2​\(∑i=1dwi​xi\)2\]=∑i,jwi​wj​𝔼​\[x1​x2​xi​xj\]=2​w1​w2\.\\mu\(w\):=\\mathbb\{E\}\[Z\\mid w\]=\\mathbb\{E\}\\left\[x\_\{1\}x\_\{2\}\\Big\(\\sum\_\{i=1\}^\{d\}w\_\{i\}x\_\{i\}\\Big\)^\{2\}\\right\]=\\sum\_\{i,j\}w\_\{i\}w\_\{j\}\\mathbb\{E\}\[x\_\{1\}x\_\{2\}x\_\{i\}x\_\{j\}\]=2w\_\{1\}w\_\{2\}\.

##### Step 2: conditional variance is bounded below on a constant\-probability event overww\.

Sincey2≡1y^\{2\}\\equiv 1,

𝔼​\[Z2∣w\]=𝔼​\[\(x⊤​w\)4∣w\]=3​‖w‖24−2​∑i=1dwi4≥‖w‖24=1\.\\mathbb\{E\}\[Z^\{2\}\\mid w\]=\\mathbb\{E\}\\big\[\(x^\{\\top\}w\)^\{4\}\\mid w\\big\]=3\\\|w\\\|\_\{2\}^\{4\}\-2\\sum\_\{i=1\}^\{d\}w\_\{i\}^\{4\}\\geq\\\|w\\\|\_\{2\}^\{4\}=1\.Therefore,

Var​\(Z∣w\)=𝔼​\[Z2∣w\]−μ​\(w\)2≥1−4​w12​w22≥1−\(w12\+w22\)2\.\\mathrm\{Var\}\(Z\\mid w\)=\\mathbb\{E\}\[Z^\{2\}\\mid w\]\-\\mu\(w\)^\{2\}\\geq 1\-4w\_\{1\}^\{2\}w\_\{2\}^\{2\}\\geq 1\-\(w\_\{1\}^\{2\}\+w\_\{2\}^\{2\}\)^\{2\}\.Define the eventℰ:=\{w12\+w22≤1/2\}\\mathcal\{E\}:=\\\{w\_\{1\}^\{2\}\+w\_\{2\}^\{2\}\\leq 1/2\\\}\. Onℰ\\mathcal\{E\}we haveVar​\(Z∣w\)≥1−1/4=3/4\\mathrm\{Var\}\(Z\\mid w\)\\geq 1\-1/4=3/4\. Moreover, sincewwis uniform on the sphere andd≥3d\\geq 3, the random variablew12\+w22w\_\{1\}^\{2\}\+w\_\{2\}^\{2\}has aBeta​\(1,\(d−2\)/2\)\\mathrm\{Beta\}\(1,\(d\-2\)/2\)distribution, hence

𝐏𝐫\[ℰ\]=𝐏𝐫\[w12\+w22≤1/2\]=1−\(1−1/2\)\(d−2\)/2≥1−2−1/2=:p1,\\mathop\{\\bf Pr\\/\}\[\\mathcal\{E\}\]=\\mathop\{\\bf Pr\\/\}\[w\_\{1\}^\{2\}\+w\_\{2\}^\{2\}\\leq 1/2\]=1\-\(1\-1/2\)^\{\(d\-2\)/2\}\\geq 1\-2^\{\-1/2\}=:p\_\{1\},wherep1\>0p\_\{1\}\>0is an absolute constant\.

##### Step 3: Paley\-Zygmund on\(q\(0\)\)2\(q^\{\(0\)\}\)^\{2\}\.

Condition onw∈ℰw\\in\\mathcal\{E\}\. We already haveσ2​\(w\):=Var​\(Z∣w\)≥3/4\\sigma^\{2\}\(w\):=\\mathrm\{Var\}\(Z\\mid w\)\\geq 3/4, hence

𝔼​\[\(q\(0\)\)2∣w\]=μ​\(w\)2\+σ2​\(w\)N≥34​N\.\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{2\}\\mid w\]=\\mu\(w\)^\{2\}\+\\frac\{\\sigma^\{2\}\(w\)\}\{N\}\\geq\\frac\{3\}\{4N\}\.We also need to upper bound the second moment of\(q\(0\)\)2\(q^\{\(0\)\}\)^\{2\}, i\.e\. a fourth\-moment upper bound forq\(0\)q^\{\(0\)\}\. For fixedww, the random variableq\(0\)q^\{\(0\)\}is a polynomial of total degree at most44in the independent Rademacher variables\{xi\(s\)\}i∈\[d\],s∈\[N\]\\\{x\_\{i\}^\{\(s\)\}\\\}\_\{i\\in\[d\],s\\in\[N\]\}\(after multilinearization usingxi2≡1x\_\{i\}^\{2\}\\equiv 1\)\. By the Bonami\-Beckner \(hypercontractive\) inequality, for any degree\-ddpolynomialffof Rademachers,

\(𝔼​\[\|f\|4\]\)1/4≤\(4−1\)d/2​\(𝔼​\[\|f\|2\]\)1/2\(\\mathbb\{E\}\[\|f\|^\{4\}\]\)^\{1/4\}\\leq\(4\-1\)^\{d/2\}\(\\mathbb\{E\}\[\|f\|^\{2\}\]\)^\{1/2\}Applying this withf=q\(0\)f=q^\{\(0\)\}\(conditional onww\) gives

𝔼​\[\(q\(0\)\)4∣w\]≤38​𝔼​\[\(q\(0\)\)2∣w\]2\.\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{4\}\\mid w\]\\leq 3^\{8\}\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{2\}\\mid w\]^\{2\}\.\(A\.7\)
Apply Paley\-Zygmund to the non\-negative random variableY:=\(q\(0\)\)2Y:=\(q^\{\(0\)\}\)^\{2\}conditional onw∈ℰw\\in\\mathcal\{E\}:

𝐏𝐫\[\(q\(0\)\)2≥θ𝔼\[\(q\(0\)\)2∣w\]\|w\]≥\(1−θ\)2​𝔼​\[\(q\(0\)\)2∣w\]2𝔼​\[\(q\(0\)\)4∣w\]≥\(1−θ\)238,\\mathop\{\\bf Pr\\/\}\\left\[\(q^\{\(0\)\}\)^\{2\}\\geq\\theta\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{2\}\\mid w\]\\ \\middle\|\\ w\\right\]\\geq\\frac\{\(1\-\\theta\)^\{2\}\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{2\}\\mid w\]^\{2\}\}\{\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{4\}\\mid w\]\}\\geq\\frac\{\(1\-\\theta\)^\{2\}\}\{3^\{8\}\},where we used equation[A\.7](https://arxiv.org/html/2605.20314#A1.E7)\. Withθ=1/2\\theta=1/2,

𝐏𝐫\[\(q\(0\)\)2≥θ𝔼\[\(q\(0\)\)2∣w\]\|w\]≥14⋅38=:p2\\mathop\{\\bf Pr\\/\}\\left\[\(q^\{\(0\)\}\)^\{2\}\\geq\\theta\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{2\}\\mid w\]\\ \\middle\|\\ w\\right\]\\geq\\frac\{1\}\{4\\cdot 3^\{8\}\}=:p\_\{2\}On this event,

\|q\(0\)\|≥12​𝔼​\[\(q\(0\)\)2∣w\]≥38⋅1N\.\|q^\{\(0\)\}\|\\geq\\sqrt\{\\tfrac\{1\}\{2\}\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{2\}\\mid w\]\}\\geq\\sqrt\{\\frac\{3\}\{8\}\}\\cdot\\frac\{1\}\{\\sqrt\{N\}\}\.Thus, forc0:=3/8c\_\{0\}:=\\sqrt\{3/8\},

𝐏𝐫\[\|q\(0\)\|≥c0N\]≥𝐏𝐫\[ℰ\]⋅p2≥p1p2=:pZ​L\(c0\)\.\\mathop\{\\bf Pr\\/\}\\left\[\|q^\{\(0\)\}\|\\geq\\frac\{c\_\{0\}\}\{\\sqrt\{N\}\}\\right\]\\geq\\mathop\{\\bf Pr\\/\}\[\\mathcal\{E\}\]\\cdot p\_\{2\}\\geq p\_\{1\}p\_\{2\}=:p\_\{ZL\}\(c\_\{0\}\)\.∎

We also need the lucky event ofsign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)so that the updates don’t flip the output weight’s sign\.

###### Lemma 2\.

Assumea\(0\)a^\{\(0\)\}is independent of\(w\(0\),𝐌^\)\(w^\{\(0\)\},\\widehat\{\{\\bm\{M\}\}\}\), symmetric about0, and that𝐏𝐫\[a\(0\)=0\]=0\\mathop\{\\bf Pr\\/\}\[a^\{\(0\)\}=0\]=0\. Then for every thresholdq⋆\>0q\_\{\\star\}\>0,

𝐏𝐫\[sign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)​and​\|q\(0\)\|≥q⋆\]=12​𝐏𝐫\[\|q\(0\)\|≥q⋆\]\.\\mathop\{\\bf Pr\\/\}\\Big\[\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)\\ \\text\{and\}\\ \|q^\{\(0\)\}\|\\geq q\_\{\\star\}\\Big\]=\\frac\{1\}\{2\}\\mathop\{\\bf Pr\\/\}\\big\[\|q^\{\(0\)\}\|\\geq q\_\{\\star\}\\big\]\.

###### Proof\.

Condition on\(w\(0\),𝑴^\)\(w^\{\(0\)\},\\widehat\{\{\\bm\{M\}\}\}\)so thatq\(0\)q^\{\(0\)\}is fixed\. On the eventq\(0\)≠0q^\{\(0\)\}\\neq 0, symmetry and independence ofa\(0\)a^\{\(0\)\}imply𝐏𝐫\[sign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)∣q\(0\)\]=1/2\\mathop\{\\bf Pr\\/\}\[\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)\\mid q^\{\(0\)\}\]=1/2\. Multiply by𝟙​\{\|q\(0\)\|≥q⋆\}\\mathbbm\{1\}\\\{\|q^\{\(0\)\}\|\\geq q\_\{\\star\}\\\}and average over\(w\(0\),𝑴^\)\(w^\{\(0\)\},\\widehat\{\{\\bm\{M\}\}\}\)\. ∎

#### A\.2\.2Stability ofq\(t\)q^\{\(t\)\}for smalla\(t\)a^\{\(t\)\}

Next, we show that under a stability condition ofη​\|a\|​‖𝑴^‖2≤1/2\\eta\|a\|\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq 1/2, the update inqqhas the same sign asaa\.

###### Lemma 3\.

Fix any unit vectorw∈ℝdw\\in\\mathbb\{R\}^\{d\}, scalara∈ℝa\\in\\mathbb\{R\}, and learning rateη\>0\\eta\>0\. Define

w~:=\(𝑰\+η​a​𝑴^\)​w,w\+:=w~/‖w~‖2\.\\tilde\{w\}:=\(\{\\bm\{I\}\}\+\\eta a\\widehat\{\{\\bm\{M\}\}\}\)w,\\qquad w^\{\+\}:=\\tilde\{w\}/\\\|\\tilde\{w\}\\\|\_\{2\}\.Letq:=w⊤​𝐌^​wq:=w^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}was before, and defineq\+:=\(w\+\)⊤​𝐌^​w\+q^\{\+\}:=\(w^\{\+\}\)^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}w^\{\+\}similarly\.

Then, under the stability condition ofη​\|a\|​‖𝐌^‖2≤1/2\\eta\|a\|\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq 1/2,q\+−qq^\{\+\}\-qhas the same sign asaa\(or is zero\)\.

###### Proof\.

Define the following quantities

s:=w⊤​𝑴^2​w,r:=w⊤​𝑴^3​w\.\\displaystyle s:=w^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}^\{2\}w,\\qquad r:=w^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}^\{3\}w\.\(A\.8\)
With𝑴^\\widehat\{\{\\bm\{M\}\}\}being symmetric, we have

q\+=w~⊤​𝑴^​w~w~⊤​w~=w⊤​\(𝑴^\+2​η​a​𝑴^2\+η2​a2​𝑴^3\)​ww⊤​\(𝑰\+2​η​a​𝑴^\+η2​a2​𝑴^2\)​w=q\+2​η​a​s\+η2​a2​r1\+2​η​a​q\+η2​a2​s,q^\{\+\}=\\frac\{\\tilde\{w\}^\{\\top\}\\widehat\{\{\\bm\{M\}\}\}\\tilde\{w\}\}\{\\tilde\{w\}^\{\\top\}\\tilde\{w\}\}=\\frac\{w^\{\\top\}\(\\widehat\{\{\\bm\{M\}\}\}\+2\\eta a\\widehat\{\{\\bm\{M\}\}\}^\{2\}\+\\eta^\{2\}a^\{2\}\\widehat\{\{\\bm\{M\}\}\}^\{3\}\)w\}\{w^\{\\top\}\(\{\\bm\{I\}\}\+2\\eta a\\widehat\{\{\\bm\{M\}\}\}\+\\eta^\{2\}a^\{2\}\\widehat\{\{\\bm\{M\}\}\}^\{2\}\)w\}=\\frac\{q\+2\\eta as\+\\eta^\{2\}a^\{2\}r\}\{1\+2\\eta aq\+\\eta^\{2\}a^\{2\}s\},and the update inqqis

q\+−q=2​η​a​\(s−q2\)\+η2​a2​\(r−q​s\)1\+2​η​a​q\+η2​a2​s\.q^\{\+\}\-q=\\frac\{2\\eta a\(s\-q^\{2\}\)\+\\eta^\{2\}a^\{2\}\(r\-qs\)\}\{1\+2\\eta aq\+\\eta^\{2\}a^\{2\}s\}\.\(A\.9\)Note that the denominator in equation[A\.9](https://arxiv.org/html/2605.20314#A1.E9)is positive, which follows from the stability assumptionη​\|a\|​‖𝑴^‖2≤1/2\\eta\|a\|\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq 1/2\. The sign ofq\+−qq^\{\+\}\-qhence depends on the two terms in the numerator, which we bound separately below\.

First, let𝑴^=∑i=1dλi​ui​ui⊤\\widehat\{\{\\bm\{M\}\}\}=\\sum\_\{i=1\}^\{d\}\\lambda\_\{i\}u\_\{i\}u\_\{i\}^\{\\top\}be an eigendecomposition and setαi:=⟨w,ui⟩2\\alpha\_\{i\}:=\\langle w,u\_\{i\}\\rangle^\{2\}so thatα\\alphais a probability vector\. LetΛ\\Lambdabe the random variable taking valueλi\\lambda\_\{i\}with probabilityαi\\alpha\_\{i\}\. Then

q=∑iαi​λi=𝔼​\[Λ\],s=∑iαi​λi2=𝔼​\[Λ2\],r=∑iαi​λi3\.q=\\sum\_\{i\}\\alpha\_\{i\}\\lambda\_\{i\}=\\mathbb\{E\}\[\\Lambda\],\\qquad s=\\sum\_\{i\}\\alpha\_\{i\}\\lambda\_\{i\}^\{2\}=\\mathbb\{E\}\[\\Lambda^\{2\}\],\\qquad r=\\sum\_\{i\}\\alpha\_\{i\}\\lambda\_\{i\}^\{3\}\.
This directly gives thats−q2=Var​\(Λ\)≥0s\-q^\{2\}=\\mathrm\{Var\}\(\\Lambda\)\\geq 0\.

For the second term, note that

r−q​s=𝔼​\[Λ3\]−𝔼​\[Λ\]​𝔼​\[Λ2\]=𝔼​\[\(Λ−𝔼​\[Λ\]\)2​\(Λ\+𝔼​\[Λ\]\)\]\.r\-qs=\\mathbb\{E\}\[\\Lambda^\{3\}\]\-\\mathbb\{E\}\[\\Lambda\]\\mathbb\{E\}\[\\Lambda^\{2\}\]=\\mathbb\{E\}\\big\[\(\\Lambda\-\\mathbb\{E\}\[\\Lambda\]\)^\{2\}\(\\Lambda\+\\mathbb\{E\}\[\\Lambda\]\)\\big\]\.Since\|Λ\+𝔼​\[Λ\]\|≤2​‖𝑴^‖2\|\\Lambda\+\\mathbb\{E\}\[\\Lambda\]\|\\leq 2\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}, we obtain

\|r−q​s\|≤2​‖𝑴^‖2​𝔼​\[\(Λ−𝔼​\[Λ\]\)2\]=2​‖𝑴^‖2​\(s−q2\),\\displaystyle\|r\-qs\|\\leq 2\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\mathbb\{E\}\[\(\\Lambda\-\\mathbb\{E\}\[\\Lambda\]\)^\{2\}\]=2\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\(s\-q^\{2\}\),Combining this with the stability assumption ofη​\|a\|​‖𝑴^‖2≤1/2\\eta\|a\|\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq 1/2, we can bound the second term of the numerator in[EquationA\.9](https://arxiv.org/html/2605.20314#A1.E9)by

\|η2​a2​\(r−q​s\)\|≤\\displaystyle\\bigl\|\\eta^\{2\}a^\{2\}\(r\-qs\)\\bigr\|\\leq2​η2​\|a\|2​‖𝑴^‖2​\(s−q2\)≤η​\|a\|​\(s−q2\)\.\\displaystyle 2\\eta^\{2\}\|a\|^\{2\}\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\(s\-q^\{2\}\)\\leq\\eta\|a\|\(s\-q^\{2\}\)\.
Therefore,

- •ifa≥0a\\geq 0, then2​η​a​\(s−q2\)\+η2​a2​\(r−q​s\)≥2​η​a​\(s−q2\)−η​a​\(s−q2\)=η​a​\(s−q2\)≥02\\eta a\(s\-q^\{2\}\)\+\\eta^\{2\}a^\{2\}\(r\-qs\)\\geq 2\\eta a\(s\-q^\{2\}\)\-\\eta a\(s\-q^\{2\}\)=\\eta a\(s\-q^\{2\}\)\\geq 0;
- •ifa≤0a\\leq 0, then2​η​a​\(s−q2\)\+η2​a2​\(r−q​s\)≤2​η​a​\(s−q2\)\+η​\|a\|​\(s−q2\)=η​a​\(s−q2\)≤02\\eta a\(s\-q^\{2\}\)\+\\eta^\{2\}a^\{2\}\(r\-qs\)\\leq 2\\eta a\(s\-q^\{2\}\)\+\\eta\|a\|\(s\-q^\{2\}\)=\\eta a\(s\-q^\{2\}\)\\leq 0\.

Thus the numerator in equation[A\.9](https://arxiv.org/html/2605.20314#A1.E9)and henceq\+−qq^\{\+\}\-qhas the same sign asaa\(or is zero\)\. ∎

#### A\.2\.3Linear growth ofa\(t\)a^\{\(t\)\}

Next, we show thataagrows linearly when conditioned on the lucky event in[Lemma2](https://arxiv.org/html/2605.20314#Thmlemma2)and the stability assumption in[Lemma3](https://arxiv.org/html/2605.20314#Thmlemma3), from which an upper bound onT⋆T\_\{\\star\}\(i\.e\., time foraato grow toa⋆a\_\{\\star\}\) directly follows\.

###### Lemma 4\.

Assume the initialization event

sign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)and\|q\(0\)\|≥q⋆\>0,\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)\\qquad\\text\{and\}\\qquad\|q^\{\(0\)\}\|\\geq q\_\{\\star\}\>0,forq⋆:=c0Nq\_\{\\star\}:=\\frac\{c\_\{0\}\}\{\\sqrt\{N\}\}from[Lemma1](https://arxiv.org/html/2605.20314#Thmlemma1)\. Further, assume the stability condition

η​a⋆​‖𝑴^‖2≤12\.\\eta a\_\{\\star\}\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq\\frac\{1\}\{2\}\.Then for allt<T⋆t<T\_\{\\star\}we havesign⁡\(a\(t\)\)=sign⁡\(q\(t\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)and\|q\(t\)\|≥\|q\(0\)\|≥q⋆\|q^\{\(t\)\}\|\\geq\|q^\{\(0\)\}\|\\geq q\_\{\\star\}\. Consequently,

T⋆≤⌈2​\(a⋆−\|a\(0\)\|\)\+η​q⋆⌉\.T\_\{\\star\}\\leq\\left\\lceil\\frac\{2\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\}\{\\eta q\_\{\\star\}\}\\right\\rceil\.\(A\.10\)

###### Proof\.

Fix anyt<T⋆t<T\_\{\\star\}\. Since\|a\(t\)\|≤a⋆\|a^\{\(t\)\}\|\\leq a\_\{\\star\}andη​a⋆​‖𝑴^‖2≤1/2\\eta a\_\{\\star\}\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq 1/2, we may apply Lemma[3](https://arxiv.org/html/2605.20314#Thmlemma3)to the inner weight update at timett\. It implies thatq\(t\+1\)−q\(t\)q^\{\(t\+1\)\}\-q^\{\(t\)\}has the same sign asa\(t\)a^\{\(t\)\}\(or is zero\)\.

We next show by induction thatsign⁡\(a\(t\)\)=sign⁡\(q\(t\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)and\|q\(t\)\|≥\|q\(0\)\|\|q^\{\(t\)\}\|\\geq\|q^\{\(0\)\}\|for allt<T⋆t<T\_\{\\star\}\. The base caset=0t=0holds by assumption\. Assume it holds at timett\. Becauset<T⋆t<T\_\{\\star\}we have\|a\(t\)\|<1\|a^\{\(t\)\}\|<1, so clipping is inactive and

a\(t\+1\)=a\(t\)\+η2​q\(t\)\.a^\{\(t\+1\)\}=a^\{\(t\)\}\+\\frac\{\\eta\}\{2\}q^\{\(t\)\}\.Sincesign⁡\(a\(t\)\)=sign⁡\(q\(t\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(t\)\}\), we getsign⁡\(a\(t\+1\)\)=sign⁡\(a\(t\)\)\\operatorname\{sign\}\(a^\{\(t\+1\)\}\)=\\operatorname\{sign\}\(a^\{\(t\)\}\)and

\|a\(t\+1\)\|=\|a\(t\)\|\+η2​\|q\(t\)\|\.\|a^\{\(t\+1\)\}\|=\|a^\{\(t\)\}\|\+\\frac\{\\eta\}\{2\}\|q^\{\(t\)\}\|\.Thussign⁡\(a\(t\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)remains constant and equal tosign⁡\(q\(0\)\)\\operatorname\{sign\}\(q^\{\(0\)\}\)throughoutt<T⋆t<T\_\{\\star\}\. Returning to Lemma[3](https://arxiv.org/html/2605.20314#Thmlemma3), this meansq\(t\)q^\{\(t\)\}is pushed monotonically in the direction ofsign⁡\(a\(t\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)and therefore cannot cross0\. Hencesign⁡\(q\(t\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(q^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)and\|q\(t\)\|≥\|q\(0\)\|≥q⋆\|q^\{\(t\)\}\|\\geq\|q^\{\(0\)\}\|\\geq q\_\{\\star\}for allt<T⋆t<T\_\{\\star\}\.

Finally, using\|q\(t\)\|≥q⋆\|q^\{\(t\)\}\|\\geq q\_\{\\star\}gives the linear growth bound

\|a\(t\+1\)\|=\|a\(t\)\|\+η2​\|q\(t\)\|≥\|a\(t\)\|\+η2​q⋆,\|a^\{\(t\+1\)\}\|=\|a^\{\(t\)\}\|\+\\frac\{\\eta\}\{2\}\|q^\{\(t\)\}\|\\geq\|a^\{\(t\)\}\|\+\\frac\{\\eta\}\{2\}q\_\{\\star\},so\|a\(t\)\|≥\|a\(0\)\|\+t⋅η2​q⋆\|a^\{\(t\)\}\|\\geq\|a^\{\(0\)\}\|\+t\\cdot\\frac\{\\eta\}\{2\}q\_\{\\star\}whilet<T⋆t<T\_\{\\star\}\. Solving for the firstttsuch that\|a\(t\)\|≥a⋆\|a^\{\(t\)\}\|\\geq a\_\{\\star\}yields equation[A\.10](https://arxiv.org/html/2605.20314#A1.E10)\. ∎

#### A\.2\.4Choosing largest stablea⋆a\_\{\\star\}

In order to find a bound on how large we can seta⋆a\_\{\\star\}, we will first bound‖𝑴^‖2\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\.

###### Lemma 5\(Matrix Bernstein bound for‖𝑴^‖2\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\)\.

For everyδ∈\(0,1\)\\delta\\in\(0,1\), with probability at least1−δ1\-\\delta,

‖𝑴^‖2≤1\+C​\(d​log⁡\(2​d/δ\)N\+d​log⁡\(2​d/δ\)N\)\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq 1\+C\\left\(\\sqrt\{\\frac\{d\\log\(2d/\\delta\)\}\{N\}\}\\;\+\\;\\frac\{d\\log\(2d/\\delta\)\}\{N\}\\right\)for a universal constantC\>0C\>0\.

###### Proof\.

LetAs:=y\(s\)​x\(s\)​x\(s\)⊤A\_\{s\}:=y^\{\(s\)\}x^\{\(s\)\}x^\{\(s\)\\top\}\. SinceAs2=\(x\(s\)​x\(s\)⊤\)2=‖x\(s\)‖22​x\(s\)​x\(s\)⊤=d​x\(s\)​x\(s\)⊤A\_\{s\}^\{2\}=\(x^\{\(s\)\}x^\{\(s\)\\top\}\)^\{2\}=\\\|x^\{\(s\)\}\\\|\_\{2\}^\{2\}x^\{\(s\)\}x^\{\(s\)\\top\}=dx^\{\(s\)\}x^\{\(s\)\\top\}\(andy\(s\)​2=1y^\{\(s\)2\}=1\), we have

𝔼​\[As2\]=d​𝔼​\[x​x⊤\]=d​𝑰\.\\mathbb\{E\}\[A\_\{s\}^\{2\}\]=d\\mathbb\{E\}\[xx^\{\\top\}\]=d\{\\bm\{I\}\}\.Write the population matrix as

𝑴:=𝔼​\[As\]=𝔼​\[y​x​x⊤\]=e1​e2⊤\+e2​e1⊤,‖𝑴‖2=1\.\{\\bm\{M\}\}:=\\mathbb\{E\}\[A\_\{s\}\]=\\mathbb\{E\}\[yxx^\{\\top\}\]=e\_\{1\}e\_\{2\}^\{\\top\}\+e\_\{2\}e\_\{1\}^\{\\top\},\\qquad\\\|\{\\bm\{M\}\}\\\|\_\{2\}=1\.Define centered summandsXs:=As−𝑴X\_\{s\}:=A\_\{s\}\-\{\\bm\{M\}\}so that𝔼​\[Xs\]=0\\mathbb\{E\}\[X\_\{s\}\]=0and

𝑴^−𝑴=1N​∑s=1NXs\.\\widehat\{\{\\bm\{M\}\}\}\-\{\\bm\{M\}\}=\\frac\{1\}\{N\}\\sum\_\{s=1\}^\{N\}X\_\{s\}\.We bound‖Xs‖2≤‖As‖2\+‖𝑴‖2≤d\+1≤2​d\\\|X\_\{s\}\\\|\_\{2\}\\leq\\\|A\_\{s\}\\\|\_\{2\}\+\\\|\{\\bm\{M\}\}\\\|\_\{2\}\\leq d\+1\\leq 2d, so we may takeR:=2​dR:=2d\.

𝔼​\[Xs2\]=𝔼​\[\(As−𝑴\)2\]=𝔼​\[As2\]−𝑴2⪯𝔼​\[As2\]=d​𝑰,\\mathbb\{E\}\[X\_\{s\}^\{2\}\]=\\mathbb\{E\}\[\(A\_\{s\}\-\{\\bm\{M\}\}\)^\{2\}\]=\\mathbb\{E\}\[A\_\{s\}^\{2\}\]\-\{\\bm\{M\}\}^\{2\}\\preceq\\mathbb\{E\}\[A\_\{s\}^\{2\}\]=d\{\\bm\{I\}\},hence

σ2:=‖∑s=1N𝔼​\[Xs2\]‖2≤N​d\.\\sigma^\{2\}:=\\left\\\|\\sum\_\{s=1\}^\{N\}\\mathbb\{E\}\[X\_\{s\}^\{2\}\]\\right\\\|\_\{2\}\\leq Nd\.Matrix Bernstein \(for sums of independent mean\-zero self\-adjoint matrices\) then yields that with probability at least1−δ1\-\\delta,

‖∑s=1NXs‖2≤C​\(σ2​log⁡\(2​d/δ\)\+R​log⁡\(2​d/δ\)\)≤C​\(N​d​log⁡\(2​d/δ\)\+d​log⁡\(2​d/δ\)\)\.\\left\\\|\\sum\_\{s=1\}^\{N\}X\_\{s\}\\right\\\|\_\{2\}\\leq C\\left\(\\sqrt\{\\sigma^\{2\}\\log\(2d/\\delta\)\}\+R\\log\(2d/\\delta\)\\right\)\\leq C\\left\(\\sqrt\{Nd\\log\(2d/\\delta\)\}\+d\\log\(2d/\\delta\)\\right\)\.Dividing byNNgives

‖𝑴^−𝑴‖2≤C​\(d​log⁡\(2​d/δ\)N\+d​log⁡\(2​d/δ\)N\)\.\\\|\\widehat\{\{\\bm\{M\}\}\}\-\{\\bm\{M\}\}\\\|\_\{2\}\\leq C\\left\(\\sqrt\{\\frac\{d\\log\(2d/\\delta\)\}\{N\}\}\\;\+\\;\\frac\{d\\log\(2d/\\delta\)\}\{N\}\\right\)\.Finally,‖𝑴^‖2≤‖𝑴‖2\+‖𝑴^−𝑴‖2\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq\\\|\{\\bm\{M\}\}\\\|\_\{2\}\+\\\|\\widehat\{\{\\bm\{M\}\}\}\-\{\\bm\{M\}\}\\\|\_\{2\}and‖𝑴‖2=1\\\|\{\\bm\{M\}\}\\\|\_\{2\}=1, proving the claim\. ∎

Substituting the above gives the following\.

###### Corollary 4\.

Fixδ∈\(0,1\)\\delta\\in\(0,1\)and stepsizeη\>0\\eta\>0\. LetBN,d,δ=1\+C​\(d​log⁡\(2​d/δ\)N\+d​log⁡\(2​d/δ\)N\)B\_\{N,d,\\delta\}=1\+C\\left\(\\sqrt\{\\frac\{d\\log\(2d/\\delta\)\}\{N\}\}\\;\+\\;\\frac\{d\\log\(2d/\\delta\)\}\{N\}\\right\)\. If

a⋆≤min⁡\{1,12​η​BN,d,δ\},a\_\{\\star\}\\leq\\min\\left\\\{1,\\ \\frac\{1\}\{2\\eta\\,B\_\{N,d,\\delta\}\}\\right\\\},then with probability at least1−δ1\-\\deltathe stability event

η​\|a\(t\)\|​‖𝑴^‖2≤12for all​t<T⋆:=min⁡\{t:\|a\(t\)\|≥a⋆\}\\eta\\,\|a^\{\(t\)\}\|\\,\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq\\tfrac\{1\}\{2\}\\qquad\\text\{for all \}t<T\_\{\\star\}:=\\min\\\{t:\\ \|a^\{\(t\)\}\|\\geq a\_\{\\star\}\\\}holds\.

### A\.3Analysis of Phase 2

In Phase 2, we replace𝑴^\\widehat\{\{\\bm\{M\}\}\}by the population matrix𝑴=e1​e2⊤\+e2​e1⊤\{\\bm\{M\}\}=e\_\{1\}e\_\{2\}^\{\\top\}\+e\_\{2\}e\_\{1\}^\{\\top\}\. Its spectrum is explicit: let

u\+:=e1\+e22,u−:=e1−e22,u\_\{\+\}:=\\frac\{e\_\{1\}\+e\_\{2\}\}\{\\sqrt\{2\}\},\\qquad u\_\{\-\}:=\\frac\{e\_\{1\}\-e\_\{2\}\}\{\\sqrt\{2\}\},then𝑴​u\+=u\+\{\\bm\{M\}\}u\_\{\+\}=u\_\{\+\},𝑴​u−=−u−\{\\bm\{M\}\}u\_\{\-\}=\-u\_\{\-\}, and𝑴​v=0\{\\bm\{M\}\}v=0for allv⟂span​\{e1,e2\}v\\perp\\mathrm\{span\}\\\{e\_\{1\},e\_\{2\}\\\}\.

We show that in Phase 2,wwconverges quickly to one ofu\+,u−u\_\{\+\},u\_\{\-\}following a power iteration on𝑴\{\\bm\{M\}\}\.777The upper bound onT2T\_\{2\}is likely improvable toO​\(log⁡\(1/a\(0\)\)\)O\(\\log\(1/a^\{\(0\)\}\)\)\.

###### Lemma 6\(Population contraction\)\.

Assumeη≤12\\eta\\leq\\tfrac\{1\}\{2\},a\(0\)≠0a^\{\(0\)\}\\neq 0, andsign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)\. Consider projected updates

a\(t\+1\)=clip\[−1,1\]​\(a\(t\)\+η2​q\(t\)\),w\(t\+1\)=w\(t\)\+η​a\(t\)​𝑴​w\(t\)‖w\(t\)\+η​a\(t\)​𝑴​w\(t\)‖2,q\(t\):=\(w\(t\)\)⊤​𝑴​w\(t\)\.a^\{\(t\+1\)\}=\\mathrm\{clip\}\_\{\[\-1,1\]\}\\\!\\left\(a^\{\(t\)\}\+\\tfrac\{\\eta\}\{2\}q^\{\(t\)\}\\right\),\\qquad w^\{\(t\+1\)\}=\\frac\{w^\{\(t\)\}\+\\eta a^\{\(t\)\}\{\\bm\{M\}\}w^\{\(t\)\}\}\{\\\|w^\{\(t\)\}\+\\eta a^\{\(t\)\}\{\\bm\{M\}\}w^\{\(t\)\}\\\|\_\{2\}\},\\qquad q^\{\(t\)\}:=\(w^\{\(t\)\}\)^\{\\top\}\{\\bm\{M\}\}w^\{\(t\)\}\.
Let

ua:=e1\+sign⁡\(a\(0\)\)​e22,u−a:=e1−sign⁡\(a\(0\)\)​e22,αt:=\|⟨w\(t\),ua⟩\|,rt:=1−αt2αt\.u\_\{a\}:=\\frac\{e\_\{1\}\+\\operatorname\{sign\}\(a^\{\(0\)\}\)e\_\{2\}\}\{\\sqrt\{2\}\},\\quad u\_\{\-a\}:=\\frac\{e\_\{1\}\-\\operatorname\{sign\}\(a^\{\(0\)\}\)e\_\{2\}\}\{\\sqrt\{2\}\},\\qquad\\alpha\_\{t\}:=\|\\langle w^\{\(t\)\},u\_\{a\}\\rangle\|,\\qquad r\_\{t\}:=\\frac\{\\sqrt\{1\-\\alpha\_\{t\}^\{2\}\}\}\{\\alpha\_\{t\}\}\.Then:

1. 1\.Sign stability and monotonicity\.For allt≥0t\\geq 0,sign⁡\(a\(t\)\)=sign⁡\(q\(t\)\)=sign⁡\(a\(0\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)=\\operatorname\{sign\}\(q^\{\(t\)\}\)=\\operatorname\{sign\}\(a^\{\(0\)\}\), and\|a\(t\)\|\|a^\{\(t\)\}\|is non\-decreasing\.
2. 2\.Alignment contraction\.For allt≥0t\\geq 0, rt\+1≤11\+η​\|a\(t\)\|​rt≤11\+η​\|a\(0\)\|​rt\.r\_\{t\+1\}\\ \\leq\\ \\frac\{1\}\{1\+\\eta\|a^\{\(t\)\}\|\}\\,r\_\{t\}\\ \\leq\\ \\frac\{1\}\{1\+\\eta\|a^\{\(0\)\}\|\}\\,r\_\{t\}\.

Consequently, after

T2:=⌈2η​\|a\(0\)\|​log⁡\(1α02​ε\)⌉T\_\{2\}:=\\left\\lceil\\frac\{2\}\{\\eta\|a^\{\(0\)\}\|\}\\log\\Big\(\\frac\{1\}\{\\alpha\_\{0\}^\{2\}\\varepsilon\}\\Big\)\\right\\rceilsteps we haveαT22≥1−ε\\alpha\_\{T\_\{2\}\}^\{2\}\\geq 1\-\\varepsilonfor anyε∈\(0,1/2\)\\varepsilon\\in\(0,1/2\)\.

###### Proof\.

Becauseη≤1/2\\eta\\leq 1/2and\|a\(t\)\|≤1\|a^\{\(t\)\}\|\\leq 1, the stability condition of Lemma[3](https://arxiv.org/html/2605.20314#Thmlemma3)\(i\.e\.,η​\|a\(t\)\|​‖𝑴‖2≤12\\eta\|a^\{\(t\)\}\|\\\|\{\\bm\{M\}\}\\\|\_\{2\}\\leq\\tfrac\{1\}\{2\}\) holds with𝑴^=𝑴\\widehat\{\{\\bm\{M\}\}\}=\{\\bm\{M\}\}for every step\. Henceq\(t\+1\)−q\(t\)q^\{\(t\+1\)\}\-q^\{\(t\)\}has the same sign asa\(t\)a^\{\(t\)\}\(or is0\)\. Sincesign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)and

a\(t\+1\)=a\(t\)\+η2​q\(t\)as long as clipping is inactive,a^\{\(t\+1\)\}=a^\{\(t\)\}\+\\tfrac\{\\eta\}\{2\}q^\{\(t\)\}\\quad\\text\{as long as clipping is inactive,\}the signs ofa\(t\)a^\{\(t\)\}andq\(t\)q^\{\(t\)\}cannot flip; moreover\|a\(t\)\|\|a^\{\(t\)\}\|is nondecreasing, and clipping preserves the sign once\|a\(t\)\|\|a^\{\(t\)\}\|hits11\. This proves \(1\)\.

For \(2\), decomposew\(t\)w^\{\(t\)\}as

w\(t\)=ct​ua\+bt​u−a\+vt,w^\{\(t\)\}=c\_\{t\}u\_\{a\}\+b\_\{t\}u\_\{\-a\}\+v\_\{t\},wherevt⟂span​\{e1,e2\}v\_\{t\}\\perp\\mathrm\{span\}\\\{e\_\{1\},e\_\{2\}\\\}\. Then\(𝑰\+η​a​𝑴\)​ua=\(1\+η​\|a\|\)​ua\(\{\\bm\{I\}\}\+\\eta a\{\\bm\{M\}\}\)u\_\{a\}=\(1\+\\eta\|a\|\)u\_\{a\},\(𝑰\+η​a​𝑴\)​u−a=\(1−η​\|a\|\)​u−a\(\{\\bm\{I\}\}\+\\eta a\{\\bm\{M\}\}\)u\_\{\-a\}=\(1\-\\eta\|a\|\)u\_\{\-a\}, and\(𝑰\+η​a​𝑴\)​v=v\(\{\\bm\{I\}\}\+\\eta a\{\\bm\{M\}\}\)v=v\. Thus before normalization,

\(𝑰\+η​a​𝑴\)​w\(t\)=\(1\+η​\|a\|\)​ct​ua\+\(1−η​\|a\|\)​bt​u−a\+vt\.\(\{\\bm\{I\}\}\+\\eta a\{\\bm\{M\}\}\)w^\{\(t\)\}=\(1\+\\eta\|a\|\)c\_\{t\}u\_\{a\}\+\(1\-\\eta\|a\|\)b\_\{t\}u\_\{\-a\}\+v\_\{t\}\.After normalization, ratios between the orthogonal component and theuau\_\{a\}component is

rt\+1=\\displaystyle r\_\{t\+1\}=bt\+12\+‖vt\+1‖2\|ct\+1\|=\(1−η​\|a\(t\)\|\)2​bt2\+‖vt‖2\(1\+η​\|a\(t\)\|\)​\|ct\|≤bt2\+‖vt‖2\(1\+η​\|a\(t\)\|\)​\|ct\|=rt1\+η​\|a\(t\)\|\.\\displaystyle\\frac\{\\sqrt\{b\_\{t\+1\}^\{2\}\+\\\|v\_\{t\+1\}\\\|^\{2\}\}\}\{\|c\_\{t\+1\}\|\}=\\frac\{\\sqrt\{\(1\-\\eta\|a^\{\(t\)\}\|\)^\{2\}b\_\{t\}^\{2\}\+\\\|v\_\{t\}\\\|^\{2\}\}\}\{\(1\+\\eta\|a^\{\(t\)\}\|\)\|c\_\{t\}\|\}\\leq\\frac\{\\sqrt\{b\_\{t\}^\{2\}\+\\\|v\_\{t\}\\\|^\{2\}\}\}\{\(1\+\\eta\|a^\{\(t\)\}\|\)\|c\_\{t\}\|\}=\\frac\{r\_\{t\}\}\{1\+\\eta\|a^\{\(t\)\}\|\}\.\(A\.11\)Combined with \(1\), this shows thatrtr\_\{t\}contracts by at least a factor of11\+η​\|a\(0\)\|\\frac\{1\}\{1\+\\eta\|a^\{\(0\)\}\|\}, which is strictly smaller than 1 sincea\(0\)≠0a^\{\(0\)\}\\neq 0\.

The convergence timeT2T\_\{2\}follows fromlog⁡\(1\+η​\|a\|\)≥η​\|a\|1\+η​\|a\|≥η​\|a\|2\\log\(1\+\\eta\|a\|\)\\geq\\frac\{\\eta\|a\|\}\{1\+\\eta\|a\|\}\\geq\\frac\{\\eta\|a\|\}\{2\}, sinceη​\|a\|≤1\\eta\|a\|\\leq 1\.

∎

### A\.4Combining both phases

We have shown thatT⋆≤⌈2​\(a⋆−\|a\(0\)\|\)η​q⋆⌉T\_\{\\star\}\\leq\\left\\lceil\\frac\{2\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\}\{\\eta q\_\{\\star\}\}\\right\\rceil\([Lemma4](https://arxiv.org/html/2605.20314#Thmlemma4)\) andT2≤2η​\|a⋆\|​log⁡\(1α02​ε\)T\_\{2\}\\leq\\frac\{2\}\{\\eta\|a\_\{\\star\}\|\}\\log\\Big\(\\frac\{1\}\{\\alpha\_\{0\}^\{2\}\\varepsilon\}\\Big\)\([Lemma6](https://arxiv.org/html/2605.20314#Thmlemma6)\)\. To reason about the overall timeT⋆\+T2T\_\{\\star\}\+T\_\{2\}, it remains to check howα0\\alpha\_\{0\}depends ona⋆a\_\{\\star\}, which in turn depends on how muchwwmoves during Phase 1\.

In the following, we will first boundww’s drift \([Lemma7](https://arxiv.org/html/2605.20314#Thmlemma7)\) which will then allow us to relatea⋆a\_\{\\star\}andα0\\alpha\_\{0\}\([Lemma8](https://arxiv.org/html/2605.20314#Thmlemma8)\), and present the final convergence bound in[SectionA\.4\.3](https://arxiv.org/html/2605.20314#A1.SS4.SSS3)\.

#### A\.4\.1w\(t\)w^\{\(t\)\}grows slowly

We first need to bound how muchwwdrifts in phase 1\. We show that under the assumptions of Lemma[4](https://arxiv.org/html/2605.20314#Thmlemma4), the input weightw\(t\)w^\{\(t\)\}changes little up to timeT⋆T\_\{\\star\}\.

###### Lemma 7\(Control of inner weight drift up toT⋆T\_\{\\star\}\)\.

Under the assumptions of Lemma[4](https://arxiv.org/html/2605.20314#Thmlemma4),

‖w\(T⋆\)−w\(0\)‖2≤8​‖𝑴^‖2q⋆​a⋆​\(a⋆−\|a\(0\)\|\)\+\+4​η​‖𝑴^‖2​a⋆\.\\\|w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\\leq\\frac\{8\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\}\{q\_\{\\star\}\}a\_\{\\star\}\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\\;\+\\;4\\eta\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}a\_\{\\star\}\.\(A\.12\)

###### Proof\.

Write the pre\-normalization iterate as

w~\(t\+1\)=w\(t\)\+η​a\(t\)​𝑴^​w\(t\),w\(t\+1\)=w~\(t\+1\)/‖w~\(t\+1\)‖2\.\\tilde\{w\}^\{\(t\+1\)\}=w^\{\(t\)\}\+\\eta a^\{\(t\)\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(t\)\},\\qquad w^\{\(t\+1\)\}=\\tilde\{w\}^\{\(t\+1\)\}/\\\|\\tilde\{w\}^\{\(t\+1\)\}\\\|\_\{2\}\.Letu\(t\):=η​a\(t\)​𝑴^​w\(t\)u^\{\(t\)\}:=\\eta a^\{\(t\)\}\\widehat\{\{\\bm\{M\}\}\}w^\{\(t\)\}, sow~\(t\+1\)=w\(t\)\+u\(t\)\\tilde\{w\}^\{\(t\+1\)\}=w^\{\(t\)\}\+u^\{\(t\)\}\. Fort<T⋆t<T\_\{\\star\}we have\|a\(t\)\|≤a⋆\|a^\{\(t\)\}\|\\leq a\_\{\\star\}, hence

‖u\(t\)‖2≤η​\|a\(t\)\|​‖𝑴^‖2≤η​a⋆​‖𝑴^‖2≤12\.\\\|u^\{\(t\)\}\\\|\_\{2\}\\leq\\eta\|a^\{\(t\)\}\|\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq\\eta a\_\{\\star\}\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq\\frac\{1\}\{2\}\.For any unit vectorwwand anyuuwith‖u‖2≤1/2\\\|u\\\|\_\{2\}\\leq 1/2, one has the standard normalization Lipschitz bound

‖w\+u‖w\+u‖2−w‖2≤4​‖u‖2,\\left\\\|\\frac\{w\+u\}\{\\\|w\+u\\\|\_\{2\}\}\-w\\right\\\|\_\{2\}\\leq 4\\\|u\\\|\_\{2\},which we apply with\(w,u\)=\(w\(t\),u\(t\)\)\(w,u\)=\(w^\{\(t\)\},u^\{\(t\)\}\)to get

‖w\(t\+1\)−w\(t\)‖2≤4​‖u\(t\)‖2≤4​η​\|a\(t\)\|​‖𝑴^‖2\.\\\|w^\{\(t\+1\)\}\-w^\{\(t\)\}\\\|\_\{2\}\\leq 4\\\|u^\{\(t\)\}\\\|\_\{2\}\\leq 4\\eta\|a^\{\(t\)\}\|\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\.Summing overt=0,1,…,T⋆−1t=0,1,\\dots,T\_\{\\star\}\-1yields

‖w\(T⋆\)−w\(0\)‖2≤4​η​‖𝑴^‖2​∑t<T⋆\|a\(t\)\|\.\\\|w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\\leq 4\\eta\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\sum\_\{t<T\_\{\\star\}\}\|a^\{\(t\)\}\|\.Using the crude bound∑t<T⋆\|a\(t\)\|≤T⋆​a⋆\\sum\_\{t<T\_\{\\star\}\}\|a^\{\(t\)\}\|\\leq T\_\{\\star\}a\_\{\\star\}, which is sufficient for the final scaling, together with equation[A\.10](https://arxiv.org/html/2605.20314#A1.E10)gives

∑t<T⋆\|a\(t\)\|≤a⋆​\(2​\(a⋆−\|a\(0\)\|\)\+η​q⋆\+1\)=2​a⋆​\(a⋆−\|a\(0\)\|\)\+η​q⋆\+a⋆\.\\sum\_\{t<T\_\{\\star\}\}\|a^\{\(t\)\}\|\\leq a\_\{\\star\}\\left\(\\frac\{2\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\}\{\\eta q\_\{\\star\}\}\+1\\right\)=\\frac\{2a\_\{\\star\}\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\}\{\\eta q\_\{\\star\}\}\+a\_\{\\star\}\.Plugging in yields

‖w\(T⋆\)−w\(0\)‖2≤4​η​‖𝑴^‖2​\(2​a⋆​\(a⋆−\|a\(0\)\|\)\+η​q⋆\+a⋆\)=8​‖𝑴^‖2q⋆​a⋆​\(a⋆−\|a\(0\)\|\)\+\+4​η​‖𝑴^‖2​a⋆,\\\|w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\\leq 4\\eta\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\left\(\\frac\{2a\_\{\\star\}\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\}\{\\eta q\_\{\\star\}\}\+a\_\{\\star\}\\right\)=\\frac\{8\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\}\{q\_\{\\star\}\}a\_\{\\star\}\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\+4\\eta\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}a\_\{\\star\},which is equation[A\.12](https://arxiv.org/html/2605.20314#A1.E12)\. ∎

#### A\.4\.2Connectingα0\\alpha\_\{0\}anda⋆a\_\{\\star\}

###### Lemma 8\(Lower bound onα02\\alpha\_\{0\}^\{2\}in terms ofa⋆a\_\{\\star\}\)\.

Letu:=\(e1\+sign⁡\(a\(T⋆\)\)​e2\)/2u:=\(e\_\{1\}\+\\operatorname\{sign\}\(a^\{\(T\_\{\\star\}\)\}\)e\_\{2\}\)/\\sqrt\{2\}and define

α0:=\|⟨w\(T⋆\),u⟩\|\.\\alpha\_\{0\}:=\|\\langle w^\{\(T\_\{\\star\}\)\},u\\rangle\|\.On the event‖w\(T⋆\)−w\(0\)‖2≤εdrift\\\|w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\\leq\\varepsilon\_\{\\mathrm\{drift\}\}, we have

α02≥\(\|⟨w\(0\),u⟩\|−εdrift\)\+2\.\\alpha\_\{0\}^\{2\}\\geq\\bigl\(\|\\langle w^\{\(0\)\},u\\rangle\|\-\\varepsilon\_\{\\mathrm\{drift\}\}\\bigr\)\_\{\+\}^\{2\}\.Under the assumptions of Lemma[7](https://arxiv.org/html/2605.20314#Thmlemma7), we may take

εdrift:=8​‖𝑴^‖2q⋆​a⋆​\(a⋆−\|a\(0\)\|\)\+\+4​η​‖𝑴^‖2​a⋆,\\varepsilon\_\{\\mathrm\{drift\}\}:=\\frac\{8\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\}\{q\_\{\\star\}\}a\_\{\\star\}\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\\;\+\\;4\\eta\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}a\_\{\\star\},soα02\\alpha\_\{0\}^\{2\}is explicitly lower bounded in terms ofa⋆a\_\{\\star\}\.

###### Proof\.

By Cauchy–Schwarz,\|⟨w\(T⋆\),u⟩−⟨w\(0\),u⟩\|≤‖w\(T⋆\)−w\(0\)‖2​‖u‖2=‖w\(T⋆\)−w\(0\)‖2\|\\langle w^\{\(T\_\{\\star\}\)\},u\\rangle\-\\langle w^\{\(0\)\},u\\rangle\|\\leq\\\|w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\\\|u\\\|\_\{2\}=\\\|w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\. This implies\|⟨w\(T⋆\),u⟩\|≥\|⟨w\(0\),u⟩\|−εdrift\|\\langle w^\{\(T\_\{\\star\}\)\},u\\rangle\|\\geq\|\\langle w^\{\(0\)\},u\\rangle\|\-\\varepsilon\_\{\\mathrm\{drift\}\}on the event\. The explicit choice ofεdrift\\varepsilon\_\{\\mathrm\{drift\}\}is equation[A\.12](https://arxiv.org/html/2605.20314#A1.E12)\. ∎

###### Lemma 9\(Constant\-probability lower bound on random initialization alignment\)\.

Letu∈ℝdu\\in\\mathbb\{R\}^\{d\}be any fixed unit vector and letw\(0\)w^\{\(0\)\}be uniform on the unit sphere\. Then there exists a universal constantpalign\>0p\_\{\\mathrm\{align\}\}\>0such that for alld≥2d\\geq 2,

𝐏𝐫\[\|⟨w\(0\),u⟩\|≥12​d\]≥palign\.\\mathop\{\\bf Pr\\/\}\\\!\\left\[\\,\|\\langle w^\{\(0\)\},u\\rangle\|\\ \\geq\\ \\frac\{1\}\{2\\sqrt\{d\}\}\\,\\right\]\\ \\geq\\ p\_\{\\mathrm\{align\}\}\.

###### Proof\.

Writew\(0\)=g/‖g‖2w^\{\(0\)\}=g/\\\|g\\\|\_\{2\}forg∼𝒩​\(0,𝑰\)g\\sim\\mathcal\{N\}\(0,\{\\bm\{I\}\}\)and rotate so thatu=e1u=e\_\{1\}\. Then\|⟨w\(0\),u⟩\|=\|g1\|/‖g‖2\|\\langle w^\{\(0\)\},u\\rangle\|=\|g\_\{1\}\|/\\\|g\\\|\_\{2\}\. On the event\{\|g1\|≥1\}∩\{‖g‖2≤2​d\}\\\{\|g\_\{1\}\|\\geq 1\\\}\\cap\\\{\\\|g\\\|\_\{2\}\\leq 2\\sqrt\{d\}\\\}we have\|g1\|/‖g‖2≥1/\(2​d\)\|g\_\{1\}\|/\\\|g\\\|\_\{2\}\\geq 1/\(2\\sqrt\{d\}\)\.𝐏𝐫\[\{\|g1\|≥1\}∩\{‖g‖2≤2​d\}\]≥𝐏𝐫\[\|g1\|≥1\]−𝐏𝐫\[‖g‖2\>2​d\]\\mathop\{\\bf Pr\\/\}\\left\[\\\{\|g\_\{1\}\|\\geq 1\\\}\\cap\\\{\\\|g\\\|\_\{2\}\\leq 2\\sqrt\{d\}\\\}\\right\]\\geq\\mathop\{\\bf Pr\\/\}\\left\[\|g\_\{1\}\|\\geq 1\\right\]\-\\mathop\{\\bf Pr\\/\}\\left\[\\\|g\\\|\_\{2\}\>2\\sqrt\{d\}\\right\]\. The former one has constant probability and the latter one decays exponentially withdd, so the intersection has probability at least some universal constantpalign\>0p\_\{\\mathrm\{align\}\}\>0\. ∎

###### Lemma 10\(Constant\-probability simultaneous phase\-1 bootstrap and population sign alignment\)\.

Let

u±:=e1±e22,P12:=u\+​u\+⊤\+u−​u−⊤,qpop​\(w\):=w⊤​𝑴​w\.u\_\{\\pm\}:=\\frac\{e\_\{1\}\\pm e\_\{2\}\}\{\\sqrt\{2\}\},\\qquad P\_\{12\}:=u\_\{\+\}u\_\{\+\}^\{\\top\}\+u\_\{\-\}u\_\{\-\}^\{\\top\},\\qquad q\_\{\\mathrm\{pop\}\}\(w\):=w^\{\\top\}\{\\bm\{M\}\}w\.
There exist universal constantsc0\>0c\_\{0\}\>0andpall\>0p\_\{\\mathrm\{all\}\}\>0such that, with

𝒫0:=\{\|qpop​\(w\(0\)\)\|≥34​d,‖P12​w\(0\)‖2≤1d\}\\mathcal\{P\}\_\{0\}:=\\left\\\{\|q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\|\\geq\\frac\{3\}\{4d\},\\quad\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}\\leq\\frac\{1\}\{\\sqrt\{d\}\}\\right\\\}and

𝒢sign:=\{sign\(a\(0\)\)=sign\(q\(0\)\)=sign\(qpop\(w\(0\)\)\),\|q\(0\)\|≥c0N\},\\mathcal\{G\}\_\{\\mathrm\{sign\}\}:=\\left\\\{\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)=\\operatorname\{sign\}\(q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\),\\quad\|q^\{\(0\)\}\|\\geq\\frac\{c\_\{0\}\}\{\\sqrt\{N\}\}\\right\\\},we have

𝐏𝐫\[𝒢sign∩𝒫0\]≥pall\.\\mathop\{\\bf Pr\\/\}\\big\[\\mathcal\{G\}\_\{\\mathrm\{sign\}\}\\cap\\mathcal\{P\}\_\{0\}\\big\]\\geq p\_\{\\mathrm\{all\}\}\.In particular, on this event, the empirical sign used to growaain phase 1 is already the population sign that will be needed at the start of phase 2\.

###### Proof\.

Writez±:=⟨w\(0\),u±⟩z\_\{\\pm\}:=\\langle w^\{\(0\)\},u\_\{\\pm\}\\rangleandr2:=z\+2\+z−2=‖P12​w\(0\)‖22r^\{2\}:=z\_\{\+\}^\{2\}\+z\_\{\-\}^\{2\}=\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}^\{2\}\. Conditional onrr, the angle\(z\+,z−\)/r\(z\_\{\+\},z\_\{\-\}\)/ris uniform on the unit circle, andr2∼Beta​\(1,\(d−2\)/2\)r^\{2\}\\sim\\mathrm\{Beta\}\(1,\(d\-2\)/2\)\. Consider the event

ℛ:=\{910​d≤r2≤1d,\|cos\(2θ\)\|≥56\},\(z\+,z−\)=r\(cosθ,sinθ\)\.\\mathcal\{R\}:=\\left\\\{\\frac\{9\}\{10d\}\\leq r^\{2\}\\leq\\frac\{1\}\{d\},\\qquad\|\\cos\(2\\theta\)\|\\geq\\frac\{5\}\{6\}\\right\\\},\\qquad\(z\_\{\+\},z\_\{\-\}\)=r\(\\cos\\theta,\\sin\\theta\)\.Onℛ\\mathcal\{R\},

\|qpop​\(w\(0\)\)\|=\|z\+2−z−2\|=r2​\|cos⁡\(2​θ\)\|≥34​d,‖P12​w\(0\)‖2=r≤1d,\|q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\|=\|z\_\{\+\}^\{2\}\-z\_\{\-\}^\{2\}\|=r^\{2\}\|\\cos\(2\\theta\)\|\\geq\\frac\{3\}\{4d\},\\qquad\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}=r\\leq\\frac\{1\}\{\\sqrt\{d\}\},soℛ⊆𝒫0\\mathcal\{R\}\\subseteq\\mathcal\{P\}\_\{0\}\. The radial probability of\{9/\(10​d\)≤r2≤1/d\}\\\{9/\(10d\)\\leq r^\{2\}\\leq 1/d\\\}is bounded below by a universal constant for alld≥3d\\geq 3after decreasing the constant to cover the finitely many small dimensions, and the angular event\{\|cos⁡\(2​θ\)\|≥5/6\}\\\{\|\\cos\(2\\theta\)\|\\geq 5/6\\\}also has universal positive probability\. Hence𝐏𝐫\[𝒫0\]≥ppop\>0\\mathop\{\\bf Pr\\/\}\[\\mathcal\{P\}\_\{0\}\]\\geq p\_\{\\mathrm\{pop\}\}\>0\.

Fix anyw∈𝒫0w\\in\\mathcal\{P\}\_\{0\}and setμ:=qpop​\(w\)\\mu:=q\_\{\\mathrm\{pop\}\}\(w\)\. For one phase–1 sample, letZ:=y​\(x⊤​w\)2Z:=y\(x^\{\\top\}w\)^\{2\}, so that𝔼​\[Z∣w\]=μ\\mathbb\{E\}\[Z\\mid w\]=\\muandq\(0\)=N−1​∑s=1NZsq^\{\(0\)\}=N^\{\-1\}\\sum\_\{s=1\}^\{N\}Z\_\{s\}\. Since‖P12​w‖22≤1/d≤1/2\\\|P\_\{12\}w\\\|\_\{2\}^\{2\}\\leq 1/d\\leq 1/2, the variance lower bound in Lemma[1](https://arxiv.org/html/2605.20314#Thmlemma1)givesVar​\(Z∣w\)≥3/4\\mathrm\{Var\}\(Z\\mid w\)\\geq 3/4\. The same hypercontractive fourth\-moment bound used in Lemma[1](https://arxiv.org/html/2605.20314#Thmlemma1), applied toN​\(q\(0\)−μ\)\\sqrt\{N\}\(q^\{\(0\)\}\-\\mu\), gives a universal fourth\-moment upper bound\.

We use the following elementary one\-sided consequence of these two moment bounds: ifXXis mean zero,𝔼​\[X2\]=σ2\\mathbb\{E\}\[X^\{2\}\]=\\sigma^\{2\}, and𝔼​\[X4\]≤K​σ4\\mathbb\{E\}\[X^\{4\}\]\\leq K\\sigma^\{4\}, then there are constantscK,pK\>0c\_\{K\},p\_\{K\}\>0depending only onKKsuch that𝐏𝐫\[X≥cK​σ\]≥pK\\mathop\{\\bf Pr\\/\}\[X\\geq c\_\{K\}\\sigma\]\\geq p\_\{K\}\. Indeed, writingX\+=max⁡\{X,0\}X\_\{\+\}=\\max\\\{X,0\\\}andX−=max⁡\{−X,0\}X\_\{\-\}=\\max\\\{\-X,0\\\}, the identity𝔼​\[X\+\]=𝔼​\[X−\]\\mathbb\{E\}\[X\_\{\+\}\]=\\mathbb\{E\}\[X\_\{\-\}\]and interpolation betweenL1,L2,L4L\_\{1\},L\_\{2\},L\_\{4\}norms give𝔼​\[X\+\]≥σ/\(23/2​K\)\\mathbb\{E\}\[X\_\{\+\}\]\\geq\\sigma/\(2^\{3/2\}\\sqrt\{K\}\)\. Therefore, withθ:=1/\(25/2​K\)\\theta:=1/\(2^\{5/2\}\\sqrt\{K\}\),

𝔼​\[X\+\]≤θ​σ\+\(𝔼​\[X\+2\]\)1/2​𝐏𝐫\[X≥θ​σ\]1/2≤θ​σ\+σ​𝐏𝐫\[X≥θ​σ\]1/2,\\mathbb\{E\}\[X\_\{\+\}\]\\leq\\theta\\sigma\+\\big\(\\mathbb\{E\}\[X\_\{\+\}^\{2\}\]\\big\)^\{1/2\}\\mathop\{\\bf Pr\\/\}\[X\\geq\\theta\\sigma\]^\{1/2\}\\leq\\theta\\sigma\+\\sigma\\mathop\{\\bf Pr\\/\}\[X\\geq\\theta\\sigma\]^\{1/2\},which implies𝐏𝐫\[X≥θ​σ\]≥θ2\\mathop\{\\bf Pr\\/\}\[X\\geq\\theta\\sigma\]\\geq\\theta^\{2\}\. Thus one may takecK=θc\_\{K\}=\\thetaandpK=θ2p\_\{K\}=\\theta^\{2\}\. Applying this toX:=sign⁡\(μ\)​N​\(q\(0\)−μ\)X:=\\operatorname\{sign\}\(\\mu\)\\sqrt\{N\}\(q^\{\(0\)\}\-\\mu\), and usingsign⁡\(μ\)​N​q\(0\)=X\+\|μ\|​N≥X\\operatorname\{sign\}\(\\mu\)\\sqrt\{N\}\\,q^\{\(0\)\}=X\+\|\\mu\|\\sqrt\{N\}\\geq X, gives, after decreasingc0c\_\{0\}if necessary, a universalpone\>0p\_\{\\mathrm\{one\}\}\>0such that

𝐏𝐫\[sign\(μ\)q\(0\)≥c0N\|w\(0\)=w\]≥pone\.\\mathop\{\\bf Pr\\/\}\\left\[\\operatorname\{sign\}\(\\mu\)\\,q^\{\(0\)\}\\geq\\frac\{c\_\{0\}\}\{\\sqrt\{N\}\}\\ \\middle\|\\ w^\{\(0\)\}=w\\right\]\\geq p\_\{\\mathrm\{one\}\}\.Equivalently, conditional onw∈𝒫0w\\in\\mathcal\{P\}\_\{0\}, the empirical quadratic form has the same sign as the population quadratic form and has magnitude at leastc0/Nc\_\{0\}/\\sqrt\{N\}with probability at leastponep\_\{\\mathrm\{one\}\}\. Finally,a\(0\)a^\{\(0\)\}is independent of the samples and is symmetric about zero, so conditional on\(w\(0\),𝑴^\)\(w^\{\(0\)\},\\widehat\{\{\\bm\{M\}\}\}\)the eventsign⁡\(a\(0\)\)=sign⁡\(q\(0\)\)\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\)contributes an additional factor1/21/2\. Therefore

𝐏𝐫\[𝒢sign∩𝒫0\]≥12​ppop​pone:=pall\>0\.\\mathop\{\\bf Pr\\/\}\\big\[\\mathcal\{G\}\_\{\\mathrm\{sign\}\}\\cap\\mathcal\{P\}\_\{0\}\\big\]\\geq\\frac\{1\}\{2\}p\_\{\\mathrm\{pop\}\}p\_\{\\mathrm\{one\}\}:=p\_\{\\mathrm\{all\}\}\>0\.∎

###### Lemma 11\(Phase\-1 drift preserves the population sign\)\.

Defineqpop​\(w\):=w⊤​𝐌​wq\_\{\\mathrm\{pop\}\}\(w\):=w^\{\\top\}\{\\bm\{M\}\}wandP12:=u\+​u\+⊤\+u−​u−⊤P\_\{12\}:=u\_\{\+\}u\_\{\+\}^\{\\top\}\+u\_\{\-\}u\_\{\-\}^\{\\top\}as in[Lemma10](https://arxiv.org/html/2605.20314#Thmlemma10)\. Suppose

\|qpop​\(w\(0\)\)\|≥34​d,‖P12​w\(0\)‖2≤1d,‖w\(T⋆\)−w\(0\)‖2≤14​d\.\|q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\|\\geq\\frac\{3\}\{4d\},\\qquad\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}\\leq\\frac\{1\}\{\\sqrt\{d\}\},\\qquad\\\|w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\\leq\\frac\{1\}\{4\\sqrt\{d\}\}\.Then

sign⁡\(qpop​\(w\(T⋆\)\)\)=sign⁡\(qpop​\(w\(0\)\)\)\.\\operatorname\{sign\}\\big\(q\_\{\\mathrm\{pop\}\}\(w^\{\(T\_\{\\star\}\)\}\)\\big\)=\\operatorname\{sign\}\\big\(q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\\big\)\.Moreover, ifs:=sign⁡\(qpop​\(w\(0\)\)\)s:=\\operatorname\{sign\}\(q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\)andus:=\(e1\+s​e2\)/2u\_\{s\}:=\(e\_\{1\}\+se\_\{2\}\)/\\sqrt\{2\}, then

\|⟨w\(T⋆\),us⟩\|≥\(32−14\)​1d≥12​d\.\|\\langle w^\{\(T\_\{\\star\}\)\},u\_\{s\}\\rangle\|\\geq\\left\(\\frac\{\\sqrt\{3\}\}\{2\}\-\\frac\{1\}\{4\}\\right\)\\frac\{1\}\{\\sqrt\{d\}\}\\geq\\frac\{1\}\{2\\sqrt\{d\}\}\.

###### Proof\.

LetΔ:=w\(T⋆\)−w\(0\)\\Delta:=w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\. Since‖𝑴‖2=1\\\|\{\\bm\{M\}\}\\\|\_\{2\}=1and‖𝑴​w\(0\)‖2=‖P12​w\(0\)‖2\\\|\{\\bm\{M\}\}w^\{\(0\)\}\\\|\_\{2\}=\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}, we have

\|qpop​\(w\(T⋆\)\)−qpop​\(w\(0\)\)\|\\displaystyle\\left\|q\_\{\\mathrm\{pop\}\}\(w^\{\(T\_\{\\star\}\)\}\)\-q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\\right\|=\|2​\(w\(0\)\)⊤​𝑴​Δ\+Δ⊤​𝑴​Δ\|\\displaystyle=\\left\|2\(w^\{\(0\)\}\)^\{\\top\}\{\\bm\{M\}\}\\Delta\+\\Delta^\{\\top\}\{\\bm\{M\}\}\\Delta\\right\|≤2​‖P12​w\(0\)‖2​‖Δ‖2\+‖Δ‖22\\displaystyle\\leq 2\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}\\\|\\Delta\\\|\_\{2\}\+\\\|\\Delta\\\|\_\{2\}^\{2\}≤12​d\+116​d=916​d<34​d\.\\displaystyle\\leq\\frac\{1\}\{2d\}\+\\frac\{1\}\{16d\}=\\frac\{9\}\{16d\}<\\frac\{3\}\{4d\}\.Thus the perturbation is smaller than the initial population margin, so the sign ofqpopq\_\{\\mathrm\{pop\}\}is preserved\.

For the alignment claim, writezs:=⟨w\(0\),us⟩z\_\{s\}:=\\langle w^\{\(0\)\},u\_\{s\}\\rangleandz−s:=⟨w\(0\),u−s⟩z\_\{\-s\}:=\\langle w^\{\(0\)\},u\_\{\-s\}\\rangle\. Sinces​qpop​\(w\(0\)\)=zs2−z−s2≥3/\(4​d\)s\\,q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)=z\_\{s\}^\{2\}\-z\_\{\-s\}^\{2\}\\geq 3/\(4d\), we have\|zs\|≥3/\(2​d\)\|z\_\{s\}\|\\geq\\sqrt\{3\}/\(2\\sqrt\{d\}\)\. Cauchy\-Schwarz gives

\|⟨w\(T⋆\),us⟩\|≥\|⟨w\(0\),us⟩\|−‖w\(T⋆\)−w\(0\)‖2≥\(32−14\)​1d≥12​d\.\|\\langle w^\{\(T\_\{\\star\}\)\},u\_\{s\}\\rangle\|\\geq\|\\langle w^\{\(0\)\},u\_\{s\}\\rangle\|\-\\\|w^\{\(T\_\{\\star\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\\geq\\left\(\\frac\{\\sqrt\{3\}\}\{2\}\-\\frac\{1\}\{4\}\\right\)\\frac\{1\}\{\\sqrt\{d\}\}\\geq\\frac\{1\}\{2\\sqrt\{d\}\}\.∎

#### A\.4\.3Final Convergence Bound

Fixε∈\(0,1/2\)\\varepsilon\\in\(0,1/2\)andδ∈\(0,1\)\\delta\\in\(0,1\)\. Run phase 1 on a fixed batch of sizeNNusing updates equation[A\.3](https://arxiv.org/html/2605.20314#A1.E3)–equation[A\.4](https://arxiv.org/html/2605.20314#A1.E4)until time

T1:=min⁡\{t:\|a\(t\)\|≥a⋆\}\.T\_\{1\}:=\\min\\\{t:\\ \|a^\{\(t\)\}\|\\geq a\_\{\\star\}\\\}\.
Letpall\>0p\_\{\\mathrm\{all\}\}\>0be the universal constant from Lemma[10](https://arxiv.org/html/2605.20314#Thmlemma10)\. Then, by intersecting Lemma[10](https://arxiv.org/html/2605.20314#Thmlemma10)with the operator\-norm event\{‖𝑴^‖2≤BN,d,δ\}\\\{\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq B\_\{N,d,\\delta\}\\\}, we get

𝐏𝐫\[𝒢sign∩𝒫0∩\{‖𝑴^‖2≤BN,d,δ\}\]≥pall−δ\.\\mathop\{\\bf Pr\\/\}\\big\[\\mathcal\{G\}\_\{\\mathrm\{sign\}\}\\cap\\mathcal\{P\}\_\{0\}\\cap\\\{\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq B\_\{N,d,\\delta\}\\\}\\big\]\\geq p\_\{\\mathrm\{all\}\}\-\\delta\.On this event, the following hold:

1. 1\.Phase 1 time\.We have T1≤⌈2​\(a⋆−\|a\(0\)\|\)\+η​q⋆⌉≤⌈2​a⋆η​c0​N⌉\.T\_\{1\}\\leq\\left\\lceil\\frac\{2\(a\_\{\\star\}\-\|a^\{\(0\)\}\|\)\_\{\+\}\}\{\\eta\\,q\_\{\\star\}\}\\right\\rceil\\leq\\left\\lceil\\frac\{2a\_\{\\star\}\}\{\\eta\\,c\_\{0\}\}\\sqrt\{N\}\\right\\rceil\.
2. 2\.Phase 2 alignment\.Switch to population gradients \(𝑴^←𝑴\\widehat\{\{\\bm\{M\}\}\}\\leftarrow\{\\bm\{M\}\}\), and run the population contraction onwwfor T2:=⌈2η​a⋆​log⁡\(16​dε\)⌉T\_\{2\}:=\\left\\lceil\\frac\{2\}\{\\eta a\_\{\\star\}\}\\log\\Big\(\\frac\{16d\}\{\\varepsilon\}\\Big\)\\right\\rceilsteps\. By Lemma[7](https://arxiv.org/html/2605.20314#Thmlemma7)and the above choice ofa⋆a\_\{\\star\}, we have ‖w\(T1\)−w\(0\)‖2≤εdrift≤14​d\.\\\|w^\{\(T\_\{1\}\)\}\-w^\{\(0\)\}\\\|\_\{2\}\\leq\\varepsilon\_\{\\mathrm\{drift\}\}\\leq\\frac\{1\}\{4\\sqrt\{d\}\}\.Since𝒢sign\\mathcal\{G\}\_\{\\mathrm\{sign\}\}givessign⁡\(a\(0\)\)=sign⁡\(qpop​\(w\(0\)\)\)\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\)and phase 1 preservessign⁡\(a\(t\)\)\\operatorname\{sign\}\(a^\{\(t\)\}\)up toT1T\_\{1\}, Lemma[11](https://arxiv.org/html/2605.20314#Thmlemma11)gives the missing population sign condition sign⁡\(a\(T1\)\)=sign⁡\(\(w\(T1\)\)⊤​𝑴​w\(T1\)\)\.\\operatorname\{sign\}\(a^\{\(T\_\{1\}\)\}\)=\\operatorname\{sign\}\\big\(\(w^\{\(T\_\{1\}\)\}\)^\{\\top\}\{\\bm\{M\}\}w^\{\(T\_\{1\}\)\}\\big\)\.The same lemma also gives α02:=\|⟨w\(T1\),e1\+sign⁡\(a\(T1\)\)​e22⟩\|2≥14​d\.\\alpha\_\{0\}^\{2\}:=\\left\|\\left\\langle w^\{\(T\_\{1\}\)\},\\frac\{e\_\{1\}\+\\operatorname\{sign\}\(a^\{\(T\_\{1\}\)\}\)e\_\{2\}\}\{\\sqrt\{2\}\}\\right\\rangle\\right\|^\{2\}\\geq\\frac\{1\}\{4d\}\.Therefore Lemma[6](https://arxiv.org/html/2605.20314#Thmlemma6), applied from the phase–2 starting point and using\|a\(T1\)\|≥a⋆\|a^\{\(T\_\{1\}\)\}\|\\geq a\_\{\\star\}, yieldsαT22≥1−ε\\alpha\_\{T\_\{2\}\}^\{2\}\\geq 1\-\\varepsilon, whereαt:=\|⟨w\(t\),u⟩\|\\alpha\_\{t\}:=\|\\langle w^\{\(t\)\},u\\rangle\|withu=\(e1\+sign⁡\(a\(T1\)\)​e2\)/2u=\(e\_\{1\}\+\\operatorname\{sign\}\(a^\{\(T\_\{1\}\)\}\)e\_\{2\}\)/\\sqrt\{2\}\.

##### Interpreting the Result

Let

BN,d,δ:=1\+C​\(d​log⁡\(2​d/δ\)N\+d​log⁡\(2​d/δ\)N\)B\_\{N,d,\\delta\}:=1\+C\\left\(\\sqrt\{\\frac\{d\\log\(2d/\\delta\)\}\{N\}\}\\;\+\\;\\frac\{d\\log\(2d/\\delta\)\}\{N\}\\right\)be the deterministic bound from Lemma[5](https://arxiv.org/html/2605.20314#Thmlemma5), so that𝐏𝐫\[‖𝑴^‖2≤BN,d,δ\]≥1−δ\\mathop\{\\bf Pr\\/\}\[\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq B\_\{N,d,\\delta\}\]\\geq 1\-\\delta\. Letc0\>0c\_\{0\}\>0be the universal constant from Lemma[10](https://arxiv.org/html/2605.20314#Thmlemma10), chosen small enough that Lemma[1](https://arxiv.org/html/2605.20314#Thmlemma1)also applies, and setq⋆:=c0/Nq\_\{\\star\}:=c\_\{0\}/\\sqrt\{N\}\. Choose

a⋆:=min⁡\{1,12​η​BN,d,δ,132​η​BN,d,δ​d,c064​BN,d,δ​N​d\}\.a\_\{\\star\}:=\\min\\left\\\{1,\\;\\frac\{1\}\{2\\eta B\_\{N,d,\\delta\}\},\\;\\frac\{1\}\{32\\eta B\_\{N,d,\\delta\}\\sqrt\{d\}\},\\;\\sqrt\{\\frac\{c\_\{0\}\}\{64\\,B\_\{N,d,\\delta\}\\sqrt\{N\}\\sqrt\{d\}\}\}\\right\\\}\.In particular, ford≤N≤d2d\\leq N\\leq d^\{2\},

a⋆=c064​BN,d,δ​N​d=O​\(1\(N​d\)1/4\),a\_\{\\star\}=\\sqrt\{\\frac\{c\_\{0\}\}\{64\\,B\_\{N,d,\\delta\}\\sqrt\{N\}\\sqrt\{d\}\}\}=O\\left\(\\frac\{1\}\{\(Nd\)^\{1/4\}\}\\right\),
which yields

T1≲N1/4η​d1/4,T2≲\(N​d\)1/4η​log⁡\(dε\)\.\\displaystyle T\_\{1\}\\ \\lesssim\\ \\frac\{N^\{1/4\}\}\{\\eta d^\{1/4\}\},\\qquad T\_\{2\}\\ \\lesssim\\ \\frac\{\(Nd\)^\{1/4\}\}\{\\eta\}\\log\\\!\\Big\(\\frac\{d\}\{\\varepsilon\}\\Big\)\.\(A\.13\)
The total number of steps needed decreases asNNgets smaller\. The gain of the two\-phase schedule is that it avoids entering phase 2 with a*too small*outer gain\|a\|\|a\|: since theww\-update is scaled byata\_\{t\}, small\|at\|\|a\_\{t\}\|slows representation learning even ifw0w\_\{0\}has typical random alignment\. Phase 1 increases\|at\|\|a\_\{t\}\|using the stronger fixed\-batch bootstrap signal\|q\(0\)\|∼1/N\|q^\{\(0\)\}\|\\sim 1/\\sqrt\{N\}, whereas the analogous population bootstrap signal at random initialization is only\|w⊤​𝑴​w\|=Θ​\(1/d\)\|w^\{\\top\}\{\\bm\{M\}\}w\|=\\Theta\(1/d\)\.

### A\.5Proof of[Corollary2](https://arxiv.org/html/2605.20314#Thmtheorem2): training first phase on random labels

We prove[Corollary2](https://arxiv.org/html/2605.20314#Thmtheorem2), which provides the convergence for a modified 2\-phase training where the first phase uses random labelsyydrawn i\.i\.d\. uniformly on\{−1,\+1\}\\\{\-1,\+1\\\}\. The proof largely follows that of[Theorem1](https://arxiv.org/html/2605.20314#Thmtheorem1), and we highlight modifications below\.

##### Phase 1: small\-set training with random labels

We state the random\-label versions of the constant\-probability lower bound for\|q\(0\)\|\|q^\{\(0\)\}\|\([Lemma1](https://arxiv.org/html/2605.20314#Thmlemma1)\) and the matrix Bernstein bound for𝑴^\\widehat\{\{\\bm\{M\}\}\}\([Lemma5](https://arxiv.org/html/2605.20314#Thmlemma5)\)\.

###### Corollary 5\.

Assumew\(0\)w^\{\(0\)\}is uniform on the unit sphere and the labelyyfor each sample is drawn uniformly from\{−1,1\}\\\{\-1,1\\\}\. Then there exist universal constantscr\>0c\_\{r\}\>0such that for allN≥1N\\geq 1and alld≥3d\\geq 3,

𝐏𝐫\[\|q\(0\)\|≥crN\]≥pP​Z​\(cr\),\\mathop\{\\bf Pr\\/\}\\left\[\|q^\{\(0\)\}\|\\ \\geq\\ \\frac\{c\_\{r\}\}\{\\sqrt\{N\}\}\\right\]\\ \\geq\\ p\_\{PZ\}\(c\_\{r\}\),wherepP​Z​\(c\):=\(1−1/2\)⋅\(1−43​c2\)238p\_\{PZ\}\(c\):=\(1\-1/\\sqrt\{2\}\)\\cdot\\frac\{\(1\-\\frac\{4\}\{3\}c^\{2\}\)^\{2\}\}\{3^\{8\}\}as in[Lemma1](https://arxiv.org/html/2605.20314#Thmlemma1)\.

###### Proof\.

The calculation follows the same way as[Lemma1](https://arxiv.org/html/2605.20314#Thmlemma1)\. The conditional mean and second moment are

μ​\(w\)=𝔼​\[Z∣w\]=𝔼​\[y​\(∑i=1dwi​xi\)2\]=𝔼​\[y\]​𝔼​\[\(∑i=1dwi​xi\)2\]=0\\displaystyle\\mu\(w\)=\\mathbb\{E\}\[Z\\mid w\]=\\mathbb\{E\}\\left\[y\\Big\(\\sum\_\{i=1\}^\{d\}w\_\{i\}x\_\{i\}\\Big\)^\{2\}\\right\]=\\mathbb\{E\}\[y\]\\mathbb\{E\}\\left\[\\Big\(\\sum\_\{i=1\}^\{d\}w\_\{i\}x\_\{i\}\\Big\)^\{2\}\\right\]=0𝔼​\[Z2∣w\]=𝔼​\[\(x⊤​w\)4∣w\]≥1\\displaystyle\\mathbb\{E\}\[Z^\{2\}\\mid w\]=\\mathbb\{E\}\\big\[\(x^\{\\top\}w\)^\{4\}\\mid w\\big\]\\geq 1respectively\. Therefore, we have

𝔼​\[\(q\(0\)\)2∣w\]=1N​𝔼​\[Z2∣w\]\+1N2​∑i≠j𝔼​\[Zi∣w\]​𝔼​\[Zj∣w\]≥1N\.\\mathbb\{E\}\[\(q^\{\(0\)\}\)^\{2\}\\mid w\]=\\frac\{1\}\{N\}\\mathbb\{E\}\[Z^\{2\}\\mid w\]\+\\frac\{1\}\{N^\{2\}\}\\sum\_\{i\\neq j\}\\mathbb\{E\}\[Z\_\{i\}\\mid w\]\\mathbb\{E\}\[Z\_\{j\}\\mid w\]\\geq\\frac\{1\}\{N\}\.Likewise, applying the hypercontractivity inequality and Paley\-Zygmund to the nonnegative random variableY:=\(q\(0\)\)2Y:=\(q^\{\(0\)\}\)^\{2\}, we can get

𝐏𝐫\[\|q\(0\)\|≥12​1N\]≥pr\\mathop\{\\bf Pr\\/\}\\left\[\|q^\{\(0\)\}\|\\geq\\frac\{1\}\{\\sqrt\{2\}\}\\frac\{1\}\{\\sqrt\{N\}\}\\right\]\\geq p\_\{r\}∎

###### Corollary 6\(Bound for random label‖𝑴^‖2\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\)\.

Foryyuniformly sampled from\{−1,\+1\}\\\{\-1,\+1\\\}, for anyδ∈\(0,1\)\\delta\\in\(0,1\), with probability at least1−δ1\-\\delta,

‖𝑴^‖2≤O​\(d​log⁡\(2​d/δ\)N\+d​log⁡\(2​d/δ\)N\)\\\|\\widehat\{\{\\bm\{M\}\}\}\\\|\_\{2\}\\leq O\\left\(\\sqrt\{\\frac\{d\\log\(2d/\\delta\)\}\{N\}\}\\;\+\\;\\frac\{d\\log\(2d/\\delta\)\}\{N\}\\right\)

###### Proof\.

The calculation is the same as[Lemma5](https://arxiv.org/html/2605.20314#Thmlemma5), withM:=𝔼​\[y​x​x⊤\]=0M:=\\mathbb\{E\}\[yxx^\{\\top\}\]=0becauseyyis uniformly random\. ∎

Replacing[Lemma1](https://arxiv.org/html/2605.20314#Thmlemma1)and[Lemma5](https://arxiv.org/html/2605.20314#Thmlemma5)with[Corollary5](https://arxiv.org/html/2605.20314#Thmtheorem5)and[Corollary6](https://arxiv.org/html/2605.20314#Thmtheorem6), we setq⋆:=cr/Nq\_\{\\star\}:=c\_\{r\}/\\sqrt\{N\}, and the first phase follows the true label case\.

##### Phase 2: population training with real labels

We give the random label version of the simultaneous sign alignment as in[Lemma10](https://arxiv.org/html/2605.20314#Thmlemma10)\.

###### Lemma 12\(Random\-label bootstrap and population sign alignment\)\.

Suppose that in phase 1 the labels are independent random signsξ\(s\)∼Unif​\{−1,\+1\},\\xi^\{\(s\)\}\\sim\\text\{Unif\}\\\{\-1,\+1\\\},independent of the inputs, and define the random\-label empirical matrix

𝑴~:=1N​∑s=1Nξ\(s\)​x\(s\)​x\(s\)⊤\.\\widetilde\{\\bm\{M\}\}:=\\frac\{1\}\{N\}\\sum\_\{s=1\}^\{N\}\\xi^\{\(s\)\}x^\{\(s\)\}x^\{\(s\)\\top\}\.Let

q\(0\):=w\(0\)⊤​𝑴~​w\(0\),qpop​\(w\):=w⊤​𝑴​w\.q^\{\(0\)\}:=\{w^\{\(0\)\}\}^\{\\top\}\\widetilde\{\{\\bm\{M\}\}\}\\,w^\{\(0\)\},\\quad q\_\{\\mathrm\{pop\}\}\(w\):=w^\{\\top\}\{\\bm\{M\}\}w\.There exist universal constantscr\>0c\_\{r\}\>0andpallrand\>0p\_\{\\mathrm\{all\}\}^\{\\mathrm\{rand\}\}\>0such that, with

𝒫0rand:=\{sign⁡\(a\(0\)\)​qpop​\(w\(0\)\)≥34​d,‖P12​w\(0\)‖2≤1d\}\\mathcal\{P\}\_\{0\}^\{\\mathrm\{rand\}\}:=\\left\\\{\\operatorname\{sign\}\(a^\{\(0\)\}\)\\,q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\\geq\\frac\{3\}\{4d\},\\quad\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}\\leq\\frac\{1\}\{\\sqrt\{d\}\}\\right\\\}and

𝒢signrand:=\{sign⁡\(a\(0\)\)=sign⁡\(q\(0\)\),\|q\(0\)\|≥crN\},\\mathcal\{G\}\_\{\\mathrm\{sign\}\}^\{\\mathrm\{rand\}\}:=\\left\\\{\\operatorname\{sign\}\(a^\{\(0\)\}\)=\\operatorname\{sign\}\(q^\{\(0\)\}\),\\quad\|q^\{\(0\)\}\|\\geq\\frac\{c\_\{r\}\}\{\\sqrt\{N\}\}\\right\\\},we have

𝐏𝐫\[𝒢signrand∩𝒫0rand\]≥pallrand\.\\mathop\{\\bf Pr\\/\}\\big\[\\mathcal\{G\}\_\{\\mathrm\{sign\}\}^\{\\mathrm\{rand\}\}\\cap\\mathcal\{P\}\_\{0\}^\{\\mathrm\{rand\}\}\\big\]\\geq p\_\{\\mathrm\{all\}\}^\{\\mathrm\{rand\}\}\.

###### Proof\.

The population part follows exactly as in Lemma[10](https://arxiv.org/html/2605.20314#Thmlemma10)\. Namely, writing

z±:=⟨w\(0\),u±⟩,r2:=z\+2\+z−2=‖P12​w\(0\)‖22,z\_\{\\pm\}:=\\langle w^\{\(0\)\},u\_\{\\pm\}\\rangle,\\qquad r^\{2\}:=z\_\{\+\}^\{2\}\+z\_\{\-\}^\{2\}=\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}^\{2\},the same radial–angular argument gives a universal constantppop\>0p\_\{\\mathrm\{pop\}\}\>0such that, with probability at leastppopp\_\{\\mathrm\{pop\}\},

\|qpop​\(w\(0\)\)\|≥34​d,‖P12​w\(0\)‖2≤1d\.\|q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\|\\geq\\frac\{3\}\{4d\},\\qquad\\\|P\_\{12\}w^\{\(0\)\}\\\|\_\{2\}\\leq\\frac\{1\}\{\\sqrt\{d\}\}\.Sincea\(0\)a^\{\(0\)\}is independent ofw\(0\)w^\{\(0\)\}and symmetric about zero, with an additional probability factor1/21/2we also have

sign⁡\(a\(0\)\)​qpop​\(w\(0\)\)=\|qpop​\(w\(0\)\)\|≥34​d\.\\operatorname\{sign\}\(a^\{\(0\)\}\)q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)=\|q\_\{\\mathrm\{pop\}\}\(w^\{\(0\)\}\)\|\\geq\\frac\{3\}\{4d\}\.Thus𝒫0rand\\mathcal\{P\}\_\{0\}^\{\\mathrm\{rand\}\}holds with probability at leastppop/2p\_\{\\mathrm\{pop\}\}/2\.

It remains to control the random\-label empirical bootstrap\. Conditional onw\(0\)w^\{\(0\)\}, for a single phase\-1 sample, define

Z:=ξ​\(x⊤​w\(0\)\)2\.Z:=\\xi\(x^\{\\top\}w^\{\(0\)\}\)^\{2\}\.As in the proof of Corollary[5](https://arxiv.org/html/2605.20314#Thmtheorem5), the Paley–Zygmund argument applied to\(q\(0\)\)2\(q^\{\(0\)\}\)^\{2\}gives universal constantscr,pr\>0c\_\{r\},p\_\{r\}\>0such that

𝐏𝐫\[\|q\(0\)\|≥crN\|w\(0\)\]≥pr\.\\mathop\{\\bf Pr\\/\}\\\!\\left\[\|q^\{\(0\)\}\|\\geq\\frac\{c\_\{r\}\}\{\\sqrt\{N\}\}\\,\\middle\|\\,w^\{\(0\)\}\\right\]\\geq p\_\{r\}\.Moreover, by the symmetry of the random labels, the conditional distribution ofq\(0\)q^\{\(0\)\}is symmetric about zero\. Hence, conditional onw\(0\)w^\{\(0\)\}anda\(0\)a^\{\(0\)\},

𝐏𝐫\[sign\(q\(0\)\)=sign\(a\(0\)\),\|q\(0\)\|≥crN\|w\(0\),a\(0\)\]≥pr2\.\\mathop\{\\bf Pr\\/\}\\\!\\left\[\\operatorname\{sign\}\(q^\{\(0\)\}\)=\\operatorname\{sign\}\(a^\{\(0\)\}\),\\quad\|q^\{\(0\)\}\|\\geq\\frac\{c\_\{r\}\}\{\\sqrt\{N\}\}\\,\\middle\|\\,w^\{\(0\)\},a^\{\(0\)\}\\right\]\\geq\\frac\{p\_\{r\}\}\{2\}\.
Combining this conditional event with𝒫0rand\\mathcal\{P\}\_\{0\}^\{\\mathrm\{rand\}\}gives

𝐏𝐫\[𝒢signrand∩𝒫0rand\]≥ppop​pr4:=pallrand\>0\.\\mathop\{\\bf Pr\\/\}\\\!\\left\[\\mathcal\{G\}\_\{\\mathrm\{sign\}\}^\{\\mathrm\{rand\}\}\\cap\\mathcal\{P\}\_\{0\}^\{\\mathrm\{rand\}\}\\right\]\\geq\\frac\{p\_\{\\mathrm\{pop\}\}p\_\{r\}\}\{4\}:=p\_\{\\mathrm\{all\}\}^\{\\mathrm\{rand\}\}\>0\.∎

Replacing[Lemma10](https://arxiv.org/html/2605.20314#Thmlemma10)with[Lemma12](https://arxiv.org/html/2605.20314#Thmlemma12), the second phase analysis follows from[Lemmas6](https://arxiv.org/html/2605.20314#Thmlemma6)and[11](https://arxiv.org/html/2605.20314#Thmlemma11)\. Choose

a⋆:=min⁡\{1,12​η​BN,d,δ,132​η​BN,d,δ​d,cr64​BN,d,δ​N​d\},a\_\{\\star\}:=\\min\\left\\\{1,\\;\\frac\{1\}\{2\\eta B\_\{N,d,\\delta\}\},\\;\\frac\{1\}\{32\\eta B\_\{N,d,\\delta\}\\sqrt\{d\}\},\\;\\sqrt\{\\frac\{c\_\{r\}\}\{64\\,B\_\{N,d,\\delta\}\\sqrt\{N\}\\sqrt\{d\}\}\}\\right\\\},where

BN,d,δ:=C​\(d​log⁡\(2​d/δ\)N\+d​log⁡\(2​d/δ\)N\)B\_\{N,d,\\delta\}:=C\\left\(\\sqrt\{\\frac\{d\\log\(2d/\\delta\)\}\{N\}\}\\;\+\\;\\frac\{d\\log\(2d/\\delta\)\}\{N\}\\right\)in this setting\. Whend<N<d2d<N<d^\{2\},BN,d,δ≃d/NB\_\{N,d,\\delta\}\\simeq\\sqrt\{d/N\},a⋆a\_\{\\star\}is upper bounded asa⋆≲1da\_\{\\star\}\\lesssim\\frac\{1\}\{\\sqrt\{d\}\}\. Combining the first and second phase, we have the total number of steps bounded as

T​\(a∗\)=T1\+T2≲a∗​Nη\+2η​a∗​log⁡\(dε\)\.T\(a\_\{\*\}\)=T\_\{1\}\+T\_\{2\}\\lesssim\\frac\{a\_\{\*\}\\sqrt\{N\}\}\{\\eta\}\+\\frac\{2\}\{\\eta a\_\{\*\}\}\\log\\Big\(\\frac\{d\}\{\\varepsilon\}\\Big\)\.
TTis a monotonically decreasing function ina⋆a\_\{\\star\}whena⋆≲1N1/4a\_\{\\star\}\\lesssim\\frac\{1\}\{N^\{1/4\}\}\. Therefore,

- •forN<d2,a⋆=O​\(1d\)N<d^\{2\},a\_\{\\star\}=O\(\\frac\{1\}\{\\sqrt\{d\}\}\), and this givesT=O​\(Nη​d\+dη​log⁡\(dε\)\)T=O\\left\(\\frac\{\\sqrt\{N\}\}\{\\eta\\sqrt\{d\}\}\+\\frac\{\\sqrt\{d\}\}\{\\eta\}\\log\\Big\(\\frac\{d\}\{\\varepsilon\}\\Big\)\\right\)\.
- •forN=d2N=d^\{2\},a⋆=O​\(1d\)=1N1/4a\_\{\\star\}=O\(\\frac\{1\}\{\\sqrt\{d\}\}\)=\\frac\{1\}\{N^\{1/4\}\}, we obtainT=O​\(N1/4η\+N1/4η​log⁡\(dε\)\)T=O\\left\(\\frac\{N^\{1/4\}\}\{\\eta\}\+\\frac\{N^\{1/4\}\}\{\\eta\}\\log\\Big\(\\frac\{d\}\{\\varepsilon\}\\Big\)\\right\)\.

## Appendix BExperiment details and additional results

### B\.1Experiment details

We report the architectures used for each task\.

- •Single\-Index Model \(SIM\): The link function for SIM is degree 3 Hermite polynomial and dimensionn=\{40,50\}n=\\\{40,50\\\}\. The default MLP in the experiments has 2 layers and hidden dimension 64, with ReLU activation function888GELU for SIM and sigmoid for parity also have the similar results\.and batch size 128 for mini\-batch training\. We train Transformers \(encoder\-type, i\.e\. no causal masking\) with 2 layers and 4 heads and embedding dimension 64 with the fixed batch size128128, a simple 2\-phase repeat can accelerate training\.
- •Sparse parity: MLP experiments default to 2\-layer MLP with hidden dimension 64 and ReLU activation\. For[Figure6](https://arxiv.org/html/2605.20314#S5.F6), each heatmap was created with roughly 42 million training runs done on a single A100 in four hours\. Transformer experiments are using encoder\-only structure \(i\.e\., without causal masking\) to preserve a permutation\-invariant structure of the parity task\. The default Transformer has 2 layers, dimension 256, and 8 heads\.
- •Modular addition: Modular addition runs use a 4\-phase training strategy, where the dataset size increases in each phase\. Experiments are performed with 2\-layer decoder\-only Transformers \(i\.e\., with causal masking\)\.
- •In\-context linear regression: We use Transformers 2 layers and 4 heads and embedding dimension 64 on the in\-context linear regression task with the number of context examplesk=15k=15and dimensionn=4n=4\. During training, the loss is computed on the last token\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/aw_ratio_MLP_GD_d20k6_m64.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim_a_over_w_norm.png)

Figure 12:Layer norm ratio‖a‖2/‖W‖F\\\|a\\\|\_\{2\}/\\\|W\\\|\_\{F\}increases\.Results are shown for MLP on \(20, 6\)\-parity and SIM trained with gradient descent\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim_first_layer_norm.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim_second_layer_norm.png)

Figure 13:Layer norm growth during training\. Results are shown for MLP on SIM trained with gradient descent\.##### Hyperparameters

We sweep the hyperparameters separately for each task and training setup\. We tune the learning rate for SGD, and tune both the learning rate andβ2\\beta\_\{2\}for AdamW, with a fixedβ1=0\.9\\beta\_\{1\}=0\.9\. Learning rates are swept at multiplicative intervals of no wider than 1\.2x\. We considerβ2∈\{0\.8,0\.9,0\.95,0\.999\}\\beta\_\{2\}\\in\\\{0\.8,0\.9,0\.95,0\.999\\\}\. For mini\-batch training, MLP experiments use batch size 128, and Transformer experiments use batch size 32, unless specified\. Parity and SIM experiments use a test set of size 4096 and 5000 respectively\.

For model initialization, all weights are initialized with Pytorch defaults\. In particular, linear weights are initialized asWi​j∼Unif\[−1/din,1/din\]W\_\{ij\}\\sim\\text\{Unif\}\[\-1/\\sqrt\{d\_\{\\text\{in\}\}\},1/\\sqrt\{d\_\{\\text\{in\}\}\]\}\. For MLP, we additionally experiment with Gaussian initialization \(i\.e\.,Wi​j∼𝒩​\(0,1/din\)W\_\{ij\}\\sim\{\\mathcal\{N\}\}\(0,1/\\sqrt\{d\_\{\\text\{in\}\}\}\)\) whose standard deviation differs from that of the uniform distribution differs by a factor of3\\sqrt\{3\}; we get the same conclusions\. For Transformer, attention is computed asai,j∝exp⁡\(qi⊤​kjd\)a\_\{i,j\}\\propto\\exp\(\\frac\{q\_\{i\}^\{\\top\}k\_\{j\}\}\{\\sqrt\{d\}\}\), whereqi,kj∈ℝdq\_\{i\},k\_\{j\}\\in\\mathbb\{R\}^\{d\}\. For experiments with RMSNorm, we use RMSNorm with a learnable scale parameter\.

##### Data use

We provide the phase schedule used in multi\-phase training\.

- •SIM uses a 2\-phase schedule\. For GD on MLP, the first phase takes100100steps on a dataset of size80008000and the second phase takes900900steps on a dataset of size6400064000\. Ablation results with other dataset sizes are shown in[Figure18](https://arxiv.org/html/2605.20314#A2.F18)\. For SGD on MLP, the first phase uses0\.010\.01fraction of the total amount of data seen during the online run for100100steps, and the second phase is online training\. For transformer, the first phase uses0\.0050\.005fraction of the total amount of data seen during the online run for800800steps, and the second phase is online training\.
- •Parity uses a 6\-phase schedule, where each phase uses\{0\.001,0\.002,0\.005,0\.01,0\.02,0\.1\}\\\{0\.001,0\.002,0\.005,0\.01,0\.02,0\.1\\\}of the total amount of data seen during the online run, with\{100,50,20,10,10,4\}\\\{100,50,20,10,10,4\\\}epochs respectively\. Ablation results with auto\-scheduling is shown in[Figure19](https://arxiv.org/html/2605.20314#A2.F19)\.
- •In\-context linear regression uses uses a 4\-phase schedule, where each phase uses\{0\.005,0\.02,0\.05,0\.1\}\\\{0\.005,0\.02,0\.05,0\.1\\\}of the total amount of data seen during the online run, with each phase running\{1500,1500,1500,10500\}\\\{1500,1500,1500,10500\\\}steps respectively\.
- •Mod addition uses a 4\-phase schedule, where each phase uses\{0\.005,0\.02,0\.05,0\.1\}\\\{0\.005,0\.02,0\.05,0\.1\\\}of the total amount of data seen during the online run, with\{20,5,4,6\}\\\{20,5,4,6\\\}epochs respectively\.

Some of our experiments with data repetition use sampling with replacement for faster data loading, which differs from the common multi\-epoch training where each epoch samples without replacement\. As a remark, sampling with or without replacement correspond to different algorithms\. For example,Linet al\.\[[2025](https://arxiv.org/html/2605.20314#bib.bib1)\]showed that for linear regression, the former is equivalent \(in terms of sample complexity\) to gradient descent, whereas the latter is closer to online SGD which can be better or worse than GD depending on the problem structure\[Wuet al\.,[2025](https://arxiv.org/html/2605.20314#bib.bib29)\]\. In our experiments though, we do not notice an empirical difference between the two based, hence we use them interchangeably\.

### B\.2Additional empirical results

#### B\.2\.1More setups with the small\-vs\-large gap

We report more setups where the small\-vs\-large gap is observed\.

##### Full parity

We consider learning the full parity whered=kd=k\. This is a trivial task in the SQ sense and does not have a sparse structure, hence the explanations in[Section4\.1\.1](https://arxiv.org/html/2605.20314#S4.SS1.SSS1)and[Section4\.1\.3](https://arxiv.org/html/2605.20314#S4.SS1.SSS3)do not apply\. However, the small\-vs\-large gap is still present, for both MLP and Transformers \([Figure14](https://arxiv.org/html/2605.20314#A2.F14)\)\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/MLP_GD_d20k20.png)\(a\)\(20, 20\)\-parity, MLP
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/gd_transformer_d10_k10.png)\(b\)\(10, 10\)\-parity, Transformer

Figure 14:Small\-vs\-large gap exists for dense parity\. Results are shown for\(Left\)\(20,20\)\(20,20\)\-parity with MLP and\(Right\)\(10,10\)\(10,10\)\-parity with Transformer\. Both are trained with full\-batch gradient descent\.
##### Transformer using full\-batch updates

The small\-vs\-large gap is observed on Transformers trained with full\-batch updates using AdamW \([Figure15](https://arxiv.org/html/2605.20314#A2.F15)\), demonstrating that the gradient variance explanation in[Section4\.1\.2](https://arxiv.org/html/2605.20314#S4.SS1.SSS2)is insufficient\. Due to memory constraint, we use a smaller input dimension \(d=10d=10\) than the MLP experiments \(d=20d=20in[Figure2](https://arxiv.org/html/2605.20314#S3.F2)\)\.

##### Transformer using mini\-batch updates, with dataset input biased removed

As discussed[Section4\.1](https://arxiv.org/html/2605.20314#S4.SS1), the small\-vs\-large gap cannot be explained by a stronginputbias from a smaller dataset, as the gap persists even when the input bias is removed \([Figure3\(a\)](https://arxiv.org/html/2605.20314#S4.F3.sf1)\)\. The same conclusion holds for Transformers[Figure16](https://arxiv.org/html/2605.20314#A2.F16)\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/gd_transformer_d10_k6.png)\(a\)\(10, 6\)\-parity
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/gd_transformer_d10_k10.png)\(b\)\(10, 10\)\-parity

Figure 15:Small\-vs\-large gap is observed in Transformer full\-batch training\. Results are on \(10, 6\)\-parity \(left\) and \(10, 10\)\-parity \(right\)\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/sgd_transformer_d20_k6_remove_input_bias.png)\(a\)\(20,6\)\(20,6\)\-parity, accuracy
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/transformer_sgd_d50_whiten_sim.png)\(b\)SIM withd=50d=50, loss

Figure 16:Repetition remains superior with dataset bias removed\. Results are based on Transformer with mini\-batch updates and are consistent with the MLP results in[Figure3\(a\)](https://arxiv.org/html/2605.20314#S4.F3.sf1)\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/sgd_transformer_d20_k6_online_biased.png)Figure 17:Biasing online training does not bridge the speed gap\.Results are based on Transformer with mini\-batch updates and are consistent with the MLP results \([Figure3\(b\)](https://arxiv.org/html/2605.20314#S4.F3.sf2)\)\. For sparse parity \(d=20,k=6d=20,k=6\), biasing the Bernoulli distribution with the empirical mean of2i2^\{i\}samples \(fori∈\{2,3,4,5,6\}i\\in\\\{2,3,4,5,6\\\}\) makes online training faster for certain values ofii\(best ati=3i=3\)\. However, to reach similar speedup as given by training on smaller datasets \(marked as “4\-phase”\), the amount of bias required for large set \(i\.e\., online\) training would require an extremely small dataset size\.

#### B\.2\.2Ablation studies

##### Choosing the small dataset size

We show ablation on the size of small\-set training\. A proper dataset provides learning speedup without incurring severe overfitting, as shown in[Figure18](https://arxiv.org/html/2605.20314#A2.F18)for parity and single\-index model \(SIM\)\. Note that for SIM, using a smaller dataset can lead to speedup initially but a worse loss at convergence\. Hence our main results on SIM \(e\.g\.[Figure2](https://arxiv.org/html/2605.20314#S3.F2)\) adopts 2\-phase training, where we first train on a small dataset \(of size 8192\) and then switch to a larger dataset \(of size 256000\)\. We recommend such multi\-phase in general to obtain both learning speedup and generalization benefits of using the full dataset\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/gd_mlp_d20_k6_vary_data_size.png)\(a\)\(20, 6\)\-parity
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim_vary_data_size.png)\(b\)SIM

Figure 18:Varying the small dataset size\. Results shown on MLP with full\-batch training, for parity \(left\) and single\-index model \(right\)\.
##### Auto\-scheduling for multi\-phase training

As mentioned in We also consider an alternative auto\-scheduling, which the phase sizes and durations are determined automatically\. Specifically, there are 6 phases, where the first and last phase is of size 1/320 and 1/50 of the amount of data seen during the online run\. The intermediate dataset sizes are distributed geometrically\. Each phase advances to the next one either when the training accuracy reaches 75%, or when 50 epochs have elapsed\. As shown in[Figure19](https://arxiv.org/html/2605.20314#A2.F19), this auto\-scheduling achieves comparable performance to the 6\-phase scheduling described above and is much faster than online training\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/sgd_transformer_d20_k6_with_auto.png)Figure 19:Auto\-scheduling for multi\-phase learning, where the dataset sizes across phases is distributed geometrically, and the phase duration is determined automatically based on the training accuracy\. Such auto\-scheduling \(red\) is comparable to manually selected phase scheduling \(yellow\), both much faster than online training\.
##### Speedup from random\-label training

Our work attributes the small\-vs\-large gap to the layer balancing effects enabled by small\-set repetitions\. As discussed in[Section5\.1](https://arxiv.org/html/2605.20314#S5.SS1), one strong empirical evidence for this is that training on a small set of samples withrandom labelsalso leads to accelerated learning\. We now[Figure20](https://arxiv.org/html/2605.20314#A2.F20)provide additional evidence on mod addition, learned with Transformers with mini\-batch updates using AdamW\. We train the model first on a small subset where the labels are randomly permuted, and then switch to online training for the remaining time\. Specifically, the number of samples seen during the first phase is 0\.5% of the second phase, repeated for 50 epochs, i\.e\., the random label phase takes up 20% of the total training time\. As shown in[Figure20](https://arxiv.org/html/2605.20314#A2.F20), such initial small\-set random\-label training speeds up learning compared to training directly with online batches with true labels\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/mod_addition_rand_phase1.png)Figure 20:Small\-set training with random labels speeds up learning for mod addition, complementing results in[Figure4](https://arxiv.org/html/2605.20314#S5.F4)\. Compared to training directly on online batches with true labels \(blue curve\), adding an initial phase of repeating a small set withrandomlabels help speed up learning\. The gray dashed line marks the switch from training on a small set with random labels to training on online batches with true labels\. Since random labels provide no learning signal, this result confirms that layer balancing is the main effect of small\-set repetition\. Results are shown on mod addition learned using Transformer with mini\-batch updates\.
##### MLP initialization across widths

Recall from[Section5\.2\.1](https://arxiv.org/html/2605.20314#S5.SS2.SSS1)that proper initialization can shrink or even eliminate the small\-vs\-large gap\.[Section5\.2\.1](https://arxiv.org/html/2605.20314#S5.SS2.SSS1)discusses two alternative initialization schemes to the default standard initialization, namelyμ​P\\mu Pand 1\-dimension simplification with anα\\alpha\-scaling \(i\.e\., dividing the first layer initialization standard deviation byα\\alpha, and multiplying the second layer’s byα\\alpha\)\.[Figure7](https://arxiv.org/html/2605.20314#S5.F7)shows the results at width 64, where bothμ​P\\mu Pand theα\\alpha\-scaling help narrow the small\-vs\-large gap\.[Figure21](https://arxiv.org/html/2605.20314#A2.F21)shows that for parity,μ​P\\mu Pcannot close the gap atm=32m=32but shows no gap for width 64 or above, which may be partly due to the effect of increasing width which reduces the gap even under the standard parameterization \([Figure11\(b\)](https://arxiv.org/html/2605.20314#S6.F11.sf2)\)\. We hypothesize this may be becausem=32m=32is too small that it deviates too much from the infinite\-width limit thatμ​P\\mu Pis designed for\. However,μ​P\\mu Pdoes not close the gap for SIM for the maximum width \(1024\) we tested \([Figure22](https://arxiv.org/html/2605.20314#A2.F22)\)\. Further, we find that the optimalα\\alpha\-scaling to stay constant across widths \([Figure23](https://arxiv.org/html/2605.20314#A2.F23)\)\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/muP_gd_mlp_d20_k6_width32.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/muP_gd_mlp_d20_k6_width64.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/muP_gd_mlp_d20_k6_width256.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/muP_gd_mlp_d20_k6_width1024.png)

Figure 21:μ​P\\mu Pacross model widths for parity\.Results are for 2\-layer MLP on \(20, 6\)\-parity trained with \(full\-batch\) GD fromμ​P\\mu Pinitialization, at various widthsm∈\{32,64,256,1024\}m\\in\\\{32,64,256,1024\\\}\.μ​P\\mu Psuffices to close the small\-vs\-large gap for width≥64\\geq 64\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/sim_mlp_mup_w64.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/sim_mlp_mup_w1024.png)

Figure 22:μ​P\\mu Pacross model widths for SIM\.Results are for 2\-layer MLP on SIM trained with \(full\-batch\) GD fromμ​P\\mu Pinitialization, at various widthsm∈\{64,1024\}m\\in\\\{64,1024\\\}\.μ​P\\mu Pdoesn’t close the small\-vs\-large gap for SIM\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/mlp_gd_init_scale_d20_k6_N16384_m32.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/mlp_gd_init_scale_d20_k6_N16384_m64.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/mlp_gd_init_scale_d20_k6_N16384_m256.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/mlp_gd_init_scale_d20_k6_N16384_m1024.png)

Figure 23:Initialization scale holds constant across width\.Results are for MLP on \(20, 6\)\-parity trained with \(full\-batch\) GD onN=214N=2^\{14\}samples, at various widthsm∈\{32,64,256,1024\}m\\in\\\{32,64,256,1024\\\}\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/mlp_gd_m64_sim.png)\(a\)Width 64
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/sim_mlp_width256.png)\(b\)Width 256
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/sim_mlp_width1024.png)\(c\)Width 1024

Figure 24:Increasing width reduces the small\-vs\-large gap\. Results are from 2\-layer MLP with full\-batch updates on SIM, where we vary the model width\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ezra_figs/1layer_transformer_gap.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ezra_figs/2layer_transformer_gap.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ezra_figs/4layer_transformer_gap.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ezra_figs/8layer_transformer_gap.png)

Figure 25:Transformer on\(10,6\)\(10,6\)\-parity, across varying depths \(2, 4, 6, 8\)\.Figure 26:Increasing depth widens the small\-vs\-large gap\. Results are shown on Transformer with mini\-batch Adam updates, complementing results in[Figure11\(a\)](https://arxiv.org/html/2605.20314#S6.F11.sf1)\.
##### Effect of Transformer QK normalization

In[Section5\.2\.2](https://arxiv.org/html/2605.20314#S5.SS2.SSS2), we showed that QK normalization shrinks the small\-vs\-large gap, for parity and SIM\. Specifically, QK normalization significantly improves large\-set training for both mini\-batch \([Figure8\(c\)](https://arxiv.org/html/2605.20314#S5.F8.sf3)\) and full\-batch \([Figure27\(a\)](https://arxiv.org/html/2605.20314#A2.F27.sf1)\) training\. However, such improvement is not universal\.

First, QK normalization worsens 4\-phase training \([Figure8\(c\)](https://arxiv.org/html/2605.20314#S5.F8.sf3)\)\. A closer investigation suggests that this is due to worse overfitting\. As shown in[Figure28](https://arxiv.org/html/2605.20314#A2.F28), training with QK normalization allows to fit the training set more quickly, while the validation accuracy remains low\. Moreover, QK normalization can even worsenonlinetraining for some tasks, such as ICL \([Figure9\(d\)](https://arxiv.org/html/2605.20314#S5.F9.sf4)\) and mod addition \([Figure27\(b\)](https://arxiv.org/html/2605.20314#A2.F27.sf2)\)\. A mechanism understanding of these effects is left as future work\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/gd_transformer_d10_k6_qk_rmsnorm.png)\(a\)\(20, 6\)\-parity, full\-batch updates
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/mod_addition_qk_norm.png)\(b\)mod addition

Figure 27:Additional results on Transformer with QK normalization\.QK normalization\(Left\)removes the small\-vs\-large gap for parity with full\-batch training, and\(Right\)worsens the training of mod addition for both online \(“large”\) and repeated \(“small”\) samples\.![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ablation_sgd_transformer_d20_k6_compare_qk_norm_train_acc.png)\(a\)Train accuracy
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ablation_sgd_transformer_d20_k6_compare_qk_norm_val_acc.png)\(b\)Test accuracy

Figure 28:QK slows down small\-set training\.Results are shown for \(20, 6\)\-parity with mini\-batch updates\. As shown in the train accuracy plot \(left\), QK normalization overfits to the training set quickly in the first two phases, but struggles to fit later phases where the training set sizes are larger\.
##### Effect of adaptive optimizers

We view the small\-vs\-large gap as related to the relative balance across layers, as supported both theoretically in[Section4\.2](https://arxiv.org/html/2605.20314#S4.SS2)and empirically in[Section5](https://arxiv.org/html/2605.20314#S5)\. As an implication, the gap should be less pronounced when using adaptive optimizers such as AdamW, which are much less sensitive to layer scale than naive \(stochastic\) gradient descent\. Indeed, we find that AdamW closes the gap on MLP \([Figure29](https://arxiv.org/html/2605.20314#A2.F29)\), across tasks and depths\. However, AdamW does not close the gap in Transformers: all Transformer experiments were conducted using AdamW and yet the small\-vs\-large gap persists\. Hence, a full characterization of the gap is more intricate than our current explanation and likely needs to be architecture\-aware\.

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/adam_MLP_GD_d20k6_m64.png)\(a\)\(20, 6\)\-parity, depth22
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/adam_MLP_GD_d20k6_m64_L4.png)\(b\)\(20, 6\)\-parity, depth44
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/adam_MLP_GD_d20k6_m64_L6.png)\(c\)\(20, 6\)\-parity, depth66
![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/jingwen_figs/sim_mlp_adam.png)\(d\)SIM, depth22

Figure 29:Adam removes the small\-vs\-large gap in MLP, across tasks and model depths\. Results are shown for MLP with GD updates\.
##### Which Transformer parameter benefits more from small\-set training?

We perform ablation on mini\-batch training where a part of the model is updated using online data, while the rest is updated using batches repeatedly sampled from a fixed, small dataset\.[Figure30](https://arxiv.org/html/2605.20314#A2.F30)shows results for \(20, 6\)\-parity using the default 2\-layer Transformer \([AppendixB](https://arxiv.org/html/2605.20314#A2)\)\. The parameterWvW\_\{v\}seems to benefit the most from small\-set training, as switching to its updates to online hurts the performance the most\. ForWq,WkW\_\{q\},W\_\{k\}

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ablation_sgd_transformer_d20_k6_online_param_single.png)

![Refer to caption](https://arxiv.org/html/2605.20314v1/figs/ablation_sgd_transformer_d20_k6_online_param_pair.png)

Figure 30:Parameter\-wise ablations of small\-set trainingWe train a Transformer on \(20, 6\)\-parity using 6\-phase mini\-batch updates, except for specific parameters which are updated using online batches\. Among single parameters \(Left\),WvW\_\{v\}relies on small\-set training the most, whereas the effects onWq,WkW\_\{q\},W\_\{k\}are mild\. When using online updates on a pair of parameter \(Right\), online updates onWq,WkW\_\{q\},W\_\{k\}jointly leads to a significant slowdown\.

Similar Articles

Why Larger Models Learn More: Effects of Capacity, Interference, and Rare-Task Retention

Hugging Face Daily Papers

This paper investigates why larger models outperform smaller ones, attributing it to reduced gradient interference and better resource allocation, allowing them to learn rare and complex tasks even with infinite data. Experiments on synthetic data and OLMo models verify that larger models avoid overwriting rare-task features due to weaker gradient updates for common tasks.

A Bitter Lesson for Data Filtering (1 minute read)

TLDR AI

This paper investigates data filtering for large model pretraining and finds that in the high-compute, data-scarce regime, filtering may not be necessary and can even be detrimental; sufficiently trained large models benefit from nominally low-quality data.

Prescriptive Scaling Laws for Data Constrained Training

Hugging Face Daily Papers

A modified scaling law accounting for data repetition effects provides compute-optimal training strategies for data-constrained scenarios, showing that beyond a point further repetition is counterproductive and compute is better spent on model capacity.