用于定位 Grokking 相变的分布谱诊断方法

arXiv cs.LG 论文

摘要

本文提出了一种分布谱诊断方法,用于在测试准确率上升之前定位 Transformer 模型中的 Grokking 相变。该方法利用经验分布和汉克尔动态模态分解(Hankel DMD)创建监测信号,以区分发生 Grokking 和未发生 Grokking 的训练运行。

arXiv:2605.08237v1 公告类型:新论文 摘要:在 Grokking 现象中,模型首先拟合训练数据,而测试准确率保持较低,直到后来才开始泛化。我们探讨是否可以在测试准确率上升之前,通过观察到的训练轨迹来定位这一相变,并将 Grokking 相变定位公式化为一个具有显式阈值/假阳性率(FPR)/提前量权衡的诊断问题。依赖于任务的观测值被总结为经验分布,映射到 Wasserstein/分位数坐标,并通过汉克尔动态模态分解(DMD)进行分析;由此产生的重建残差,结合谱和有效秩,构成了诊断输出。在保留的模加 Transformer 运行中,残差在运行级别对 Grokking 与非 Grokking 的区分达到了 AUROC \(\approx \) 0.93;在固定持续阈值操作规则下,真阳性警报可以早于相变发生,并报告了提前量、假警报率及不确定性区间。扰动实验表明,在测试的 \(wd=1\) 池中,高残差窗口的短视界扰动偏差约为低残差窗口的 \(3\times\)。在同一数据规范窗口对照实验中,扰动敏感度与残差排序一致,而非与总参数范数排序一致,这表明在所研究的 \(wd=1\) 动力学中,残差在窗口级别不仅仅是总范数的代理指标。范数信号仍然是强有力的运行级别机制指标,并且在当前协议下,对数概率在所有测试的观测值中表现最佳。我们将残差定位为所研究的模算术 Transformer 设置中的窗口级别监测和定位信号,而非通用的早期预警预测器或干预规则。
查看原文 导出为 Word 导出为 PDF
查看缓存全文

缓存时间: 2026/05/12 07:10

# 分布谱诊断用于定位泛化突破(Grokking)相变

**来源**: https://arxiv.org/html/2605.08237  
**作者**: Ziyue Wang¹, Yufeng Ying², Takafumi Kanamori¹  
¹东京科学大学  
²中国科学技术大学  

###### 摘要

在泛化突破(grokking)现象中,模型首先拟合训练数据,而测试准确率保持低位,直到较晚阶段才开始泛化。我们探讨能否在测试准确率上升之前,从观测到的训练轨迹中定位这一相变过程,并将泛化突破相变定位公式化为一个具有明确阈值/假阳性率(FPR)/提前量权衡的诊断问题。任务依赖的可观测量被总结为经验分布,映射到 Wasserstein/分位数坐标,并通过汉克尔动态模态分解(Hankel DMD)进行分析;由此产生的重构残差,连同谱信息和有效秩,构成了诊断输出。在留出的模加 Transformer 实验中,该残差在运行级别上的“泛化突破 vs 非泛化突破”判别任务中实现了 AUROC≈0.93;在固定的持续阈值操作规则下,真阳性警报可以出现在相变 onset 之前,并报告了提前量与假阳性率及不确定性区间。扰动实验表明,在测试的 $w_d=1$ 池子中,高残差窗口表现出比低残差窗口大约 $3\times$ 更大的短期扰动偏差。在同数据规范窗口对照实验中,扰动敏感性与残差排序一致,而非总参数范数排序,这表明在所研究的 $w_d=1$ 动力学中,残差在窗口级别上不仅仅是一个总范数的代理。范数信号仍然是强大的运行级别状态指标,并且在当前协议下,对数概率(log-probability)在所有测试的可观测量中表现最佳。我们将残差定位为在所研究的模算术 Transformer 设置中的窗口级别监控和定位信号,而非通用的早期预警预测器或干预规则。

