修正表格分类中先验数据拟合网络的类别不平衡问题

arXiv cs.LG 论文

摘要

本文将经典的类别不平衡技术应用于表格分类的先验数据拟合网络(PFNs),发现由于PFNs的校准特性和有限数据能力,阈值法和降采样法表现良好。

arXiv:2605.21742v1 公告类型:新内容 摘要:先验数据拟合网络(PFNs)在表格分类任务中取得了卓越的性能。然而,与其他分类器一样,其性能可能受到类别不平衡的影响,导致稀有类别的表现不佳。已有多种技术试图减轻类别不平衡对分类性能的有害影响,但PFNs的上下文学习(ICL)动态使得基于损失的策略不可行,其他技术也未经验证。我们改编了几种处理类别不平衡的经典技术,并分析了它们在PFN分类中的性能。我们观察到,由于PFNs的校准特性,阈值法表现尤为出色,而由于PFNs在有限数据下的卓越性能,降采样法表现相仿,且额外带来了推理计算成本的降低。
查看原文
查看缓存全文

缓存时间: 2026/05/22 08:51

# 修正先验数据拟合网络中的类别不平衡问题以实现表格分类
来源: https://arxiv.org/html/2605.21742
###### 摘要

先验数据拟合网络(PFN)在表格分类任务中取得了卓越的性能。然而,与其他分类器一样,其性能可能受到类别不平衡的影响,导致少数类别的表现不佳。现有多种技术试图减轻类别不平衡对分类性能的不利影响,但 PFN 的上下文学习(ICL)动态使得基于损失的方法不可行,而其他技术尚未得到验证。我们改编了几种经典的处理类别不平衡的技术,并分析了它们在 PFN 分类上的表现。我们观察到,由于 PFN 的校准特性,阈值法表现异常出色;而由于 PFN 在有限数据下的出色表现,下采样法也表现出可比的性能,并且额外降低了推理的计算成本。

## I. 引言

