通过混合层蒸馏和关键信息的逐步注意力改进小模型的推理能力

arXiv cs.CL 论文

摘要

本文提出一种新颖的思维链蒸馏框架,通过混合层模块的动态层对齐,将教师模型对关键信息的逐步注意力转移到学生模型中。该方法通过明确指导学生模型在推理过程中逐步聚焦关键信息,在数学和常识推理基准测试中实现了一致的性能提升。

arXiv:2604.15701v1 公告类型:新发布 摘要:大型语言模型计算需求的增加引发了人们对通过思维链(CoT)蒸馏将推理能力转化为小模型的关注。当前的 CoT 蒸馏方法主要关注于将教师生成的复杂推理步骤转移到学生模型中。然而,这些方法未能充分探索教师在推理过程中对关键信息的动态注意。我们发现语言模型在推理过程中会逐步改变对关键信息的注意力,这暗示了得出结论的必要线索。基于这一观察和分析,我们引入了一种新颖的 CoT 蒸馏框架,将教师对关键信息的逐步注意力转移到学生模型中。这为学生在推理过程中逐步专注于关键信息提供了结构化指导。更重要的是,我们开发了一个混合层模块,实现动态对齐,能够适应教师和学生之间不同的层。我们的方法在多个数学和常识推理数据集上实现了一致的性能提升。据我们所知,这是首个在 CoT 蒸馏中利用逐步注意力来改进小模型推理能力的方法。
查看原文 导出为 Word 导出为 PDF
查看缓存全文

缓存时间: 2026/04/20 08:28

# 通过分层混合蒸馏和关键信息逐步注意力改进小型模型的推理能力

来源: https://arxiv.org/html/2604.15701

陈瑶1,2, 盛佳伟1, 张文远1,2, 刘庭文1,2
1中国科学院信息工程研究所 2中国科学院大学网络安全学院
{chenyao2023, shengjiawei, zhangwenyuan, liutingwen}@iie.ac.cn

###### 摘要

大型语言模型的巨大计算需求激发了人们对通过思维链(CoT)蒸馏将推理能力转移到小型模型的兴趣。当前的CoT蒸馏方法主要关注传输教师生成的复杂推理理由给学生模型。然而,它们没有充分探索教师在推理过程中对关键信息的动态注意。我们发现语言模型在推理过程中表现出对关键信息的逐步注意力转移,这意味着推导结论的重要线索。基于这一观察和分析,我们引入了一种新颖的CoT蒸馏框架,将教师对关键信息的逐步注意力转移给学生模型。这为学生在推理过程中对关键信息的逐步集中提供了结构化指导。更重要的是,我们开发了一个分层混合(MoL)模块,支持动态对齐以适应教师和学生之间不同层数的差异。我们的方法在多个数学和常识推理数据集上实现了一致的性能改进。据我们所知,这是第一种在CoT蒸馏中利用逐步注意力来改进小型模型推理的方法。

通过分层混合蒸馏和关键信息逐步注意力改进小型模型的推理能力

陈瑶1,2, 盛佳伟1, 张文远1,2, 刘庭文1,2††脚注:表示通讯作者。
1中国科学院信息工程研究所
2中国科学院大学网络安全学院
{chenyao2023, shengjiawei, zhangwenyuan, liutingwen}@iie.ac.cn

## 1 介绍

复杂推理能力是人类智能的基石,在解决问题、决策制定和世界理解中起着至关重要的作用。最近的进展表明大型语言模型的少样本推理能力得到了显著改进。然而,这些模型的巨大规模需要大量的内存和计算资源,使得在边缘设备上部署变得极其昂贵,阻碍了应用。为了解决这一挑战,CoT蒸馏已成为一种有前景的方法。在复杂推理中,CoT蒸馏方法通常将教师模型生成的逐步理由转移到学生模型,作为知识蒸馏的有效手段。

