CRUMB:基于分布匹配上下文批处理的高效先验拟合网络推理

arXiv cs.LG 论文

摘要

本文提出CRUMB,一种三阶段推理封装方法,通过聚类测试查询并利用最小化最大均值差异(MMD)选择分布匹配的训练子集,从而实现对大规模数据集的高效先验拟合网络推理。在51个TabArena数据集上,该方法在上下文选择方面达到了最先进水平。

arXiv:2606.11473v1 公告类型:新论文 摘要:先验拟合网络(PFN)是一类有前景的表格基础模型,它们通过上下文学习,将整个带标签的训练集作为上下文提供,并在单次前向传播中为测试查询生成预测。然而,许多PFN架构中二次方缩放的自注意力机制使得对非常大的训练数据集进行推理变得不可行。我们提出CRUMB(基于最小化MMD聚类的检索),一种三阶段推理封装方法:(i)对测试查询进行聚类,(ii)通过贪婪最小化最大均值差异(MMD)为每个聚类选择一个小型、分布匹配的训练子集,(iii)在每个缩减上下文的批次上执行精确的PFN推理。CRUMB与架构无关,无需重新训练。在51个数据集的TabArena基准测试中,针对三种PFN架构(TabPFNv2、TabICLv1、TabICLv2)进行评估,我们证明CRUMB优于类似的最新上下文选择策略。我们还表明CRUMB对协变量偏移具有鲁棒性,因为MMD最小化步骤自然有助于对齐训练上下文分布以匹配当前测试批次分布。
查看原文
查看缓存全文

缓存时间: 2026/06/11 13:47

# 基于分布匹配上下文批处理的高效先验拟合网络推理
来源: https://arxiv.org/html/2606.11473
Jamie Heredge Mattia J. Villani* Pranav Deshpande Akshay Seshadri Niraj Kumar
全球技术应用研究部,摩根大通,纽约,NY 10001,美国
共同第一作者。邮箱: {jamie.heredge, mattia.villani}@jpmchase.com。
首席研究员。邮箱: [email protected]

###### 摘要

先验拟合网络(PFNs)是一类有前景的表格基础模型,能够进行上下文学习:将整个带标签的训练集作为上下文提供,并在单次前向传播中为测试查询生成预测。然而,许多 PFN 架构中的自注意力机制具有二次方复杂度,使得对于非常大的训练数据集,推理变得不可行。我们提出了 **CRUMB**(基于最小化 MMD 的聚类检索批处理),这是一个三阶段推理封装器:(i) 对测试查询进行聚类,(ii) 为每个聚类选择一个小的、分布匹配的训练子集,通过贪心最小化最大均值差异 (MMD) 来实现,(iii) 在每个缩减上下文的批次上执行精确的 PFN 推理。CRUMB 与架构无关,且无需重新训练。在包含 51 个数据集的 TabArena 基准测试中,我们在三种 PFN 架构(TabPFNv2, TabICLv1, TabICLv2)上评估,结果表明 CRUMB 优于类似的现有最优上下文选择策略。我们还展示了 CRUMB 对协变量漂移具有鲁棒性,因为 MMD 最小化步骤自然有助于使训练上下文分布与当前测试批次分布对齐。

## 1 引言

