利用记忆引导的数据集去偏方法缓解虚假相关性

arXiv cs.LG 论文

摘要

本文提出一种通过两阶段样本评分函数分离核心特征与虚假特征学习动态的方法,仅需10%的训练数据即可实现最先进的去偏性能。

arXiv:2606.02830v1 公告类型:新论文 摘要:现实世界的数据集通常包含与目标标签无因果关系的虚假相关性。当此类相关性主导大多数训练样本时,模型倾向于依赖它们,导致对不具备相同虚假模式的少数样本分类错误。虽然一种潜在方法是选择数据子集以更好地代表少数样本,但这可能需要访问通常未知的组标签。此外,如我们所示,在不变子集或核心集选择文献中广泛使用的样本评分函数在很大程度上依赖于虚假特征,因此无法准确捕捉核心因果相关特征的重要性或难度。据此,我们提出通过开发一种两阶段样本评分函数来缓解虚假相关性,该函数分离核心特征与虚假特征的学习动态,并分别评估它们的难度。基于我们提出的度量,我们引入了一种新算法,用于寻找并优先处理带有和不带有虚假相关性的信息样本。大量实验表明,在所选样本上训练的标准经验风险最小化(ERM)模型相比最先进的去偏技术实现了更优的性能,而仅需原始训练数据的10%。
查看原文
查看缓存全文

缓存时间: 2026/06/03 09:40

# 基于记忆引导的伪相关去偏数据集选择方法

**来源:** https://arxiv.org/html/2606.02830

**Arda Fazla, Abolfazl Hashemi**

作者单位:普渡大学电气与计算机工程学院,美国印第安纳州西拉法叶市,邮编47907。

###### 摘要

现实世界的数据集通常包含与目标标签无因果关系的伪相关。当这些伪相关主导了训练样本中的大多数时,模型倾向于依赖它们,导致对那些不呈现相同伪模式的少数样本分类错误。虽然一种潜在的方法是选择数据子集以更好地代表少数样本,但这可能需要访问通常未知的群体标签。此外,正如我们所展示的,在不变子集或核心集选择文献中广泛使用的样本评分函数在很大程度上依赖于伪特征,因此无法准确捕捉核心因果相关特征的重要性或难度。因此,我们提出通过开发一个两阶段样本评分函数来缓解伪相关,该函数能够分离核心特征和伪特征的学习动态,并分别评估它们的难度。基于我们提出的度量,我们引入了一种新算法,用于寻找并优先处理带有或不带有伪相关性的信息样本。大量实验表明,在我们选择的样本上训练的标准ERM模型,其性能优于最先进的去偏技术,同时所需的训练数据仅为原始训练数据的10%。

## 1 引言

![图1](待添加图片描述)
**图1:标准样本分数在Waterbirds数据集上的比较。** (A) 我们可视化了Waterbirds数据集中两张示例图像的EL2N分数,该分数是在包含和不包含背景的数据集上训练的模型上计算得出的。结果表明,背景的存在显著改变了EL2N分数。(B) 我们展示了基于在Waterbirds上训练的ResNet50模型以及CLIP模型提取的特征嵌入,具有高相似度的代表性图像对。在这两种情况下,高相似度主要由共享的背景特征驱动,即使鸟类属性差异很大。

![图2](待添加图片描述)
**图2:我们提出的核心集选择算法概述。** (A) 我们首先训练一个两阶段模型,以准确分离伪特征和核心特征的学习过程,并分别计算每个组成部分的样本分数(`TCSL_s` 和 `TCSL_c`)。然后,基于计算出的分数,我们构建核心集选择算法——两阶段累积样本损失(TCSL)引导的核心集选择(TCSL-CS)。(B) 我们说明了TCSL分数相对于现有样本评分函数的优势,展示了其能够分离图像中不同特征组成部分的难度。(C) 在仅使用10%总训练数据的情况下,在TCSL-CS选择的核心集上训练一个ERM模型,在Waterbirds数据集上达到了最先进的WGA。

现实世界的数据集通常包含大量具有伪相关的样本,这些伪相关在类别内高度一致,但对真实的类别标签没有预测能力。深度学习模型倾向于基于这些更简单的伪特征而非更复杂的核心特征进行预测[42, 39],导致在测试数据上,当伪相关不成立时,最差群体准确率(WGA)很差。因此,最近的大量工作集中在开发专门算法来缓解由伪相关引起的偏差。然而,这些方法偏离了实践中仍然广泛使用的标准ERM训练。这引出了一个问题:我们能否在保持标准ERM训练的同时,减少模型对伪相关的依赖,而不引入复杂的、专门设计的优化技术?

