基于语义损失的微调方法以防止因果推理中的模型崩溃
摘要
本文指出了标准微调在因果推理任务中存在的“模型崩溃”问题,并提出了一种结合基于图的逻辑约束的语义损失函数来防止该现象。
arXiv:2605.05438v1 公告类型:新论文
摘要:在对因果推理任务进行标准微调时,Transformer 模型会出现灾难性的模型崩溃现象,即模型学会了无关输入结构的平凡解(例如始终预测“是”或“否”)。我们证明,若不使用语义损失而在传递性和 d-分离(d-separation)任务上对 Gemma 270M 进行微调,会导致 100% 的崩溃率;此时模型虽然表现出误导性的较高准确率(73.9%),但实际上并未学习到任何因果推理能力。我们提出了一种结合基于图的逻辑约束和动态 lambda 调度机制的语义损失函数,以阻止这种崩溃。我们的方法在传递性任务上达到了 70.4% 的准确率,在 d-分离任务上达到了 68.6% 的准确率,并实现了稳定且依赖上下文的预测,相比崩溃基线提升了 42.7%。在 1,000 个结构推理样本上的对抗性评估显示,语义模型的准确率为 67%-70%,而崩溃模型的表现则灾难性地降至 43%-71%。我们通过在五个模型变体上的 20 万多个评估样本进行的全面基准测试验证了上述发现,证明在 Transformer 中实现稳定的因果推理时,语义损失至关重要而非可有可无。
查看缓存全文
缓存时间: 2026/05/08 07:22
# 关于防止因果推理中模型崩溃的语义损失微调方法
来源:https://arxiv.org/html/2605.05438
###### 摘要
在因果推理任务上对 Transformer 模型进行标准微调会导致灾难性的模型崩溃,其中模型学习到的是琐碎的解决方案,例如无论输入结构如何,始终预测“是”或“否”。我们证明,在对 Gemma 270M 进行传递性和 d-分离任务的微调时,如果不使用语义损失,会导致 100% 的崩溃率,模型虽然取得了误导性的高准确率(73.9%),但并未学习到任何因果推理能力。我们提出了一种带有基于图的逻辑约束和动态 Lambda 调度的语义损失函数,以防止这种崩溃。我们的方法在传递性任务中达到了 70.4% 的准确率,在 d-分离任务中达到了 68.6% 的准确率,并实现了稳定、依赖于上下文的预测,相比崩溃的基线提升了 42.7%。在对 1,000 个结构推理样本的对抗性评估中,语义模型实现了 67-70% 的准确率,而崩溃的模型则灾难性地失败,准确率仅为 43-71%。我们通过五个模型变体在 200,000+ 评估样本上的全面基准测试验证了我们的发现,表明对于 Transformer 中稳定的因果推理而言,语义损失是必不可少的,而非可有可无的。
## 1 引言
因果推理——理解并对因果关系进行推理的能力——是人类认知的基础,对于开发强大的 AI 系统也日益关键\[2 (https://arxiv.org/html/2605.05438#bib.bib2)\]。最近的进展表明,Transformer 可以通过在因果公理的合成演示上进行公理化训练来学习因果推理\[1 (https://arxiv.org/html/2605.05438#bib.bib1)\]。然而,通过系统性实验,我们识别出一种关键且此前未记录的故障模式:在因果推理任务上进行标准微调会导致发生率高达 100% 的灾难性模型崩溃。
### 1.1 崩溃问题
我们将模型崩溃定义为一种退化的学习结果,其中模型的预测分布 $P(y|x)$ 变得独立于输入结构 $x$,收敛为固定输出(始终为“是”或始终为“否”),而不管因果图拓扑如何。通过对 Gemma 270M 模型\[4 (https://arxiv.org/html/2605.05438#bib.bib4)\]的全面实验,我们证明了:
- • 传递性崩溃:模型对所有输入输出“是”(10,000/10,000 次预测),准确率为 27.7%
- • d-分离崩溃:模型对几乎所有输入输出“否”,达到了误导性的高准确率(73.9%),但 F1 分数极低(7.6%)
在不使用语义损失的情况下,100% 的微调尝试都会发生这种崩溃,使得标准方法在因果推理任务中根本不可靠。
### 1.2 我们的贡献
1. 1. 问题识别:首次系统记录了因果推理微调中的灾难性模型崩溃,在传递性和 d-分离任务中发生率均为 100%
2. 2. 理论框架:预测偏差崩溃的形式化定义,以及仅交叉熵损失为何在因果推理中失败的分析
3. 3. 解决方案方法:结合基于图的逻辑约束和动态 Lambda 调度($\lambda:0.05\rightarrow 0.30$)的语义损失函数
4. 4. 全面评估:在 200,000+ 样本上的基准测试,显示相比崩溃基线提升了 42.7%,并在两个不同的因果推理任务中得到验证
5. 5. 对抗性验证:新颖的测试套件证明语义模型学习到了结构推理(67-70% 准确率),而崩溃模型则灾难性地失败(43-71%)
## 2 相关工作
### 2.1 神经网络中的因果推理
因果推理在因果发现\[2 (https://arxiv.org/html/2605.05438#bib.bib2)\]、效应估计和反事实推理的背景下得到了广泛研究。最近的工作探索了通过各种方法向神经网络传授因果概念:符号演示\[1 (https://arxiv.org/html/2605.05438#bib.bib1)\]、因果图生成和基于干预的学习。
Vashishtha 等人\[1 (https://arxiv.org/html/2605.05438#bib.bib1)\]证明,从零开始在公理化演示上训练的 6700 万参数 Transformer 可以泛化到复杂的因果结构。他们的工作表明,在具有足够架构容量下从零开始训练时,在传递性和 d-分离任务上表现出强大的性能。我们的工作通过识别微调预训练模型时的关键故障模式并开发防止崩溃的解决方案,对此进行了扩展。
### 2.2 语义损失与神经符号集成
语义损失函数通过可微约束满足将符号知识纳入神经网络训练\[3 (https://arxiv.org/html/2605.05438#bib.bib3)\]。核心方法使用权重模型计数来计算关于逻辑公式满足度的梯度。应用包括半监督学习、结构化预测和知识库补全。
我们的工作专门针对因果图约束调整了语义损失,开发了动态调度机制,以在微调期间平衡稳定性和结构学习。
### 2.3 模型崩溃现象
模式崩溃在生成对抗网络(GANs)\[5 (https://arxiv.org/html/2605.05438#bib.bib5)\]中得到了广泛研究,其中生成器学习产生有限的多样性。表示崩溃发生在对比学习\[6 (https://arxiv.org/html/2605.05438#bib.bib6)\]中,当嵌入收敛到常数向量时。最近的工作已识别出大语言模型在指令微调和基于人类反馈的强化学习(RLHF)期间的崩溃现象\[7 (https://arxiv.org/html/2605.05438#bib.bib7)\]。
我们识别出的崩溃存在根本区别:它发生在具有清晰地面真值的明确定义的推理任务的监督微调期间,并表现为极端的预测偏差,而不是表示退化。据我们所知,这是首次系统记录因果推理微调中的崩溃现象。
### 2.4 因果推理的评估
最近的基准测试评估了语言模型中的因果推理能力,包括用于因果阶梯问题的 CLADDER\[8 (https://arxiv.org/html/2605.05438#bib.bib8)\]和用于从相关性推断因果关系的 Corr2Cause\[9 (https://arxiv.org/html/2605.05438#bib.bib9)\]。这些基准主要评估预训练或提示模型,而不是微调系统。
我们的对抗性评估方法专门针对结构理解与表面启发式之间的区别,提供了识别崩溃的诊断工具。
## 3 问题公式化
### 3.1 因果推理任务
我们关注基于 Pearl 因果框架\[2 (https://arxiv.org/html/2605.05438#bib.bib2)\]的两个基本因果推理任务:
##### 传递性
给定表示因果关系的有向无环图(DAG)$G=(V,E)$,确定从节点 $A$ 到节点 $B$ 是否存在有向路径。形式上,传递性公理表述为:
$$\forall A,B,C\in V:(A\rightarrow C)\wedge(C\rightarrow B)\implies(A\rightarrow B) \quad (1)$$
##### d-分离
确定在因果 DAG $G$ 中,给定条件集合 $Z$,节点 $X$ 和 $Y$ 是否条件独立,遵循 Pearl 的 d-分离准则。如果 $X$ 和 $Y$ 之间的所有路径都被 $Z$ 阻塞,则节点 $X$ 和 $Y$ 被 $Z$ d-分离。
### 3.2 形式化问题设置
令 $\mathcal{D}=\{(p_i,h_i,y_i)\}_{i=1}^N$ 表示训练数据集,其中:
- • $p_i$: 描述因果图结构的文本前提
- • $h_i$: 关于因果关系的二元假设查询
- • $y_i\in\{\text{Yes},\text{No}\}$: 地面真值标签
模型 $f_\theta:(p,h)\rightarrow\mathbb{R}^2$ 将前提-假设对映射到 logits,我们通过 softmax 从中计算预测概率:$P_\theta(y|p,h)=\text{softmax}(f_\theta(p,h))$。
### 3.3 模型崩溃:形式化定义
###### 定义 1(预测偏差崩溃)。
如果存在固定预测 $\bar{y}$,使得对于评估数据集 $\mathcal{D}_\text{eval}$:
$$\frac{1}{\|\mathcal{D}_\text{eval}\|}\sum_{(p,h,y)\in\mathcal{D}_\text{eval}}\mathbb{1}[\arg\max P_\theta(y|p,h)=\bar{y}]>0.95 \quad (2)$$
则模型 $f_\theta$ 在任务 $\mathcal{T}$ 上表现出预测偏差崩溃。
崩溃指标:
- • 极端预测偏差:>95% 的预测属于同一类别
- • 分布独立性:预测对图结构变化保持不变
- • 指标发散:在偏差数据集上准确率高,F1 分数接近零
## 4 方法论
### 4.1 用于因果图的语义损失
我们用语义组件增强标准交叉熵损失,以强制与因果图结构的逻辑一致性:
$$\mathcal{L}_\text{total}=\mathcal{L}_\text{CE}(y,\hat{y})+\lambda(t)\cdot\mathcal{L}_\text{semantic}(p,h,\hat{y}) \quad (3)$$
其中 $\mathcal{L}_\text{CE}$ 是交叉熵,$\hat{y}=P_\theta(y|p,h)$ 是预测概率,$\lambda(t)$ 是时间依赖的权重因子。
#### 4.1.1 基于图的一致性
对于传递性任务,我们解析前提 $p$ 以提取因果图 $G=(V,E)$ 并计算逻辑一致性:
$$c(p,h,\hat{y})=\begin{cases}P_\theta(y=\text{Yes}|p,h) & \text{if path exists in }G\\ P_\theta(y=\text{No}|p,h) & \text{otherwise}\end{cases} \quad (4)$$
语义损失惩罚与图结构的不一致性:
$$\mathcal{L}_\text{semantic}=-\frac{1}{N}\sum_{i=1}^N\log(c(p_i,h_i,\hat{y}_i)+\epsilon) \quad (5)$$
其中 $\epsilon=10^{-8}$ 防止数值不稳定。
对于 d-分离,一致性基于路径阻塞计算:如果节点未 d-分离,则 $c(p,h,\hat{y})=P(y=\text{Yes})$,否则为 $P(y=\text{No})$。
#### 4.1.2 动态 Lambda 调度
为了防止崩溃同时保持训练稳定性,我们采用动态 Lambda 调度:
$$\lambda(t)=\lambda_\text{start}+\frac{t}{T}(\lambda_\text{end}-\lambda_\text{start}) \quad (6)$$
其中 $t$ 是当前训练步骤,$T$ 是总步骤,$\lambda_\text{start}=0.05$,$\lambda_\text{end}=0.30$。
设计原理:
- • 初始 $\lambda$ 较低:防止在早期训练阶段与交叉熵信号冲突
- • 逐渐增加:允许模型在学习严格结构约束之前学习基本模式
- • 最终强度:足以防止退化解决方案,同时保持梯度流
算法 1:使用语义损失训练
1: 输入:数据集 $\mathcal{D}$,模型 $f_\theta$,轮数 $E$,批量大小 $B$
2: 参数:$\lambda_\text{start}=0.05$,$\lambda_\text{end}=0.30$
3: $T \leftarrow$ 总训练步骤
4: for 轮次 $e=1$ 到 $E$ do
5: for $\mathcal{D}$ 中的每个批量 $(p,h,y)$ do
6: $t \leftarrow$ 当前步骤
7: $\lambda \leftarrow \lambda_\text{start}+\frac{t}{T}(\lambda_\text{end}-\lambda_\text{start})$
8: $\hat{y} \leftarrow f_\theta(p,h)$
9: $\mathcal{L}_\text{CE} \leftarrow -\sum y\log\hat{y}$
10: $\mathcal{L}_\text{sem} \leftarrow$ ComputeSemanticLoss$(p,h,\hat{y})$
11: $\mathcal{L} \leftarrow \mathcal{L}_\text{CE}+\lambda\cdot\mathcal{L}_\text{sem}$
12: 通过对 $\mathcal{L}$ 进行梯度下降更新 $\theta$
13: end for
14: end for
### 4.2 训练配置
表 1:训练超参数
### 4.3 评估方法
我们在六个测试分布上评估模型,每个分布包含 10,000 个样本,对抗性测试除外(1,000 个样本):
##### 标准泛化测试
- • 长度:7-15 个节点的因果链(训练:3-6 个节点)
- • 分支:分支因子为 1.4-2.0 的 DAG
- • 反转:所有有向边反转
- • 打乱:前提语句随机排序
- • 长名称:8-10 个字符的变量名(训练:1-3 个字符)
##### 对抗性结构测试
新颖的评估集(1,000 个样本),旨在区分结构理解与启发式方法:
- • 无关节点(30%):没有路径通向查询变量的额外节点
- • 断链(30%):缺少单个边的传递性链
- • 长链(40%):需要多次应用公理的扩展传递性
##### 评估指标
除标准准确率外,我们还计算:
- • F1 分数、精确率和召回率
- • 预测分布分析(是/否计数)
- • 混淆矩阵
- • 每项任务的性能细分
## 5 实验结果
### 5.1 实验设置
所有实验均使用 Gemma 3 270M Instruct-tuned 模型作为基础。我们训练了五个模型变体:
1. 1. 标准 Gemma:零样本基线(无微调)
2. 2. 传递性 V1:在传递性上微调,无语义损失
3. 3. d-分离 V1:在 d-分离上微调,无语义损失
4. 4. 传递性语义 V4:使用动态语义损失进行微调
5. 5. d-分离语义 V2:使用动态语义损失进行微调
训练数据包括每个任务 50,000 个合成生成的示例,遵循\[1 (https://arxiv.org/html/2605.05438#bib.bib1)\]的公理化训练方法,并在图结构上增强了多样性。
### 5.2 标准微调中的模型崩溃
表 2 (https://arxiv.org/html/2605.05438#S5.T2) 展示了在没有语义损失训练的所有模型中发生的灾难性崩溃(100%)。
表 2:每个模型在 50,000 个评估样本上的模型崩溃证据。预测模式显示是/否计数(以千计)。V1 模型表现出 100% 的崩溃率,具有极端的预测偏差。
#### 5.2.1 崩溃分析:传递性 V1
传递性 V1 表现出完全崩溃,始终预测“是”:
- • 预测分布:在所有五个测试集上均为 10,000 个“是” / 0 个“否”
- • 准确率方差:0.15%(打乱)到 100%(长度)——完全由标签分布决定
- • 结构独立性:预测不受图拓扑、边反转或节点添加的影响
- • F1 悖论:尽管准确率为 27.7%,但平均 F1 为 31.9%,表明召回率为 100% 但精确率差
#### 5.2.2 崩溃分析:d-分离 V1
d-分离 V1 表现出相反的崩溃(始终为“否”):
- • 预测分布:0-1,889 个“是” / 8,111-10,000 个“否”
- • 误导性准确率:73.9% 的平均准确率掩盖了灾难性的失败
- • F1 揭示真相:7.6% 的 F1 分数暴露了极端的召回率失败(平均 8.6%)
- • 测试集偏差:高准确率源于偏向“否”的标签分布,而非学习到的推理
关键见解:仅凭准确率是不够的——F1、精确率、召回率和预测分布分析对于检测崩溃至关重要。
### 5.3 语义损失防止崩溃
表 3 (https://arxiv.org/html/2605.05438#S5.T3) 展示了证明防止崩溃的全面结果。
表 3:每项任务的准确率比较(每项 10,000 个样本。显示传递性任务;d-分离...相似文章
Vernier: 探究因果推理中词汇缺口背后的表征错位
本文探究了为何指令调优的语言模型在将变量名替换为占位符后,对因果推理问题给出不同答案,发现问题源于表征错位而非信息丢失。作者引入了Vernier方法,通过配对视图权重更新和机制检查,揭示出答案相关内容在占位符视图中仍然存在但错位。
当推理收敛时停止:保留语义的推理模型提前退出
本文介绍 PUMA,一个即插即用框架,通过检测思维链推理中的语义冗余实现提前退出,在多个模型和基准测试中平均减少 26.2% 的 Token,同时保持准确性和推理质量。
为什么推理模型会失去覆盖率?数据与路径分岔的作用
本文研究了推理模型在监督微调过程中失去覆盖率的原因,将这一现象与训练数据中存在多个有效路径的决策点联系起来,并提出数据合成和多样性感知解码作为缓解措施。
使用合成理由数据的监督微调损害了现实世界疾病预测
本文证明,与仅使用标签的微调相比,在阿尔茨海默病检测中,使用合成理由数据进行监督微调在多种配置和模型家族中始终损害预测性能。尽管理由质量很高,这种退化仍然存在,并归因于叙事合理性与判别优化之间的结构性冲突。
风险链条:大型推理模型中的安全失效及通过自适应多原则引导进行缓解
本文研究了大型推理模型中的安全失效问题,即尽管最终答案安全,但推理轨迹中仍会出现有害内容,并提出了一种自适应多原则引导方法来缓解这些风险。