用于结构和结果预测的因果基础模型
摘要
TabPFN-CFM是一个因果基础模型,可从观测数据中预测因果结构和因果结果,支持Pearl因果层次的所有三个层级,并实现了优于基线的性能。
arXiv:2606.26467v1 公告类型:新
摘要:我们介绍了TabPFN-CFM,一个能够处理多种因果问题的因果基础模型。TabPFN-CFM可从观测数据中预测因果结构和因果结果,支持对Pearl因果层次所有三个层级的查询,并在已知图结构时利用它来改进预测。TabPFN-CFM在合成数据集上训练,并泛化到真实数据集,在结构和结果预测基线上均展现出更优的性能。
查看缓存全文
缓存时间: 2026/06/26 05:19
# 一个用于结构和结果预测的因果基础模型
来源:https://arxiv.org/html/2606.26467
###### 摘要
我们提出了 TabPFN\-CFM,一个能够处理多种因果问题的因果基础模型。TabPFN\-CFM 能够从观测数据中同时预测因果结构和结果,支持 Pearl 因果层次结构所有三个层次的查询,并在已知图结构可用时利用它来改进预测。TabPFN\-CFM 在合成数据集上训练,并能泛化到真实数据集,在结构预测和结果预测上均优于现有基线方法。
机器学习,ICML
## 1 引言
过程之间的因果关系决定了干预系统如何影响其行为。然而,从观测数据中确定因果结构和因果结果是一个具有挑战性的问题 (Pearl, 2000 (https://arxiv.org/html/2606.26467#bib.bib11); Imbens and Rubin, 2015 (https://arxiv.org/html/2606.26467#bib.bib24))。最近,机器学习模型已被开发出来,用于在最少数据的情况下高效地对干预措施的结果进行因果预测 (Künzel 等,2019 (https://arxiv.org/html/2606.26467#bib.bib18); Hollmann 等,2023 (https://arxiv.org/html/2606.26467#bib.bib15); Robertson 等,2025 (https://arxiv.org/html/2606.26467#bib.bib16); Balazadeh 等,2025 (https://arxiv.org/html/2606.26467#bib.bib17)),以及进行结构预测 (Lorch 等,2022 (https://arxiv.org/html/2606.26467#bib.bib14); Ke 等,2023 (https://arxiv.org/html/2606.26467#bib.bib25))。我们训练了一个因果基础模型,能够从观测数据中同时预测因果结构和结果。与现有方法相比,TabPFN\-因果基础模型 (TabPFN\-CFM) 提高了准确性,支持 Pearl 因果层次结构 (Pearl, 2000 (https://arxiv.org/html/2606.26467#bib.bib11)) 的所有三个层次,并在已知图结构可用时加以利用。此外,我们还改进了训练过程,使训练效率提高了近 4 倍。
### 1.1 问题设置
我们假设底层系统遵循一个 SCM ψ = {U, V, F},由未观测变量U、观测变量V和未知的结构方程F组成。未知变量包括未观测的协变量和噪声源。观测变量V被分为协变量X、一个二元处理变量T和结果Y。结构方程F定义了每个变量的父节点,并形成一个图G。我们有一个观测数据集 D^obs = {x_i, y_i, t_i}_{i=1}^n,包含从ψ中抽取的n个独立同分布样本,以及可选的因果图先验知识 G^est。如果没有先验,则 G^est = ∅。
我们的模型针对两个目标。首先,如果真实因果图未知,我们的模型可以对潜在的因果图结构进行估计。其次是三种因果查询下的结果预测。给定一个样本的协变量 x*、自然处理值 t* 和结果 y*,观测查询 P(y* | x*, T=t*, D^obs, G^est) 预测当处理变量取其观测值时的结果。干预查询 P(y* | x*, do(T=1-t*), D^obs, G^est) 预测在外部指定一个不同于观测到的处理值时的结果。反事实查询 P(y* | x*, do(T=1-t*), y=y_{t*}, D^obs, G^est) 预测在原始处理下观测到的结果条件下,不同处理时的假设结果。通过在单个模型上联合训练所有任务,该模型对因果过程有了更广泛的理解,并且比针对单一任务训练的模型学习效率更高。
为了解决这个问题,我们遵循贝叶斯 PFN 框架 (Robertson 等,2025 (https://arxiv.org/html/2606.26467#bib.bib16); Balazadeh 等,2025 (https://arxiv.org/html/2606.26467#bib.bib17))。一个关于 SCM 的先验 p(ψ) 生成一个真实的 SCM ψ_true ~ p(ψ)。观测数据产生后验 P(ψ | D^obs, G^est)。给定一个 SCM,结果可以估计为 P(y | x*, do(T=t*), ψ)。一个因果查询(例如干预查询)的目标分布是:
P(y* | x*, do(T=1-t*), D^obs, G^est) = (1)
∫ P(y* | x*, do(T=1-t*), ψ) P(ψ | D^obs, G^est) dψ. (2)
如果图结构 G 未知,其后验为:
P(G* | D^obs) = ∫ P(G* | ψ) P(ψ | D^obs) dψ. (3)
可选的图允许更准确地估计后验,因为它提供了变量之间因果关系的信息。在附录 A (https://arxiv.org/html/2606.26467#A1) 中,我们证明将图作为输入输入只能提高后验精度,如果图不改变 ψ 的后验分布(其中 y 有支撑),则没有改进。
实际上,我们不是显式地估计 P(ψ | D^obs),而是训练一个模型直接估计 P(y* | x*, do(T=t*), D^obs) 和 P(G | D^obs),方法是从先验 p(ψ) 中抽取样本,并基于观测值优化对数似然条件:L = -E[log p̂_θ(y* | x*, do(T=1-t*), D^obs, G)]。我们在附录 B (https://arxiv.org/html/2606.26467#A2) 中证明,这个损失等价于优化模型估计分布与真实后验分布之间的 KL 散度。
现有的因果结构学习深度学习方法不预测由未观测因素引起的混杂变量。我们的方法通过使用有向无环混合图 (ADMG) 来解决这个问题,ADMG 通过双向边表示未观测的混杂;详细信息见附录 C (https://arxiv.org/html/2606.26467#A3)。混杂可能导致不可识别的图,因此 P(ψ | D^obs, G^est) 是不确定的,这反映在估计的分布中。通过预测混杂,我们的模型不仅揭示了潜在的因果结构,还允许用户在解释结果时考虑混杂因素。
## 2 数据生成
我们的模型在由 SCM 先验分布生成的合成数据上进行训练,以近似该先验的贝叶斯推断。首先,我们采样具有随机图结构、缺失节点、噪声分布和结构方程的 SCM。使用具有随机权重和非线性激活的神经网络以及随机噪声分布生成随机函数。
为了生成单个训练样本,对于每个 SCM,我们从 SCM 中采样观测数据 D^obs,并生成另一个用于预测的观测数据点 D^pred = {x*, t*, y*}。由于真实因果图已知,我们生成一个反事实(具有固定噪声)D^causal = {x*_t, t=do(1-t*), y*_t}。最后,我们有对应于 SCM 结构的 ADMG,G^est = {A, C},包含邻接矩阵 A 和双向混杂矩阵 C。
我们的设置扩展了 Robertson 等人 (2025 (https://arxiv.org/html/2606.26467#bib.bib16)) 中使用的先验。先验中的这种多样性确保我们的模型学会对各种因果系统进行推断,从而提高其对真实数据的泛化能力。数据生成过程的完整描述见附录 D (https://arxiv.org/html/2606.26467#A4)。
## 3 模型架构
### 3.1 架构概述
Fit G G Pred Row Attn Col Attn MHA MLP Row Attn Col Attn MLP GG-Emb KV KV KV Row MHA Col MHA Graph MLP × n blocks Matrix Decoder ŷ Decoder context graph target
图 1:模型架构图。
我们的模型扩展了 Do-PFN 架构,以支持带有显式图条件的干预和反事实预测。详细信息见附录 E (https://arxiv.org/html/2606.26467#A5)。
编码器嵌入拟合集 D^fit、预测查询 D^pred 和先验 ADMG 结构。协变量、结果和处理变量像 TabPFNv2 一样使用随机列位置嵌入以及查询类型进行嵌入。ADMG 通过邻接矩阵 A 和双向矩阵 C 以及导出的祖先矩阵进行编码,使模型能够直接访问父节点、子节点、祖先节点和后代节点关系 (Ke 等人,2023 (https://arxiv.org/html/2606.26467#bib.bib25))。
主 transformer 由行注意力、列注意力和前馈层组成。预测还额外关注图嵌入,使模型预测能够整合图的先验信息。
输出从最终隐藏状态计算得出。对应于结果变量的最终隐藏状态通过一个 MLP 头传递,以产生离散结果桶上的 logits,从而得到预测分布 p̂_θ(y)。图结构预测使用每个特征变量对应的最终隐藏状态生成。一个解码器计算图中每条可能边的逐元素预测。对定向邻接矩阵 Â、双向相关矩阵 Ĉ 和祖先矩阵 R̂ 进行预测。
在我们的架构中,注意力从 D^fit 和图流向 D^pred,但不反向流动。除了像 TabPFN 那样允许从 KV 缓存进行高效推理外,这还使得预测的图结构独立于预测查询,包括输入的图先验。
与 TabPFNv2 和 Do-PFN 相比,我们省略了特征分组、集成、随机增强和随机特征乘积,因为图显式地索引了变量,并且在特征混合下难以保持。此外,我们还引入了几个用于稳定性和效率的训练和骨干网络改进,这些改进共同将训练速度提高了 3-4 倍,并且最终损失显著降低。这些变化的评估见附录 F (https://arxiv.org/html/2606.26467#A6) 中的消融研究。
### 3.2 训练过程
每个训练数据点包含数据集 D^fit、D^pred = {x*, t*, y*}、D^causal = {x*_t, t=do(1-t*), y*_t} 和 G^est。每个数据点用于生成所有三种查询类型的样本。在观测查询中,模型给定 x* 和 t*,目标是 y*。在干预查询中,模型给定 x* 和 do(1-t*),目标是 y*_t。在反事实查询中,模型给定 x*_t、do(1-t*) 和 y*,目标是 y*_t。在每种情况下,令 y_targ 和 p̂_θ(y) 分别为目标和预测分布。预测损失为:
L_pred = CE(y_targ, p̂_θ(y)). (4)
这是 PFN (Hollmann 等人,2023 (https://arxiv.org/html/2606.26467#bib.bib15); Balazadeh 等人,2025 (https://arxiv.org/html/2606.26467#bib.bib17); Robertson 等人,2025 (https://arxiv.org/html/2606.26467#bib.bib16)) 的标准预测损失。模型总是给定 D^fit。为了让模型学会在有和没有真实因果图的情况下进行预测,一半的时间将 G^est 设为零。
结构预测损失在预测 Â、R̂、Ĉ 和真实矩阵之间是逐元素二元交叉熵:
L_graph = 1/((k+2)^2) ∑_{i=1}^{k+2} ∑_{j=1}^{k+2} BCE(M_{i,j}, M̂_{i,j}), (5)
M ∈ {A, R, C}, M̂ ∈ {Â, R̂, Ĉ}. (6)
由于结构预测独立于查询,该损失每个数据点只计算一次。总损失为:
L = L_pred + λ · L_graph. (7)
最后,模型使用批填充中的虚拟特征进行训练。尽管样本是独立处理的,但模型可以关注这些虚拟特征;经验上,在推理时添加它们可以提高性能,表明它们可能支持中间计算。
## 4 评估
### 4.1 合成玩具示例
首先,我们在工具变量 (IV) 问题上展示 TabPFN-CFM,其 SEM 为:Z → T → Y, U → T, U → Y。生成一个线性 IV SEM,具有已知参数和噪声分布,从而可以计算精确的观测、干预和反事实分布。从该 SEM 中抽取一个样本,我们评估模型在该样本上的预测,包括有和没有图结构作为输入的情况。详见附录 G.1 (https://arxiv.org/html/2606.26467#A7.SS1)。图 2 (https://arxiv.org/html/2606.26467#S4.F2) 和图 5 (https://arxiv.org/html/2606.26467#A7.F5) 分别显示了在没有和有先验结构的情况下三种查询类型的预测分布和精确分布。模型预测在所有情况下都吻合得很好。
参见标题
参见标题
参见标题
图 2:IV 示例的观测(左)、干预(中)和反事实(右)分布。蓝色为精确解,橙色为当 U 被观测时模型的预测,绿色为当 U 未观测时的预测。模型没有获得图结构。表 3 (https://arxiv.org/html/2606.26467#A0.T3) 显示了预测的邻接矩阵。预测的邻接矩阵与真实邻接矩阵紧密匹配,表明模式相似文章
当表格基础模型遇到策略性表格数据:一种先验对齐方法
本文研究了基于预训练先验数据拟合网络的表格基础模型是否能够泛化到个体在部署后修改特征的策略性表格数据。提出了策略性先验数据拟合网络(SPN),这是一个无需重新训练即可将PFN预测与操纵后分布对齐的推理时框架。
TabPFN-3:技术报告
TabPFN-3 是一个新的表格数据基础模型,在合成数据上预训练,可扩展到 100 万训练行,同时减少训练和推理时间,在表格预测、时间序列和关系数据上实现了最先进的性能。
GOTabPFN:从特征排序到紧凑标记化——面向高维数据的表格基础模型
本文介绍了GOTabPFN,一种结合了图引导排序与局部精炼(GO-LR)及神经启发子单元压缩(NSC)的方法,使得小型表格基础模型能够在无需重新训练大型骨干网络的情况下,有效进行高维低样本量预测。
FoundCause: 从观测数据中发现存在潜在混杂因素的因果关系
FoundCause 是一种摊销式因果关系发现模型,能够显式处理潜在混杂因素和缺失数据,在真实数据集上通过单次前向传播即可超越15种现有方法。
迈向连续时间因果基础模型
提出了一个连续性准则,用于将离散时间因果先验数据拟合网络扩展到连续时间,利用随机微分方程(SDE)。引入了分类体系和细网格积分方法,在不规则观测时间表上优于朴素积分方法。