先验拟合网络 (PFNs) [12 (https://arxiv.org/html/2606.11473#bib.bib47), 13 (https://arxiv.org/html/2606.11473#bib.bib46)] 是一类日益流行 [28 (https://arxiv.org/html/2606.11473#bib.bib19)] 的表格基础模型,它们通过上下文学习解决监督学习任务:将整个带标签的训练集作为上下文提供,并在单次前向传播中为测试点生成预测。这些 PFN 模型已在各种数据集上取得成功,甚至超越了诸如 CatBoost [21 (https://arxiv.org/html/2606.11473#bib.bib39)] 和 XGBoost [3 (https://arxiv.org/html/2606.11473#bib.bib40)] 等梯度提升决策树方法,这些方法此前被认为是表格数据集分类任务中最具竞争力的模型之一。一个关键问题是,PFN 相对于当前最优方法的胜出报道通常只涉及小数据集 [13 (https://arxiv.org/html/2606.11473#bib.bib46), 32 (https://arxiv.org/html/2606.11473#bib.bib2), 5 (https://arxiv.org/html/2606.11473#bib.bib14)]。部分原因是注意力层随训练数据集大小呈二次方缩放,这使得 PFN 模型在处理大型数据集时时间和内存成本过高。

已经提出了多种方法来解决这个问题。可能的解决方案包括对 PFN 本身的架构进行更改,例如切换到线性注意力,如 TabFlex [33 (https://arxiv.org/html/2606.11473#bib.bib21)] 中所采用的。还有一些技术专注于针对特定数据集微调模型权重 [27 (https://arxiv.org/html/2606.11473#bib.bib56)],以及上下文调整 [8 (https://arxiv.org/html/2606.11473#bib.bib26)],后者对训练点本身进行变分自适应。在这项工作中,我们不关注架构调整,而是将研究范围限制在 TabPFNv2、TabICLv1 和 TabICLv2 上。

一个自然的补救措施是*上下文选择*:用一个小得多的子集 $\mathcal{S} \subset \mathcal{D}_{\text{train}}$ 替换完整的训练集 $\mathcal{D}_{\text{train}}$,这样选择使得预测受到的影响最小。均匀子采样是最简单的选择,尽管它实际上随机丢弃信息以换取速度提升。依赖于查询的方法(如 $k$-近邻 ($k$NN) 检索)可以通过为每个测试点定制上下文来提高准确性,但牺牲了批处理测试查询的能力:因为每个测试点接收到不同的上下文,每个都需要单独的前向传播,导致 $T$ 次独立的 PFN 评估,而不是少量批处理调用。

我们提出了 **CRUMB**(基于最小化 MMD 的聚类检索批处理),该方法解决了上下文质量与批处理效率之间的矛盾。关键思想是对*测试*查询进行聚类,并通过最小化测试聚类和选定训练子集之间的最大均值差异 (MMD) 来为每个测试聚类选择一个训练上下文。这使训练上下文与我们想要评估的测试查询点对齐,同时允许一次对一批测试查询进行高效的 PFN 推理。图 1 (https://arxiv.org/html/2606.11473#S1.F1) 给出了 CRUMB 的概览。我们的贡献如下:

- **基于 MMD 的上下文检索**。我们提出贪心 MMD 最小化(核 herding)作为选择训练子集的机制,并证明在相同的聚类框架内,基于 MMD 的选择优于中心近邻和 Voronoi 均匀检索。我们还展示了与其他方法相比,这种 MMD 最小化的上下文检索为测试数据中的协变量漂移提供了额外的鲁棒性。
- **通过测试端聚类实现批处理上下文选择**。我们展示了对测试查询进行聚类并为每个聚类共享一个训练上下文,能够在保持查询相关上下文的同时实现高效批处理。
- **在 TabArena 上的强结果**。在包含 51 个数据集的 TabArena 基准测试 [6 (https://arxiv.org/html/2606.11473#bib.bib30)] 上,我们在三种 PFN 架构下进行评估,发现 CRUMB 的性能接近逐查询的 $k$NN,同时只需要固定 $K$ 次前向传播,而不是每个测试点一次前向传播。在相同上下文预算下,CRUMB 显著优于混合上下文提示器 (MICP) 技术和均匀子采样。我们还展示,在协变量漂移下,当漂移从无漂移加剧到完全样本外协变量漂移时,CRUMB 相对于 MICP 的优势从 +4.9% 增长到 +17.1%。

训练 $\mathcal{D}_{\text{train}}$,$N$ 个点
测试 $\mathcal{D}_{\text{test}}$,$T$ 个点

阶段 1: 在 $\mathcal{D}_{\text{test}}$ 上运行 $k$-means
$C_1$ $C_2$ $\cdots$ $C_K$

阶段 2: MMD Herding
$\mathcal{S}_1$ $\mathcal{S}_2$ $\cdots$ $\mathcal{S}_K$ ($n \ll N$)

阶段 3: PFN($\mathcal{S}_1, C_1$) PFN($\mathcal{S}_2, C_2$) $\cdots$ PFN($\mathcal{S}_K, C_K$)
输出: $\hat{\mathbf{y}}_{C_1}$ $\hat{\mathbf{y}}_{C_2}$ $\cdots$ $\hat{\mathbf{y}}_{C_K}$

其中
$\displaystyle \mathcal{S}_k^* = \arg \min_{|\mathcal{S}|=n} \mathrm{MMD}^2\!\big(\hat{P}_{C_k},\, \hat{P}_{\mathcal{S}}\big)$

图 1: CRUMB 概述。阶段 1: 通过 $k$-means 将测试查询划分为 $K$ 个聚类。阶段 2: 对于每个聚类 $C_k$,通过贪心 MMD 最小化(核 herding)从完整训练池(蓝色虚线箭头)中选择大小为 $n \ll N$ 的训练子集 $\mathcal{S}_k$。阶段 3: PFN 运行 $K$ 次独立的前向传播,每次具有一个小的、几何上相关的上下文。总注意成本从 $T \cdot N$ 降至 $T \cdot n$,从而能够对具有 $N > 50,000$ 个训练点的数据集进行推理。

## 2 相关工作

#### 高效先验拟合网络。
PFN 范式由 [13 (https://arxiv.org/html/2606.11473#bib.bib46)] 引入,他们展示了一个在从先验采样的合成数据集上训练的 transformer 可以对真实表格数据进行上下文学习。TabPFNv2 [25 (https://arxiv.org/html/2606.11473#bib.bib55)] 和 TabPFNv2.5 [10 (https://arxiv.org/html/2606.11473#bib.bib25)] 将这个想法扩展到更大、更多样化的先验,在中小型数据集上取得了强结果 [19 (https://arxiv.org/html/2606.11473#bib.bib15)]。TabICL 系列工作 [22 (https://arxiv.org/html/2606.11473#bib.bib41), 23 (https://arxiv.org/html/2606.11473#bib.bib42)] 探索了相同范式下的替代架构和训练过程。所有这些模型都共享推动我们工作的基本缩放限制:它们的推理时间成本随上下文中包含的训练样本数量呈二次方缩放。

有几项工作引入了架构更改以提高速度、性能和标记效率,包括 TabFlex [33 (https://arxiv.org/html/2606.11473#bib.bib21)]、MITRA [34 (https://arxiv.org/html/2606.11473#bib.bib7)] 以及相关的 TabPFN 变体 [15 (https://arxiv.org/html/2606.11473#bib.bib6)]。分块 TabPFN [25 (https://arxiv.org/html/2606.11473#bib.bib55)] 引入了一种分块块注意力策略,将注意力张量划分为块,以增量方式计算完整注意力,从而加快注意力计算。此外,一些方法试图降低数据集的有效维度 [7 (https://arxiv.org/html/2606.11473#bib.bib5)],提供了与推理时间上下文子选择正交的方向。

已经提出了具有抗漂移能力的 TabPFN 变体,以提高对时间分布变化的鲁棒性 [11 (https://arxiv.org/html/2606.11473#bib.bib9)]。这些模型可以处理协变量和概念漂移,但主要是通过修改 PFN 的预训练先验以包含时变数据生成机制,而不是通过推理时的上下文选择。

#### 上下文学习中的上下文选择。
关于 PFN 的先前工作表明,基于 $k$NN 的上下文选择是一个有效的选择 [13 (https://arxiv.org/html/2606.11473#bib.bib46), 27 (https://arxiv.org/html/2606.11473#bib.bib56)]。然而,$k$NN 方法的一个关键缺点是测试查询无法高效批处理,因为每个查询可能引发不同的检索上下文。MixturePFN [30 (https://arxiv.org/html/2606.11473#bib.bib29)] 提出了混合上下文提示器 (MICP),它将附近的测试点路由到共享的局部训练上下文,从而在保持局部性的同时实现高效批处理。其他与上下文子选择密切相关的工作包括 [18 (https://arxiv.org/html/2606.11473#bib.bib3), 17 (https://arxiv.org/html/2606.11473#bib.bib4)];然而,这些方法不使用基于 MMD 的启发式方法。

## 3 背景

我们考虑一个标准的表格监督学习设置。设 $\mathcal{D}_{\text{train}} = \{(\mathbf{x}_i, y_i)\}_{i=1}^N$ 表示带标签的训练集,其中 $\mathbf{x}_i \in \mathbb{R}^d$,对于分类任务 $y_i \in \{1, \dots, C\}$,对于回归任务 $y_i \in \mathbb{R}$;设 $\mathcal{D}_{\text{test}} = \{\mathbf{x}_j^*\}_{j=1}^T$ 表示测试查询。PFN 将完整训练集作为上下文,并在单次前向传播中为所有测试查询生成预测 $\mathbf{\hat{y}} = \text{PFN}(\mathcal{D}_{\text{train}}, \mathcal{D}_{\text{test}})$,无需对 $\mathcal{D}_{\text{train}}$ 进行任何数据集特定的训练。相反,PFN 模型已经在大范围合成数据集上进行了预训练,以近似后验预测分布,这使得它们能够在单次前向传播中执行上下文学习 [12 (https://arxiv.org/html/2606.11473#bib.bib47)]。然而,由于训练样本之间的自注意力 [12 (https://arxiv.org/html/2606.11473#bib.bib47), 22 (https://arxiv.org/html/2606.11473#bib.bib41)],这次前向传播的成本通常与训练样本数量成二次方关系,当 $N$ 很大时变得不可行。

#### 最大均值差异。
最大均值差异 (MMD) 是一种基于核的概率分布距离 [9 (https://arxiv.org/html/2606.11473#bib.bib53)]。给定两个分布 $P$ 和 $Q$ 以及一个具有核 $\kappa$ 的再生核希尔伯特空间 (RKHS),平方 MMD 定义为:
$\text{MMD}^2(P, Q) = \mathbb{E}_{\mathbf{x}, \mathbf{x}' \sim P}[\kappa(\mathbf{x}, \mathbf{x}')] - 2\mathbb{E}_{\mathbf{x} \sim P, \mathbf{z} \sim Q}[\kappa(\mathbf{x}, \mathbf{z})] + \mathbb{E}_{\mathbf{z}, \mathbf{z}' \sim Q}[\kappa(\mathbf{z}, \mathbf{z}')].$ (1)
直观上,$\text{MMD}^2(P, Q) = 0$ 当且仅当 $P = Q$(对于特征核如高斯 RBF),这使其成为衡量所选训练子集对目标分布表示程度的一个自然标准。在我们的设置中,$P$ 是测试聚类 $C_k$ 上的经验分布,$Q$ 是候选训练子集 $\mathcal{S}_k \subset \mathcal{D}_{\text{train}}$ 上的经验分布。

#### 协变量漂移。
标准假设是训练输入和测试输入来自相同的边际分布 $P(\mathbf{x})$。协变量漂移是指这个假设被违反的情况:测试输入来自一个漂移的分布 $P_{\text{test}}(\mathbf{x}) \neq P_{\text{train}}(\mathbf{x})$,而标签机制 $P(y|\mathbf{x})$ 保持不变 [26 (https://arxiv.org/html/2606.11473#bib.bib24), 31 (https://arxiv.org/html/2606.11473#bib.bib23)]。在这种漂移下,从 $\mathcal{D}_{\text{train}}$ 均匀采样的训练上下文将反映 $P_{\text{train}}$ 而不是 $P_{\text{test}}$,可能会在测试查询集中的特征空间区域降低预测质量。这促使我们采用显式地将训练上下文与测试分布对齐的上下文选择方法。

#### 混合上下文提示器 (MICP)。
与 CRUMB 最相关的基线是 MixturePFN [30 (https://arxiv.org/html/2606.11473#bib.bib29)] 中的混合上下文提示器 (MICP) 组件。MICP [30 (https://arxiv.org/html/2606.11473#bib.bib29)] 对训练数据进行聚类,为每个聚类构建一个固定的支持集,并将测试点路由到最近的训练中心。完整的 MixturePFN 系统还包括一个上下文感知微调 (CAPFN) 阶段,该阶段训练每个数据集的适配器层;我们全程省略 CAPFN 以隔离上下文选择的效果,并确保与同样不修改模型权重的 CRUMB 进行公平比较。

## 4 方法

我们提出的方法包含三个阶段:(i) 对测试查询进行聚类,(ii) 通过最小化训练点和测试点之间的最大均值差异 (MMD) 为每个聚类检索训练子集,以及 (iii) 在每个(聚类,子集)对上运行 PFN 前向传播,其中给定对中的所有点可以批量处理在一起。我们将此方法称为**基于最小化 MMD 的聚类检索批处理 (CRUMB)**。我们还描述了可选的增强功能,例如一种自适应的训练上下文选择方法,我们选择点直到看到 MMD 最小化没有显著改进,从而通过避免没有性能提升的大上下文来节省时间。

**阶段 1: 对测试查询进行聚类。** 我们使用 $k$-means 在输入特征 $\{\mathbf{x}_j^*\}_{j=1}^T$ 上将测试集划分为 $K$ 个聚类 $\{C_1, \dots, C_K\}$ [16 (https://arxiv.org/html/2606.11473#bib.bib38), 2 (https://arxiv.org/html/2606.11473#bib.bib32)]。具体数量

相似文章

面向上下文LLM级联的在线Pandora's Box

arXiv cs.AI

本文介绍了一种面向自适应查询和选择LLM API的在线上下文Pandora's Box模型,提出了一种结合GMM估计与UCB风格置信区间的学习方法,并证明了维度相关的遗憾界。