先验数据拟合网络(PFN)已在表格数据推理 [2] (https://arxiv.org/html/2605.21742#bib.bib1) 和时间序列预测 [7] (https://arxiv.org/html/2605.21742#bib.bib4) 等关键领域变得无处不在。具体而言,TabPFN(及其变体)在表格分类和回归任务中展现了非凡的性能——无需更新任何权重,且仅需极少的任务特定训练数据。PFN 通过在海量*合成*数据上进行预训练,专门为了上下文学习任务,从而实现了这种卓越的效率。

通过仅从高度结构化的因果模型生成的合成数据中进行训练,PFN 隐式地学习估计后验预测分布(PPD),该分布基于由数据生成模型类别确定的先验。预训练模型类别与最终预测之间的关系至关重要,因为分类器的预测将遵循训练中使用的数据生成模型的结构。通过这种方式,控制预训练先验也就控制了分类的假设空间。

PFN 的另一个关键架构特性是其对上下文学习(ICL)的依赖。ICL 是一种元学习形式,其中 PFN 模型经过预训练,不是为了预测固定任务的输出,而是为了预测带标签的上下文集与未标签的查询之间的关系。这导致模型的权重在面临新任务时完全不更新;相反,预测依赖于提供给模型的相对较少的带标签示例(上下文)。这与经典的权重内学习不同,后者使用任务特定示例来更新模型权重。ICL 通常使用 Transformer 模型执行,因此上下文与查询之间的关系通过注意力机制捕获。

虽然 PFN 在表格分类上达到了最先进的性能,但与大多数分类模型一样,它们在类别不平衡的数据上表现不佳,即某些类别的样本数量远多于其他类别。事实上,这些模型可能达到合理甚至出色的平均准确率,但由于样本有限,少数类别的性能可能会显著下降。这反过来又限制了稀有类别(例如,罕见疾病或网络攻击检测)的检测。

解决类别不平衡影响的方法主要分为三类:基于损失、基于数据或基于决策的方法。损失重加权是经典方法,属于第一类,涉及提高少数类别样本的损失权重,使其与多数类别对模型具有相同的影响。数据级方法包括对多数类别进行下采样以匹配少数类别的规模,以及生成合成的少数类别样本。决策级方法涉及操作模型的输出,例如缩放/倾斜分类器的软分数。所有这些方法都有其优点;然而,由于 ICL 的学习动态与权重内学习不同,其中一些技术,特别是损失重加权,无法应用于此场景。

最近,[4] (https://arxiv.org/html/2605.21742#bib.bib2) 通过评估生成合成少数样本的效果,解决了 TabPFN 在稀有类别上的性能问题。这种技术受限于用于生成合成样本的方法的有效性,因为合成样本分布的任何失真都会影响下游分类。此外,它的计算成本也很高,因为每个在代表性过多类别中额外添加的样本都需要生成一个样本。

我们首先研究 PFN 独特的校准特性。这促使我们进行实证评估,包括阈值法、下采样、过采样和合成上采样(所有方法均在第二节-E 小节中正式定义)等有理论依据且实用的修正方法。对于二分类任务,我们基于理论的实验结果表明,阈值法取得了最佳性能,大幅提升了少数类别的性能,而多数类别的性能下降极小。下采样也表现良好,在实现最高最差类别准确率的同时,平衡性能仅略有下降,并且通过减少上下文样本数量额外降低了推理计算成本。

## II. 问题设置

### II-A 先验数据拟合网络 (PFN)

PFN 是一类模型,它们在监督学习任务的贝叶斯先验上进行训练,通常使用大规模合成数据集,以便它们能够利用上下文学习 (ICL) 直接预测后验预测分布 (PPD) [5] (https://arxiv.org/html/2605.21742#bib.bib5)。在训练期间,这些模型被给予一组从某个分布中抽取的上下文点,并训练它们预测从同一分布中抽取的查询点的掩码标签。在各种分布上最小化标准交叉熵损失,使模型学习 \( P(y \mid x, D) \),即上下文 \( D = \{x_c, y_c\}_{i=1}^{n} \) 和查询 \( x \) 的 PPD。形式上,PPD 可以写成 [3] (https://arxiv.org/html/2605.21742#bib.bib3):

\[
P(y | x, D) \propto \int_{\Phi} P(y | x, \phi) P(D | \phi) P(\phi) d\phi,
\]
其中 \( y \) 是数据点 \( x \) 的类别标签,\( D \) 是与 \( x \) 同分布的一个带标签数据集,而 \( \Phi \) 是数据生成函数的集合。

数据生成函数的集合是一个先验,它定义了分类器的假设空间。一个常见的选择是结构因果模型 (SCM) 的集合,其中特征之间的因果关系由有向无环图的边表示 [3] (https://arxiv.org/html/2605.21742#bib.bib3)。这训练模型通过*隐式*预测最符合数据的 SCM,并使用该模型预测 \( \hat{y} \),从而直接估计 PPD。

### II-B 上下文学习 (ICL)

在 ICL 期间,预训练模型被给予一组 \( n \) 个带标签的上下文点 \( \{(x_c, y_c)\}_{i=1}^{n} \) 和一个查询 \( x_q \),并被要求预测 \( \hat{y}_q \),即查询点的标签。关键的是,这并不涉及更新模型权重,因此与标准的权重内学习不同,在权重内学习中,带标签的数据用于在传入查询预测标签之前更新模型权重。

ICL 的想法起源于大型语言模型 (LLM),当它们展现了执行任务无关的少样本分类的能力时 [1] (https://arxiv.org/html/2605.21742#bib.bib6)。然而,自那以后,ICL 的想法已经扩展,模型被训练在各种设置中明确执行 ICL,包括图像 [8] (https://arxiv.org/html/2605.21742#bib.bib7) 和表格数据 [3] (https://arxiv.org/html/2605.21742#bib.bib3)。

ICL 模型的学习完全发生在上下文和查询数据的潜在表示中,而不是模型的权重中,这使得操作模型变得困难。如果不执行计算成本高昂的模型微调操作,影响其性能的唯一方法就是操作输入数据或下游预测。

### II-C 接收者操作特征 (ROC)

对于类别为‘0’和‘1’(少数类别)的二分类,ROC 曲线定义了分类器在特定漏检 (MD) 率与虚警 (FA) 率下的可能操作点。一个完美的分类器 MD 和 FA 率都为 0;即 \( P(\hat{y}=0 \mid y=1) = 0 \) 且 \( P(\hat{y}=1 \mid y=0) = 0 \)。

在实践中,MD 和 FA 率之间存在权衡:当 MD 降低时,FA 增加,反之亦然。根据漏检和虚警的成本,可以选择分类器的操作点以最小化期望成本。一个常见的选择是最小化错误概率:

\[
P_e = \pi_1 P(\hat{y}=0 \mid y=1) + \pi_0 P(\hat{y}=1 \mid y=0),
\]
其中 \( \pi_0 \) 和 \( \pi_1 \) 分别是类别‘0’和‘1’的先验概率。事实上,这将是利用未加权损失最小化所实现的目标,在实践中通过经验风险最小化 (ERM) 实现。

然而,这并不一定能产生良好的下游分类器。当在不平衡数据集上执行 ERM 时,分类器会达到一个具有高漏检率或高虚警率的操作点。此后,我们假设‘0’和‘1’分别是多数类别和少数类别。

### II-D 校准

令 \( f_\theta(x) \) 为一个预训练的分类器,参数化为 \( \theta \)。我们将分类器输出视为软分数,即 \( f_\theta(x) \in [0,1] \)。如果对于所有 \( p \in [0,1] \),有 \( P[Y=y \mid f_\theta(X)=p] = p \),则模型是*完美校准*的。也就是说,样本 \( x \) 具有标签 \( y \) 的预测概率与真实后验概率相同。由于 \( Y \) 是伯努利随机变量,这个概率就是期望值(或经验上,样本均值)。

评估模型校准的标准方法是绘制观测频率 \( p \in [0,1] \) 作为预测概率的函数图,即 \( E[y \mid f_\theta(x)=p] \)。在实践中,x 轴的值是通过对选定范围的 \( p \) 进行经验估计得到的。y 轴的值是在相同范围内标签的经验频率。完美校准的曲线是一条通过原点的直线。如果对于所有 \( p \),预测概率相对于真实概率偏向于 0.5(或背离 0.5),则模型是*欠自信*(相应地,*过自信*)的。另一方面,如果一个模型始终预测高于或低于观测频率,则它偏向于某一类(参见图 1(b) (https://arxiv.org/html/2605.21742#S2.F1.sf2))。校准信息允许对 \( f_\theta(x) \) 进行修正,从而做出更准确的下游预测。

参考图注
(a) 欠自信
参考图注
(b) 类别 0 有偏

图 1:示例校准曲线

### II-E 数据级策略 (采样)

**下采样** 涉及从上下文集中移除多数样本,使得每个类别的样本数量相等。这实现了 \( \pi_0 = \pi_1 \),但代价是减少了关于多数类别的可用信息。

**过采样** 涉及在上下文集中多次包含少数类别的样本,以使每个类别的样本数量相等。该技术同样实现了 \( \pi_0 = \pi_1 \),但扭曲了少数类别的分布,使其看起来比真实分布更尖锐。

**合成上采样** 涉及使用上下文集生成少数类别的人工样本,并用足够多的这些样本补充上下文集,使其成为类别平衡的。合成上采样类似于过采样,但少数类别分布的扭曲来自生成器所学习到的分布中的任何不准确性。

### II-F 决策级策略

分类的贝叶斯最优规则是选择软分数最高的类别。对于二分类,这简化为设置一个 0.5 的阈值来进行硬决策。

**阈值法** 涉及将决策边界从 0.5 移开。这可以解释为当假阴性(将少数类别错误分类为多数类别)和假阳性(将多数类别错误分类为少数类别)的成本不相等时最小化风险 \( \mathcal{R} \),即:

\[
\mathcal{R} = C_{01} \pi_1 P(\hat{y}=0 \mid y=1) + C_{10} \pi_0 P(\hat{y}=1 \mid y=0).
\]

如果我们有一个以 \( \pi_0 \neq \pi_1 \) 优化的分类器,我们可以通过定义 \( C_{01}/C_{10} = \pi_0/\pi_1 \) 使其等价于在 \( \pi_0 = \pi_1 \) 时达到的分类器。确定分类阈值 \( \tau \),使得我们的结果是 \( 1[f_\theta(x, D) > \tau] \),我们有:

\[
\tau = \frac{C_{10}}{C_{10} + C_{01}} = \pi_1.
\]

因此,通过将分类阈值从 0.5 调整开,我们能够抵消数据不平衡的影响。

### II-G 指标

我们关注的指标是每个类别的准确率、平衡准确率和最差类别准确率 (WCA)。我们使用一个留出的测试集来计算准确率。每个类别的准确率简单来说是测试数据中每个类别正确分类的点的比例,即 \( P(\hat{y}=i \mid y=i), i \in \{0,1\} \)。平均测试准确率是两个类别上的经验准确率,即每个准确率按其经验先验进行缩放。另一方面,平衡准确率是两个类别准确率的平均值(等价于 \( 1 - P_e \),其中 \( P_e \) 在公式 2 (https://arxiv.org/html/2605.21742#S2.E2) 中定义,且 \( \pi_0 = \pi_1 \))。最后,WCA 是两个类别准确率中的最小值。

## III. 实验结果与分析

在我们的实验中,我们从基准数据集集合 OpenML-CC18 中选择二分类任务。该集合包含 72 个表格分类任务,我们从中选出满足以下约束条件的 11 个数据集:

- **任务**:二分类
- **测试集大小**:每类 500 个样本
- **训练集大小**:少数类别 500 个样本,多数类别 950 个样本(允许不平衡程度达到 \( \pi_1 = 0.05 \),同时保持总共 1000 个训练样本)

查询(测试)集始终保持平衡,因此平均测试准确率也可以视为平衡准确率。

在这些实验中,我们利用了 TabPFN-2.5 [2] (https://arxiv.org/html/2605.21742#bib.bib1),它在表格 PFN 模型中提供了最先进的性能。我们在表 I (https://arxiv.org/html/2605.21742#S3.T1) 中列出了实验中使用的全部数据集及其总样本数和自然不平衡度。

表 I:实验中使用的数据集摘要。

### III-A 校准

在图 2 (https://arxiv.org/html/2605.21742#S3.F2) 中,我们展示了当训练上下文是平衡或不平衡时,TabPFN 在几个数据集上的经验校准曲线。我们观察到一种持续的趋势:在平衡数据集上校准良好,而在不平衡设置下则存在多数类别偏向。图 2 (https://arxiv.org/html/2605.21742#S3.F2) 显示了校

相似文章

当表格基础模型遇到策略性表格数据:一种先验对齐方法

arXiv cs.AI

本文研究了基于预训练先验数据拟合网络的表格基础模型是否能够泛化到个体在部署后修改特征的策略性表格数据。提出了策略性先验数据拟合网络(SPN),这是一个无需重新训练即可将PFN预测与操纵后分布对齐的推理时框架。

TabPFN-3:技术报告

arXiv cs.LG

TabPFN-3 是一个新的表格数据基础模型,在合成数据上预训练,可扩展到 100 万训练行,同时减少训练和推理时间,在表格预测、时间序列和关系数据上实现了最先进的性能。

PriorLabs/TabPFN

GitHub Trending (daily)

PriorLabs 推出了 TabPFN,这是一种专为表格数据设计的基座模型。