![参考](a) SVAMP数据集的样本。蒸馏后的学生模型未能充分利用数值信息,导致错误的结果,而教师模型在逐步推理过程中有效利用所有数值信息得出正确的最终结果。

![参考](b) 数学推理中的数值vs非数值标记:水平轴代表推理步骤,垂直轴显示数值和非数值标记分别接收的逐步注意力的相对比例(详见附录B.1)。

![参考](c) Llama3-8B教师模型第13层对图1(a)样本中数值标记的逐步注意力可视化(详见附录D)。水平轴代表数值标记的索引(图1(a)中以红色突出显示的标记),垂直轴代表步骤的索引(图1(a)中的灰色Sx标签)。

图1:关键标记上的逐步注意力隐含编码推理线索:综合分析。

现有的CoT蒸馏方法通常平等对待所有标记,往往忽视复杂推理中的关键信息。我们观察到通过现有方法蒸馏的学生模型在多步推理中难以充分利用关键信息(图1(a))。值得注意的是,语言模型在推理过程中对关键标记分配更多的平均注意力,隐含编码逐步推理的关键线索。例如,数值标记对数学推理直观来说很重要,我们的分析结果表明在这个过程中,教师和学生模型都确实对数值标记比非数值标记分配明显更多的注意力(图1(b))。更重要的是,我们探索了教师模型对这些关键标记的注意力如何在逐步推理中演变,发现注意力分布表现出逐步变化,具有更高的注意力分数分配给每个推理步骤相关的关键标记(图1(c)与图2)。这突显了教师模型在推理过程中逐步捕获关键信息的能力。然而,当前的CoT蒸馏方法直接向学生提供教师模型生成的理由。这种方法未能充分利用上述现象,导致学生逐步捕获和利用关键信息的能力改进不足。

基于上述见解,我们介绍MoLSAKI,一个新颖的CoT蒸馏框架,通过分层混合对齐策略捕获和转移教师模型对关键信息的逐步注意力,以增强学生模型的推理能力。具体来说,我们将每个推理步骤分配给每个关键标记的注意力权重定义为对关键标记的逐步注意力。通过连接这些逐步分布,我们捕获整个推理过程中模型对关键信息的不断演变的关注。基于这一概念,我们随后在CoT蒸馏过程中从教师和学生模型的每一层提取这些逐步注意力映射。对于蒸馏中的层映射,我们设计了分层混合(MoL),灵感来自混合专家(MoE)。MoL促进了教师和学生层之间的自适应加权对齐,从而克服了层数不匹配的蒸馏挑战。总结我们的贡献如下:

- •我们引入了一个新的观点:在推理过程中,大型语言模型表现出对某些关键标记的逐步注意力模式,这种模式隐含编码了逐步推理的宝贵线索。
- •我们提出了一个新颖的思维链蒸馏框架MoLSAKI,引入了对关键标记的逐步注意力概念,并将教师模型对关键信息的逐步动态关注转移给学生模型,从而增强其有效推理的能力。
- •我们设计了MoL以自适应的加权和动态方式对齐不同深度的教师和学生模型层,从而成功克服了它们层数不匹配的挑战。
- •我们的方法在不同教师-学生模型规模下,在数学和常识推理基准测试的域内和域外设置中产生了性能收益。

## 2 相关工作

### 2.1 思维链蒸馏

大型语言模型展示了强大的推理能力,然而它们的巨大规模阻碍了实际部署。最近的工作通过CoT知识转移将推理能力蒸馏到较小的模型中。关键方法包括Fine-tune-CoT的零样本理由提取和DSS的推理/答案预测的多任务分离。后续改进引入了互信息最大化(MMI loss)和基于辅助模型的蒸馏(Mentor-KD)(详见附录A.1)。现有方法忽视推理中的关键信息,并面临logit蒸馏需求的结构约束。我们的方法引入了对关键标记的逐步注意力蒸馏,无需标记器对齐或投影层。

![参考](a) CommonSenseQA数据集的样本。

