大规模线性自编码器中学习机制的三棱柱层次结构
摘要
本文系统性地识别了大型权重共享线性自编码器中所有定性不同的极端学习机制,推导了与三棱柱面相关的五种机制的显式损失演化。
arXiv:2606.05335v1 Announce Type: new
摘要:机器学习模型的理论研究通常考虑不同的极限机制,在这些机制下梯度下降的学习动态在理论上是可解的。然而,对于特定类型的模型,系统地获取所有定性不同的极端学习机制的图景是可取的。在本文中,我们针对由输入和潜在维度、初始化幅度及训练集大小表征的大型权重共享线性自编码器,提出了这样一个图景。该模型在权重上是非线性的,其梯度流没有通用的理论解。我们表明,在形式损失展开层次结构层面,其极端机制自然地与三棱柱的面相关联。具体而言,存在与棱柱二维面相关的五种基本极端机制:(1) 大数据,(2) 小数据,(3) 平均场,(4) 窄潜在,以及 (5) 自由。对于机制 (1,2,3,4),我们推导了梯度流下训练损失和总体极限损失演化的显式表达式,与实验结果吻合良好。
查看缓存全文
缓存时间: 2026/06/05 08:10
# 大线性自编码器中学习机制的三棱柱层次结构 来源:https://arxiv.org/html/2606.05335
Eugene Golikov
应用人工智能研究所
莫斯科,俄罗斯
e\.golikov@applied\-ai\.ru
&Yaroslav Gusev
应用人工智能研究所
莫斯科,俄罗斯
i\.gusev@applied\-ai\.ru
&Dmitry Yarotsky
应用人工智能研究所 & 俄罗斯科学院斯捷克洛夫数学研究所
莫斯科,俄罗斯
yarotsky@gmail\.com
###### 摘要
机器学习模型的理论研究通常考虑不同的极限机制,在这些机制中,梯度下降的学习动力学变得理论上可处理。然而,理想情况下,对于特定类型的模型,应该系统性地获得所有定性不同的极端学习机制的完整图景。在本文中,我们针对具有输入维度和潜在维度、初始化幅度以及训练集大小特征的大权重共享线性自编码器提出了这样一个图景。该模型在权重上是非线性的,其梯度流没有通用的理论解。我们表明,在形式损失展开层次上,其极端机制自然地与一个三棱柱的面相关联。特别地,存在五个与棱柱的二维面相关的基本极端机制:(1) 大数据,(2) 小数据,(3) 平均场,(4) 窄潜在,以及 (5) 自由。对于机制 1,2,3,4,我们推导了梯度流下训练损失和总体损失极限演化过程的显式表达式,并与实验结果取得了非常好的一致性。
## 1 引言
机器学习中的一大挑战是准确理解大型预测模型中学习轨迹的理论。通常,模型训练是通过某种梯度下降变体进行的,这并非立即可以处理。动力学的理论分析通常涉及一些简单的、可解的模型,或者在适当假设下将复杂模型简化为更简单的模型,或者两者兼有(Simon et al., 2026 (https://arxiv.org/html/2606.05335#bib.bib2))。可解模型的重要例子是具有适当对齐初始化的线性模型和多层线性神经网络(Saxe et al., 2013 (https://arxiv.org/html/2606.05335#bib.bib51), 2019 (https://arxiv.org/html/2606.05335#bib.bib85))。简化为可处理模型的重要例子是“懒惰训练”场景中的模型线性化(Chizat et al., 2019 (https://arxiv.org/html/2606.05335#bib.bib73)),例如 NTK 机制(Jacot et al., 2018 (https://arxiv.org/html/2606.05335#bib.bib52); Lee et al., 2018 (https://arxiv.org/html/2606.05335#bib.bib90))。可解的线性网络可以从合适的极端机制中的随机初始化产生(Tu et al., 2024 (https://arxiv.org/html/2606.05335#bib.bib33))。其他一些实现可处理性的著名通用方法包括随机矩阵理论的应用(Pennington and Worah, 2017 (https://arxiv.org/html/2606.05335#bib.bib1))、统计物理方法如复制法(Zdeborová and Krzakala, 2016 (https://arxiv.org/html/2606.05335#bib.bib112)) 以及平均场理论(Sirignano and Spiliopoulos, 2020 (https://arxiv.org/html/2606.05335#bib.bib61); Rotskoff and Vanden-Eijnden, 2018 (https://arxiv.org/html/2606.05335#bib.bib62); Chizat and Bach, 2018 (https://arxiv.org/html/2606.05335#bib.bib63))。
通常,这类研究假设一个特定的极端学习机制(当然,这对于一大类模型可能都是相关的——参见 NTK 机制)。然而,我们问一个不同的问题:*给定一个特定的模型族,其所有极端机制是什么,其中哪些机制的学习动力学在理论上是可处理的?* 在本文中,我们表明对于*权重共享线性自编码器*(Baldi and Hornik, 1989 (https://arxiv.org/html/2606.05335#bib.bib10); Baldi, 2012 (https://arxiv.org/html/2606.05335#bib.bib4)),这个问题可以得到相当完整的回答。这是一个相对简单的模型,但它在权重上是非线性的,并且没有通用的理论解。理论研究通常考虑自编码器的两种不同机制:欠完备和过完备。前者具有比输入维度更小的隐藏维度;这样,有人认为模型被迫执行降维;特别是,欠完备线性自编码器的所有全局最小值对应于执行 PCA(Baldi and Hornik, 1989 (https://arxiv.org/html/2606.05335#bib.bib10))。另一方面,过完备自编码器总是能够完美地拟合恒等映射,因此不必执行任何形式的 PCA。然而,如果训练数据集小于输入维度,那么除了完美的恒等映射之外,还有多种方法可以将训练数据映射到自身。我们的训练算法选择的确切权重配置直接影响所学模型的泛化性能。这尤其凸显了在该模型中研究训练和总体学习轨迹的重要性。虽然自编码器已经在欠完备(Refinetti and Goldt, 2022 (https://arxiv.org/html/2606.05335#bib.bib111)) 和过完备(Nguyen, 2021 (https://arxiv.org/html/2606.05335#bib.bib109)) 机制中得到了分析,但据我们所知,除了 Yarotsky 等人 (2026 (https://arxiv.org/html/2606.05335#bib.bib106)) 的工作(有关训练机制分类研究的概述,请参见 B (https://arxiv.org/html/2606.05335#A2))之外,还没有对其机制全貌的系统性研究。该工作甚至解决了更广泛的、不同阶的张量模型族,但只考虑了群体层面的学习,因此没有解决任何与泛化相关的问题。Yarotsky 等人 (2026 (https://arxiv.org/html/2606.05335#bib.bib106)) 提出的机制分类依赖于损失演化的“图展开”。不同的机制对应于不同的图族,这些图族可以根据模型超参数之间的缩放关系系统地推导出来。
##### 我们的贡献。
1. 1. 我们将 Yarotsky 等人 (2026 (https://arxiv.org/html/2606.05335#bib.bib106)) 的基于图的方法推广到有限训练集上的训练,从而允许分别研究训练和总体学习轨迹。我们通过在图谱中引入和分析数据相关的边和节点来实现这一点。
2. 2. 使用这种方法,我们从理论上推导出线性自编码器中以输入维度和潜在维度、权重初始化幅度以及训练集大小为特征的极端学习机制层次结构。特别是,通过检查形式损失展开层次结构,我们认为该模型具有与*三棱柱*的二维面相关联的*五个基本理论极端*:(1) *大数据*,(2) *小数据*,(3) *平均场*,(4) *窄潜在*,(5) *自由*。每种机制都由四个超参数之间的特定缩放关系来表征。棱柱的边和顶点对应于通过组合五个基本机制而获得的更退化的机制。
3. 3. 在五个基本极端中的四个中(除了自由机制),我们推导了训练损失和总体损失演化的显式极限描述:大数据机制中的闭式公式、平均场机制中的 Marchenko-Pastur 积分公式、窄潜在机制中的有限维 ODE 表征以及小数据机制中的矩层次结构。这些解与实验非常吻合。在大数据机制 (1) 中,解是已知的(例如,在 Yarotsky 等人 (2026 (https://arxiv.org/html/2606.05335#bib.bib106)) 中通过另一种方法发现),但在其他三个机制 (2)、(3)、(4) 中,据我们所知,这些解是新的。
## 2 问题陈述
我们考虑一个薄的线性权重共享自编码器,其*输入维度*为 \(p\),*潜在维度*为 \(n\):
\[
f(\bm{x}) = \bm{U}^\top \bm{U} \bm{x},
\]
其中 \(\bm{x} \in \mathbb{R}^p\),\(\bm{U} \in \mathbb{R}^{n \times p}\)。假设数据分布是各向同性高斯分布,考虑一个大小为 \(m\) 的训练集,并定义该模型的*训练*和*总体*平方损失:
\[
\widehat{L}(\bm{U}) = \frac{\left\|\bm{X} - \bm{U}^\top \bm{U} \bm{X}\right\|_F^2}{2pm},\qquad
L(\bm{U}) = \mathbb{E}_{\bm{x}}\left[\frac{\left\|\bm{x} - \bm{U}^\top \bm{U} \bm{x}\right\|^2}{2p}\right] = \frac{\left\|\bm{I} - \bm{U}^\top \bm{U}\right\|_F^2}{2p},
\]
其中 \(\bm{x} \sim \mathcal{N}(0, \bm{I}_p)\),\(\bm{X} \in \mathbb{R}^{p \times m}\) 的每一列独立地从 \(\mathcal{N}(0, \bm{I}_p)\) 中采样,并且 \(\|\bm{A}\|_F^2 = \operatorname{Tr}[\bm{A}^\top \bm{A}]\)。
我们使用学习率为 \(\eta\) 的梯度流来训练我们的模型:
\[
\frac{d\bm{U}}{dt} = -\eta \frac{\partial \widehat{L}(\bm{U})}{\partial \bm{U}},\qquad
\bm{U}(0) \sim \mathcal{N}(0, \sigma^2 \bm{I}_n \otimes \bm{I}_p).
\]
\(\eta\) 和 \(\sigma^2\) 都可能依赖于 \(p, m, n\)。定义时间 \(t\) 时的损失值:
\[
\widehat{L}(t) = \widehat{L}(\bm{U}(t)),\qquad L(t) = L(\bm{U}(t)).
\]
我们将对大 \(p, n, m\) 极限下的*平均*损失演化感兴趣:
\[
\widehat{\mathcal{L}}(t) = \lim_{p,n,m\to\infty} \mathbb{E}[\widehat{L}(t)],\quad {\mathcal{L}}(t) = \lim_{p,n,m\to\infty} \mathbb{E}[L(t)].
\]
此处以及后续内容中,期望值是针对 \(\bm{U}(0)\) 和 \(\bm{X}\) 取的。我们将看到,根据 \(p\)、\(n\)、\(m\) 和 \(\sigma^2\) 的联合缩放,存在多个极限*学习机制*,并将描述这些机制的完整层次结构。接下来,我们将检查 \(\widehat{\mathcal{L}}(t), {\mathcal{L}}(t)\) 是否存在显式解。虽然目前看来一般来说没有显式解,但我们将看到(在不同程度的显式性下),在极端学习机制中可以找到它们。
## 3 学习机制分类
我们大致遵循 Yarotsky 等人 (2026 (https://arxiv.org/html/2606.05335#bib.bib106)) 提出的学习机制分类的小 \(t\) 展开方法。该方法间接地在展开系数的层面上处理学习轨迹,但它是系统性的,并产生了学习机制的一致且清晰的几何图景。更多讨论请参见 C.3 (https://arxiv.org/html/2606.05335#A3.SS3)。
##### 损失展开 (C.1 (https://arxiv.org/html/2606.05335#A3.SS1))。
我们的起点是在 \(t=0\) 处损失的幂级数展开:
###### 命题 3.1。
平均总体损失和训练损失允许幂级数展开:
\[
\mathbb{E}[L(t)] \sim \frac{1}{2} + \sum_{s=0}^{\infty} \left(\frac{-\eta}{pm}\right)^s Y_s \frac{t^s}{s!},\quad
\mathbb{E}[\widehat{L}(t)] \sim \frac{1}{2} + \sum_{s=0}^{\infty} \left(\frac{-\eta}{pm}\right)^s \widehat{Y}_s \frac{t^s}{s!},
\]
其中
\[
Y_s = \sum_{\mathbf{q} \in Q_s} c_{\mathbf{q};s} p^{q_p} n^{q_n} m^{q_m} \sigma^{q_\sigma}
\]
和
\[
\widehat{Y}_s = \sum_{\mathbf{q} \in \widehat{Q}_s} \widehat{c}_{\mathbf{q};s} p^{q_p} n^{q_n} m^{q_m} \sigma^{q_\sigma}
\]
是关于 \(p, n, m\) 和 \(\sigma^2\) 的多项式。多项式 \(Y_s, \widehat{Y}_s\) 是复杂的,但我们勾勒出如何用合适的*图*来构造性地描述它们(详见 C.1 (https://arxiv.org/html/2606.05335#A3.SS1))。
首先注意,可以通过对损失进行时间微分来找到 \(Y_s, \widehat{Y}_s\) 的值,例如对于总体损失和 \(s \geq 1\),有
\[
Y_s \eta^s = \mathbb{E}\left[\tfrac{d^s L}{dt^s}(t=0)\right].
\]
这些时间导数可以使用梯度流方程递归计算:
\[
\frac{d^s L}{dt^s} = -\eta \left\langle \bm{\nabla}_{\bm{U}} \frac{d^{s-1} L}{dt^{s-1}}, \bm{\nabla}_{\bm{U}} \widehat{L} \right\rangle,\quad
\frac{d^s \widehat{L}}{dt^s} = -\eta \left\langle \bm{\nabla}_{\bm{U}} \frac{d^{s-1} \widehat{L}}{dt^{s-1}}, \bm{\nabla}_{\bm{U}} \widehat{L} \right\rangle.
\]
在 \(s=0\) 时,\(L\) 和 \(\widehat{L}\) 可以写成各种矩阵乘积 \(\bm{U}, \bm{U}^\top, \bm{X}, \bm{X}^\top\) 的迹的线性组合:
\[
L = \frac{\frac{1}{2}D - R + \frac{p}{2}}{p},\quad
\widehat{L} = \frac{\frac{1}{2}\widehat{D} - \widehat{R} + \frac{1}{2}\widehat{F}}{pm},
\]
其中
\[
D = \operatorname{Tr}[(\bm{U}^\top \bm{U})^2],\quad R = \operatorname{Tr}[\bm{U}^\top \bm{U}],
\]
\[
\widehat{D} = \operatorname{Tr}[(\bm{U}^\top \bm{U})^2 \bm{X} \bm{X}^\top],\quad \widehat{R} = \operatorname{Tr}[\bm{U}^\top \bm{U} \bm{X} \bm{X}^\top],\quad \widehat{F} = \operatorname{Tr}[\bm{X} \bm{X}^\top].
\]
我们用*环图*来描述每一个迹,见图 1 (https://arxiv.org/html/2606.05335#S3.F1) 左上角。然后,计算 (6) 中梯度的标量积相当于“合并”对应于 \(\widehat{L}\) 和 \(\frac{d^{s-1} L}{dt^{s-1}}\) 或 \(\frac{d^{s-1} \widehat{L}}{dt^{s-1}}\) 的图(图 1 (https://arxiv.org/html/2606.05335#S3.F1) 右上角)。合并后的图是更大的环,也表示乘积的迹。最后,为了得到多项式 \(Y_s, \widehat{Y}_s\),只需在 \(t=0\) 时找到 \(\frac{d^s L}{dt^s}\) 和 \(\frac{d^s \widehat{L}}{dt^s}\) 的期望值。由于 \(\bm{U}(0)\) 和 \(\bm{X}\) 具有独立的高斯元素,这可以使用 Wick 定理来完成。为了找到图 \(G\) 的期望值 \(\mathbb{E}[G]\),我们考虑匹配类型 \(U\) 或 \(X\) 的边之间的所有可能*配对*,并*收缩*相应的节点(图 1 (https://arxiv.org/html/2606.05335#S3.F1) 底部)。然后每个配对贡献给 \(\mathbb{E}[G]\) 一个单项式 \(p^{q_p} n^{q_n} m^{q_m} \sigma^{q_\sigma}\),其中 \(q_\sigma\) 是 \(U\) 边的数量,\(q_p, q_n, q_m\) 分别是收缩节点的数量。初始项 \(\frac{1}{2} = \frac{1}{2p} \operatorname{Tr}[\bm{I}] = \frac{1}{2pm} \operatorname{Tr}[\bm{X} \bm{X}^\top]\) 是目标恒等矩阵 \(\bm{I}\) 的损失;它是常数相似文章
PRISM:一种将漂移分解为尺度、形状和头部的几何风险界
本文介绍了 PRISM,这是一种几何风险界,将训练后大型语言模型(LLM)变体中的模型漂移分解为尺度、形状和头部三个维度,以诊断量化误差或灾难性遗忘等特定故障模式。
PRISM:面向多层薄膜设计的位置编码回归逆光谱模型
PRISM是一种仅解码器的自回归变换器,通过联合预测材料选择和厚度来解决多层薄膜光学涂层设计的逆问题,以更小的模型实现了最先进的性能。
从权重到特征:SAE引导的激活正则化用于LLM持续学习
本文提出了一种用于大语言模型的持续学习方法,该方法使用预训练的稀疏自编码器(SAEs)在激活空间而非权重空间中进行正则化,从而在无需存储先前数据的同时避免灾难性遗忘,并实现了更好的内存效率和更强的基准性能。
稀疏自编码器中概念学习与神经元解释的几何视角
本文提出了一个统一的几何框架,用于理解稀疏自编码器中的概念学习和神经元解释,将概念形式化为集合,并定义了检测、分离和近似。它提供了误差界、容量约束,并与形式概念分析建立了联系,同时在合成数据上进行了实验。
揭示SciML中的多模态模式:不同的失败模式与模态特定优化
本文识别了科学机器学习模型中一致的三模态结构,表明优化效果是模态特定的,并可能挑战传统的损失景观解释。它提出了一个模态感知的诊断框架,并在PINN、神经算子以及神经ODE上得到验证。