一个有助于回答上述问题的直观方法是,利用现有的不变数据或样本选择算法,这些算法旨在构建能够代表完整训练分布的高质量数据子集。利用这个想法,我们可以形成一个信息丰富的样本集合,即*核心集*,它能很好地代表数据集中的所有群体。在这样的核心集上训练一个标准ERM模型,可能会在所有群体上取得有竞争力的性能。

然而,我们认为常用的核心集选择算法无法在伪相关数据集上构建强健的核心集,因为它们并非明确设计用于确保高最差群体准确率,而是确保高平均测试准确率。因此,数据集中的某些群体在选择的核心集中可能代表性不足,即使总体平均准确率仍然很高。因此,这些核心集选择方法不能直接应用于伪相关数据集。

最近的工作[6, 37]表明,基于常用样本分数(如EL2N[23]和SelfSup[34])构建的核心集选择策略,在已知伪相关的数据集上无法持续达到高最差群体准确率。常用的样本评分函数以及基于它们构建的核心集选择算法,通常为单个样本分配分数以反映其难度,并据此推导选择策略。我们认为,由于深度学习中广泛观察到的“简单性偏差”现象[36, 32, 19, 39],模型倾向于先学习简单的伪特征,然后才捕捉更复杂的核心特征。因此,依赖于模型输出或损失值的样本评分函数的行为被伪特征所主导。结果,没有伪相关的样本通常被分配高分并被归类为“困难”,而具有伪相关的样本则获得低分并被当作“简单”。这种偏差导致核心集选择中常用的评分函数无法充分捕捉底层核心特征的强度。我们在第2.1节和附录B中从理论上分析了简单性偏差对核心和伪特征学习速度以及基于损失的样本分数的影响。

我们在图1中说明了简单性偏差的影响。在图1(A)中,我们展示了当背景特征从数据集中移除时,EL2N分数会显著变化,这表明一只普通的鸟类图像可能仅仅因为背景特征而获得比一只独特的鸟类图像更高的分数。除了基于损失的样本评分函数,某些核心集选择算法还结合了基于提取的特征嵌入的样本相似度。因此,我们在图1(B)中展示了来自Waterbirds数据集的代表性高相似度图像对,这些图像对是通过ResNet模型和基础模型CLIP提取的特征嵌入的余弦相似度识别的。在这两种情况下,相似度主要由共享的背景特征驱动,即使鸟类属性差异很大。

此外,最近一项研究[18]表明,即使少量具有简单伪特征和复杂核心特征的样本,也可能导致模型主要依赖伪特征进行预测。受这些观察的启发,我们认为常用的样本评分函数被伪特征所主导;因此,基于它们构建的核心集选择算法可能无法区分简单和困难的核心特征,导致次优的核心集和较差的最差群体准确率。

作为补救措施,我们提出了两阶段累积样本损失(TCSL)和TCSL引导的核心集选择(TCSL-CS)算法。如图2(A)所示,我们的方法建立在广泛使用的两阶段训练方法之上(这些方法来自关于伪相关学习的文献[16, 20, 35, 2]),并在无需访问伪属性(群体标签)的情况下,区分核心和伪特征的分数计算。因此,TCSL由两个分数组成:`TCSL_s` 和 `TCSL_c`,分别表示计算出的伪特征和核心特征的难度。如图2(B)所示,传统的样本评分函数为每张图像分配一个单一的分数,该分数被伪特征(例如背景)所主导,而我们的TCSL框架则分别评估核心(鸟)和伪(背景)组成部分的难度。然后,我们基于TCSL设计TCSL-CS算法,以有效选择既能实现(1)高平均准确率又能实现(2)高最差群体准确率的核心集。如图2(C)所示,仅使用10%训练数据由TCSL-CS选择的核心集,在Waterbirds数据集上将标准ERM模型的最差群体准确率提高了11.33%,优于需要群体标签或复杂优化过程的基线方法。

**我们的贡献和范围:**

*   我们提出了TCSL分数,它分别量化了具有伪相关的数据集中核心和伪特征的学习难度。
*   我们提出的TCSL分数使我们能够将核心集选择作为一种原则性工具,引入TCSL-CS算法,该算法为具有伪相关的数据集选择有效的核心集,同时实现高平均准确率和高最差群体准确率。
*   我们对核心和伪特征的独特学习动态提供了强有力的理论分析。
*   通过在具有伪相关的数据集上进行大量实验,我们表明TCSL-CS在无需访问群体标签的情况下,优于现有的去偏和样本评分基线方法。

## 2 问题形式化

### 2.1 问题设置