## 1 引言

泛化突破是标准训练总结的一种显著失效模式:在某些算法任务上,模型早期拟合训练集,但直到长时间延迟后才开始泛化,期间测试准确率曲线几乎保持平坦(Power et al., 2022)。这种平坦性构成了难点——即使是在最终发生泛化突破的运行中,仅凭损失和准确率无法直接指示相变何时发生。我们解决一个由此差距引发的狭窄问题:能否使用通用分布可观测量和明确的阈值/假阳性率/提前量诊断协议,从观测到的训练轨迹中定位相变窗口?这与解释*为什么*发生泛化突破是不同的问题。

越来越多的工作通过机制解释泛化突破。机制可解释性识别了新兴的计算电路(Nanda et al., 2023; Varma et al., 2023);隐偏差分析将相变归因于零损失流形上的晚期范数最小化(Liu et al., 2022; Lyu et al., 2024; Musat, 2026);基于稳定性的解释将泛化突破与 logits 缩放和 softmax 崩溃联系起来(Thilak et al., 2022; Prieto et al., 2025)。范数增长、logits 缩放、基于 AGOP 的特征涌现(Radhakrishnan et al., 2024)以及电路形成观点是我们工作的自然参考点。我们的目标是互补的:从任务依赖的可观测量计算阈值定位信号,并报告明确的假阳性/提前量权衡,旨在标记相变窗口而非解释它们。

我们在每个训练步长 $t$ 将选定的任务依赖可观测量 $o_t$ 总结为经验分布 $\mu_t$。因为诊断是从选定的输出分布而非隐藏单元坐标计算得出,所以它不依赖于隐藏单元的索引。Wasserstein/分位数坐标将 $\mu_t$ 转换为向量观测 $z_t$,窗口汉克尔动态模态分解(DMD)提供局部动力学近似(Schmid, 2010; H. Tu et al., 2014; Arbabi and Mezić, 2017)。重构残差 $\text{Res}^{(r)}$ 是主要诊断指标;谱信息和有效秩 $r_{0.99}$ 是辅助描述符,主要在低残差窗口中可解释。

我们的主要设置取自在模加任务上训练的 Transformer 中固定探测集上正确答案对数概率的经验分布;次要可观测量和全连接网络(FCN)比较作为范围检查稍后出现。经验上,残差在模加 Transformer 设置的泛化突破相变附近上升。在带有新鲜种子(seeds)的留出测试集上,残差在明确的阈值/假阳性率/提前量权衡下表现出非平凡的运行级别检测行为(具体数字见 §3.2 和 表 3)。扰动实验显示,在匹配噪声下,高残差窗口比低残差窗口表现出更大的短期偏差。同数据范数窗口对照通过将相同运行重新标记为总参数范数百分位,产生了相反的脆弱性排序,这表明在所研究的 $w_d=1$ 动力学中,残差在窗口级别上不仅仅是一个总范数代理;尽管如此,范数信号仍然是强大的运行级别状态指标。可观测量消融实验发现,对数概率是当前协议下测试的最佳性能可观测量。

#### 贡献。
- (i) 我们提出了一种用于训练动力学的窗口分布诊断方法。该方法将任务依赖的可观测量映射到经验分布,用 Wasserstein/分位数坐标表示它们,并应用汉克尔-DMD 来计算谱、有效秩和重构残差。
- (ii) 我们评估重构残差作为泛化突破相变定位信号的性能。配合持续阈值操作规则,它在明确的阈值/假阳性率/提前量权衡下给出了留出检测行为;对数概率是我们当前协议下测试的最佳性能可观测量。
- (iii) 我们提供了基于扰动的证据,表明高残差窗口对应于脆弱的训练时期,支持将残差作为监控/定位信号,而非干预规则。
- (iv) 我们通过模型规模、任务族、范数基线、AGOP、干预、CIFAR-10 和 FCN 检查来评估范围和边界;这些作为范围检查呈现,而非普遍鲁棒性声明。