![参考](b) Qwen2.5-32B教师模型第32层对图2(a)样本中关键标记的逐步注意力可视化。水平轴代表关键标记的索引,垂直轴代表步骤的索引。

图2:对关键标记的逐步注意力模式(详见附录D)。

### 2.2 自注意力蒸馏

先前的方法通过层映射转移自注意力模式:TinyBERT使用统一映射,MobileBERT假设层数相同,MiniLM仅蒸馏最后的层(详见附录A.2)。这些方法需要匹配的注意力维度和固定的层对应关系。我们通过以下方式克服这些局限:1)将蒸馏集中在推理步骤中的关键标记而不是完整的注意力矩阵上,2)通过MoL模块的动态层路由自动选择最优的教师-学生层对,优于刚性映射方法。

![参考]图3:MoLSAKI框架包含三个组件。在示例中,问题和理由总共有13个数值标记和5个步骤。因此教师和学生模型中数值标记的逐步注意力为5×13。

## 3 方法论

MoLSAKI引入了一个新颖的知识蒸馏框架,通过CoT蒸馏和逐步注意力指导的协同整合来增强学生模型的推理能力。具体来说,我们首先准备由教师模型标注的CoT数据并进行CoT蒸馏(§3.1),随后在CoT蒸馏过程中从教师和学生模型中提取关键标记上的逐步注意力(§3.2),最后实施自适应MoL层对齐(§3.3)。

### 3.1 思维链蒸馏

我们通过对教师模型进行少样本提示,为原始数据集D中的每个问题-答案对{q,â}获取CoT数据(详见附录F.3)。教师对每个问题q的回应分为两个组件:理由r和答案a(见图3中的样本)。标记的数据集{q,r,a | q∈D, a=â}将用于学生模型的后续CoT蒸馏。

继续采用Hsieh等人(2023)的方法,我们执行由两个任务组成的CoT蒸馏(图3中的CoT蒸馏模块):1)给定问题q预测最终答案a,2)为相同输入q生成理由r。相应的损失函数如下:

L_pre = E_{q∈D}[L_ce(f(q),a)],
L_exp = E_{q∈D}[L_ce(f(q),r)],

其中f表示学生模型,L_ce表示模型预测和目标标记之间的交叉熵损失。

### 3.2 关键标记上的逐步注意力

相信蒸馏教师在推理过程中对关键标记的逐步注意力比仅转移理由更有影响力,我们在CoT蒸馏过程中引入关键标记上逐步注意力的损失L_att(在等式6中),以指导学生对关键信息的逐步关注。

为了计算损失L_att,我们首先从教师和学生模型中提取关键标记上的逐步注意力(图3中的"提取关键标记上的逐步注意力"模块)。在我们的设计中,"逐步"表示包含问题的推理步骤。如图3中的示例所示,我们基于句号将由问题和理由组成的输入序列分割成推理步骤,得到5个步骤。

教师模型的标记器将由问题和理由组成的输入序列转换为标记序列{x_1^t, x_2^t, ..., x_M^t}。M_1表示按推理步骤分割的所有标记的索引集。其元素特别表示单个推理步骤内所有标记的索引集。利用正则表达式匹配和标记器的映射,我们从标记序列中获得关键标记的索引集,记为M_2。其元素表示原始文本中特定关键词在标记化后对应的关键标记的索引集(详见附录C.1)。

教师模型的第l层随后构造自注意力矩阵I_l^t ∈ R^{M×M}。为了计算关键标记上的逐步注意力,我们首先从I_l^t在关键标记索引处提取列,其中每列代表...

相似文章

AtManRL: 通过可微分注意力显著性实现忠实推理

arXiv cs.CL

AtManRL 是一种通过可微分注意力操作和强化学习来训练大语言模型的方法,旨在确保推理令牌因果地影响最终预测,从而生成更忠实的思维链推理。在 GSM8K 和 MMLU 上使用 Llama-3.2-3B 进行的实验表明,该方法能够识别具有影响力的推理令牌并提高推理透明度。