令 $\mathcal{D} = \{(x_1, y_1), \dots, (x_n, y_n)\}$ 表示大小为 $n$ 的训练数据集,其中对于每个数据样本,我们观察到一个输入特征向量 $x_i \in \mathbb{R}^d$ 及其对应的标签 $y_i$。为简单起见且不失一般性,在整个分析中,我们专注于二元分类设置,其中 $y_i \in \{\pm1\}$。每个数据样本还关联着一个未观察到的伪属性 $a_i \in \{\pm1\}$。我们假设在每个类别内,训练数据被划分为一个多数群体和一个少数群体,其中多数群体包含所有 $a_i = y_i$ 的样本,少数群体包含所有 $a_i = -y_i$ 的样本。每个类别内多数群体样本的比例记为 $\alpha_y$,其中在实践中通常观察到 $\alpha_y > 0.5$。我们将整个数据集中多数样本的比例记为 $\alpha$。我们考虑这样一种设置:伪属性以及因此的群体标签对我们未知。

我们假设每个样本 $x_i$ 由核心和伪成分组成,$x_i = [x_i^c, x_i^s]$,其中 $x_i^c$ 与 $y_i$ 相关,$x_i^s$ 与 $a_i$ 相关。例如,在Waterbirds[30]数据集中,鸟类区域作为核心特征 $x_i^c$,而背景作为伪特征 $x_i^s$。

深度学习模型通常通过经验风险最小化(ERM)进行训练,其中给定一个具有概率输出 $f(x; \mathbf{W})$ 和权重 $\mathbf{W}$ 的模型,我们最小化
$$\mathcal{L}(\mathbf{W}) = \frac{1}{N} \sum_{i=1}^N \ell(y_i, f(x_i; \mathbf{W}))$$
其中 $\ell$ 可以是任何合适的损失函数,例如交叉熵损失或Sigmoid损失。通常使用基于梯度的优化技术,例如随机梯度下降(SGD),其中模型参数在每次迭代时更新为
$$\mathbf{W}_{t+1} = \mathbf{W}_t - \eta_t \nabla \mathcal{L}(\mathcal{B}_t, \mathbf{W}_t),$$
其中 $\eta_t$ 表示第 $t$ 次迭代的学习率,$\mathcal{B}_t$ 是在第 $t$ 步从 $\mathcal{D}$ 中使用的样本小批量。根据上下文,我们交替使用 $t$ 来表示更新步和轮次步,$T$ 表示总轮数。这里,$\nabla \mathcal{L}(\cdot, \cdot)$ 表示批次损失 $\mathcal{L}(\mathcal{B}_t, \mathbf{W}_t)$ 的随机梯度,其定义为批次上的平均加权损失:
$$\mathcal{L}(\mathcal{B}_t, \mathbf{W}_t) = \sum_{i=1}^{|\mathcal{B}_t|} \lambda_i \, \ell(y_i, f(x_i; \mathbf{W}_t)),$$
(1)
其中 $\lambda_i$ 表示与批次中第 $i$ 个样本相关的权重。通常,$\lambda_i$ 设置为 $1/|\mathcal{B}_t|$ 以确保均匀平均。在存在类别不平衡数据的情况下,$\lambda_i$ 通常被选择为属于同一类别的样本数量的倒数,这有助于防止模式崩溃,即模型退化为仅预测多数类别。为了保持批次间的一致性,应用了一个额外的归一化步骤,使得每个批次内样本权重的总和等于1,确保整体优化问题保持不变。

为了评估算法性能,我们使用平均测试准确率(ACC),定义为
$$\mathrm{ACC}(\mathbf{W}) = \mathbb{P}_{(x,y) \sim \mathcal{D}_{\text{test}}} \big[ \operatorname{arg\,max}(f(x;\mathbf{W})) = y \big].$$
对于伪相关,研究人员通常更关注最差群体准确率(WGA):
$$\mathrm{WGA}(\mathbf{W}) = \min_{\substack{y \in \{\pm1\} \\ a \in \{\pm1\}}} \mathbb{P}_{(x,y,a) \sim \mathcal{D}_{\text{test}}} \big[ \operatorname{arg\,max}(f(x;\mathbf{W})) = y \big],$$
它衡量模型在由 $y$ 和 $a$ 组合定义的所有群体中的最差准确率。

我们的总体目标是,给定核心集选择比率 $r$,从训练数据集 $\mathcal{D}$ 中选择 $r|\mathcal{D}|$ 个样本,使得在未见测试集上的WGA最大化。如果在其上训练的标准ERM模型达到最优WGA,我们认为核心集是成功的。因此,我们的目标是消除由伪相关引起的数据集偏差,从而在不牺牲整体泛化性能的前提下提高WGA。等价地,该任务可以被视为选择大小为 $c|\mathcal{D}|$ 的核心集。

相似文章