#### 声明范围。
我们不声称这是泛化突破的通用预测器、与架构无关的诊断或自动干预规则。我们不声称残差取代基于范数的状态分类器,也不声称扰动对齐建立了因果机制。我们的声明更窄:在所研究的模算术 Transformer 设置中,重构残差是相变定位和脆弱性监控的窗口级别线索,在明确的阈值/假阳性率/提前量权衡下进行评估。与机制、谱、Koopman/DMD 和 Wasserstein 训练动力学诊断的扩展比较推迟到附录 U。

## 2 方法

该流水线有三个阶段:(i) 将选定的任务依赖可观测量总结为经验分布;(ii) 通过 Wasserstein-分位数表示将每个分布嵌入希尔伯特坐标;(iii) 通过汉克尔-DMD 分析固定步长窗口上产生的向量值轨迹,并读取一小组窗口化量。流水线依赖于可观测量:$o_t$ 的选择决定了诊断能检测到什么。图 1 总结了整体诊断流水线,表 1 列出了主要符号。

**表 1: 分布汉克尔-DMD 诊断中使用的符号。**

### 2.1 可观测量与 Wasserstein-分位数坐标

#### 可观测量与分布状态。
对于模加任务上的 Transformer(主要设置),我们固定一个探测集 $\mathcal{P}=\{(x_i, y_i^\star)\}_{i=1}^M$($M=100$ 个示例),其中 $y_i^\star$ 是输入 $x_i$ 的正确答案标记。在训练步长 $t$,每个样本的可观测量是标量正确答案对数概率 $o_{t,i}=\log p_{\theta_t}(y_i^\star \mid x_i)$,分布状态是这 $M$ 个标量的经验分布:
$$
\mu_t := \frac{1}{M}\sum_{i=1}^M \delta_{o_{t,i}}. \quad (1)
$$
因此,诊断跟踪固定探测集上正确答案对数概率的经验分布;平均损失或准确率会将此分布坍缩为单个标量。因为构建使用选定的输出分布而非隐藏单元坐标,所以它不依赖于隐藏单元索引。Wasserstein/分位数坐标提供了表示:它们将分布值状态转换为向量值观测。汉克尔-DMD 提供局部动力学近似:它分析这些向量在短训练窗口中的演变。FCN 可观测量作为次要低残差描述符,定义在附录 T 中。

#### Wasserstein-分位数坐标。
对于具有有限二阶矩的一维测度,分位数映射是 $\mathcal{W}_2(\mathbb{R}, d_W)$ 与 $L^2(0,1)$ 的一个闭凸子集之间的全局等距映射(Villani, 2009):在固定参考 $\mu^\star$ 处,$\log_{\mu^\star}(\mu)=F_\mu^{-1}\circ F_{\mu^\star}^{-1}-\text{id}$ 将 Wasserstein 切空间标识为希尔伯特子空间。我们在固定分位数网格 $p_1,\dots,p_d$($d=19$ 个层级,0.05–0.95)上评估 $F_{\mu_t}^{-1}$:
$$
z_t = \bigl(F_{\mu_t}^{-1}(p_1),\,\dots,\,F_{\mu_t}^{-1}(p_d)\bigr) \in \mathbb{R}^d. \quad (2)
$$
多维类比需要嵌入(核均值嵌入、MDS),这些嵌入不保留 Wasserstein 几何结构;因此我们限制在一维任务依赖的可观测量上。完整的 Wasserstein 背景,包括 $\mathcal{W}_2(D)$ 的 Hadamard 结构,见附录 B。

### 2.2 窗口汉克尔-DMD 诊断

#### 延迟状态与快照矩阵。
在长度为 $m+q$ 的步长窗口上,我们形成延迟嵌入向量
$$
\xi_t = \bigl[z_t^\top,\,z_{t+1}^\top,\,\ldots,\,z_{t+q-1}^\top\bigr]^\top \in \mathbb{R}^{qd}, \quad (3)
$$
以及汉克尔快照矩阵 $H_-=[\xi_0\,\cdots\,\xi_{m-1}]$, $H_+=[\xi_1\,\cdots\,\xi_m] \in \mathbb{R}^{qd\times m}$。

#### 汉克尔-DMD 估计器。
一个 Koopman/DMD 近似(Schmid, 2010; H. Tu et al., 2014; Arbabi and Mezić, 2017; drmač2017datadrivenmodaldecompositions)求解普通最小二乘问题
$$
A_H^\star = \arg\min_A \|H_+ - A H_-\|_F^2, \quad (4)
$$
然后我们通过截断拟合算子的前 $r$ 个特征对 $(\lambda_j, w_j)_{j=1}^r$ 将其截断至秩 $r$。秩-$r$ DMD 重构为 $\hat{\xi}_t^{(r)} = W \Lambda^t b$,其中 $b = W^\dagger \xi_0$。快照构建细节、降秩投影和非正常敏感性的讨论见附录 C。

#### 重构残差。
$$
\text{Res}^{(r)} := \frac{\bigl(\sum_t \|\xi_t - \hat{\xi}_t^{(r)}\|_2^2\bigr)^{1/2}}{\bigl(\sum_t \|\xi_t\|_2^2\bigr)^{1/2}}. \quad (5)
$$
较小的 $\text{Res}^{(r)}$ 表明窗口轨迹在选定坐标中允许准确的低秩线性描述;较大的值表明偏离该状态。

#### 有效秩。
令 $H_-=U\Sigma V^\top$ 且奇异值为 $\sigma_1 \geq \sigma_2 \geq \cdots$,
$$
r_{0.99} := \min\Bigl\{r \geq 1:\ \frac{\sum_{i=1}^r \sigma_i^2}{\sum_i \sigma_i^2} \geq 0.99\Bigr\}. \quad (6)
$$

#### 有效性门控。
如图 1 总结所示,残差作为辅助描述符的有效性门控。在低残差窗口中,$A_H^\star$ 的谱和 $r_{0.99}$ 可解读为局部线性演化状态的经验摘要。在高残差窗口中,低秩线性近似较差;此时残差本身应被解释为相变或脆弱性信号,不应过度解读谱点和有效秩为稳定状态描述符。

$$
\underbrace{o_t}_{\substack{\text{selected}\\\text{observable}}} \longrightarrow \underbrace{\mu_t}_{\substack{\text{distributional}\\\text{state}}} \xrightarrow{\;\text{fixed quantile map}\;} \underbrace{z_t \in \mathbb{R}^d}_{\substack{\text{Wasserstein}\\\text{coordinate}}} \xrightarrow{\;\text{windowed Hankel-DMD}\;} \underbrace{\bigl(\{\lambda_j\},\,r_{0.99},\,\text{Res}^{(r)}\bigr)}_{\substack{\text{spectrum / effective rank}\\\text{/ residual}}}
$$

*如何解读诊断。*
**低 $\text{Res}^{(r)}$**:将谱和 $r_{0.99}$ 解释为局部状态描述符;**高 $\text{Res}^{(r)}$**:将残差本身解释为相变/脆弱性信号。

**图 1: 提议的诊断流水线。** 每个训练步长的选定任务依赖可观测量 $o_t$ 被总结为经验分布 $\mu_t$。Wasserstein/分位数坐标将每个 $\mu_t$ 转换为向量观测 $z_t \in \mathbb{R}^d$。窗口汉克尔-DMD 随后分析 $\{z_t\}$ 在固定步长窗口上的局部时间演变,并返回谱、有效秩和重构残差。低残差支持将谱和有效秩解释为局部描述符;高残差被视为...

相似文章