Improving Reasoning Capabilities in Small Models through Mixture-of-Layers Distillation with Stepwise Attention on Key Information
Summary
This paper proposes a novel Chain-of-Thought distillation framework that transfers teacher models' stepwise attention on key information to student models through a Mixture-of-Layers module for dynamic layer alignment. The method achieves consistent performance improvements on mathematical and commonsense reasoning benchmarks by explicitly guiding student models to progressively focus on critical information during reasoning.
View Cached Full Text
Cached at: 04/20/26, 08:28 AM
# Improving Reasoning Capabilities in Small Models through Mixture-of-Layers Distillation with Stepwise Attention on Key Information
Source: https://arxiv.org/html/2604.15701
Yao Chen1,2, Jiawei Sheng1, Wenyuan Zhang1,2, Tingwen Liu1,2
1Institute of Information Engineering, Chinese Academy of Sciences
2School of Cyber Security, University of Chinese Academy of Sciences
{chenyao2023, shengjiawei, zhangwenyuan, liutingwen}@iie.ac.cn
###### Abstract
The significant computational demands of large language models have increased interest in distilling reasoning abilities into smaller models via Chain-of-Thought (CoT) distillation. Current CoT distillation methods mainly focus on transferring teacher-generated rationales for complex reasoning to student models. However, they do not adequately explore teachers' dynamic attention toward critical information during reasoning. We find that language models exhibit progressive attention shifts towards key information during reasoning, which implies essential clues for drawing conclusions. Building on this observation and analysis, we introduce a novel CoT distillation framework that transfers the teacher's stepwise attention on key information to the student model. This establishes structured guidance for the student's progressive concentration on key information during reasoning. More importantly, we develop a Mixture of Layers module enabling dynamic alignment that adapts to different layers between the teacher and student. Our method achieves consistent performance improvements across multiple mathematical and commonsense reasoning datasets. To our knowledge, it is the first method to leverage stepwise attention within CoT distillation to improve small model reasoning.
## 1 Introduction
The ability of complex reasoning is a cornerstone of human intelligence, playing a crucial role in problem-solving, decision-making, and world understanding (Cobbe et al. 2021; Chu et al. 2024; Plaat et al. 2024). Recent advances have shown substantial improvements in the few-shot reasoning abilities of large language models. However, the immense scale of these models demands enormous memory and computational resources, making them prohibitively expensive to deploy on edge devices and impeding applications (Liu et al. 2024; Hu et al. 2024). To address this challenge, CoT distillation (Ho et al. 2023; Fu et al. 2023; Li et al. 2023; Hsieh et al. 2023) has emerged as a promising approach. In complex reasoning, CoT distillation methods typically transfer the step-by-step rationales generated by the teacher model to the student model, serving as an effective means of knowledge distillation.
**Figure 1a:** A sample from the SVAMP dataset. The distilled student model fails to adequately utilize numerical information, leading to erroneous results, whereas the teacher model, during stepwise reasoning, effectively utilizes all numerical information to arrive at the correct final result.
**Figure 1b:** Numerical vs. Non-Numerical Tokens in Mathematical Reasoning: The horizontal axis represents the reasoning steps, and the vertical axis shows the relative proportion of stepwise attention received by numerical and non-numerical tokens, respectively (details in Appendix B.1).
**Figure 1c:** Visualization of stepwise attention on numerical tokens from the 13th layer of the teacher model Llama3-8B for the sample in Figure 1a (details in Appendix D). The horizontal axis represents the indices of numerical tokens (the tokens highlighted in red in sample Figure 1a), and the vertical axis represents the indices of steps (the grey Sx labels in Figure 1a).
**Figure 1:** Stepwise attention on critical tokens implicitly encodes reasoning clues: A comprehensive analysis.
Existing CoT distillation methods typically treat all tokens equally, often neglecting critical information for complex reasoning. We observe that the student models distilled via existing methods struggle to fully utilize key information across multi-step reasoning (Figure 1a). Notably, language models allocate more average attention to critical tokens during reasoning, implicitly encoding key clues for stepwise reasoning. For example, numerical tokens are intuitively crucial for mathematical reasoning, and our analysis results indicate that they indeed receive significantly more attention than non-numerical tokens during this process in both teacher and student models (Figure 1b). More importantly, we explore how the teacher model's attention to these critical tokens evolves during stepwise reasoning, and find that the attention distribution exhibits stepwise changes, with higher attention scores assigned to the critical tokens relevant to each reasoning step (Figure 1c & Figure 2). This highlights the teacher model's ability to progressively capture key information during reasoning. However, current CoT distillation methods directly provide the rationales generated by the teacher model to the student. This approach fails to fully exploit the aforementioned phenomena, leading to a failure in improving the student's ability to progressively capture and utilize key information.
Building on the above insights, we introduce **MoLSAKI**, a novel CoT distillation framework that captures and transfers the teacher model's **S**tepwise **A**ttention on **K**ey **I**nformation to enhance the student model's reasoning capabilities via a **M**ixture-**of**-**L**ayers alignment strategy. Specifically, we define stepwise attention on critical tokens as the attention weights assigned to each critical token at each reasoning step. By concatenating these per-step distributions, we capture the model's evolving focus on key information throughout the entire reasoning process. Building on this concept, we then extract these stepwise attention maps from every layer of both the teacher and student models during the CoT distillation. For layer mapping in distillation, we design Mixture-of-Layers (MoL), drawing inspiration from Mixture-of-Experts (MoE) (Zhou et al. 2022; Jin et al. 2024). MoL facilitates adaptive weighted alignment between teacher and student layers, thereby overcoming the distillation challenge of mismatched layer counts. In summary, our contributions are as follows:
- We introduce a new perspective: during the reasoning process, large language models exhibit a progressive attention pattern towards certain critical tokens, a pattern that implicitly encodes valuable clues for stepwise reasoning.
- We propose a novel chain-of-thought distillation framework, MoLSAKI, which introduces the concept of stepwise attention on critical tokens and transfers the teacher model's progressive, dynamic focus on key information to the student model, thereby enhancing its capacity for effective reasoning.
- We design MoL to adaptively align layers between teacher and student models of different depths in a weighted and dynamic manner, thereby successfully overcoming the challenge of their mismatched layer counts.
- Our method yields performance gains in in-domain and out-of-domain settings across varying teacher-student model scales on mathematical and commonsense reasoning benchmarks.
## 2 Related Work
### 2.1 Chain-of-Thought Distillation
Large language models (LLMs) demonstrate strong reasoning capabilities (Kojima et al. 2022; Wei et al. 2022), yet their massive scale hinders practical deployment. Recent work distills reasoning abilities into smaller models through CoT knowledge transfer (Ho et al. 2023; Hsieh et al. 2023; Fu et al. 2023; Li et al. 2023). Key approaches include Fine-tune-CoT's zero-shot rationale extraction (Ho et al. 2023) and DSS's multi-task separation of reasoning/answer prediction (Hsieh et al. 2023). Subsequent improvements introduce mutual information maximization (MMI loss, Chen et al. 2024) and auxiliary model-based distillation (Mentor-KD, Lee et al. 2024) (details in Appendix A.1). Existing methods neglect key information in reasoning and face structural constraints from logit distillation requirements (Lee et al. 2024; Zhang et al. 2024b). Our approach introduces stepwise attention on critical tokens distillation without requiring tokenizer alignment or projection layers.
### 2.2 Self-Attention Distillation
Prior methods transfer self-attention patterns via layer mapping: TinyBERT (Jiao et al. 2020) uses uniform mapping, MobileBERT (Sun et al. 2020) assumes identical layer counts, and MiniLM (Wang et al. 2020) distills only final layers (details in Appendix A.2). These methods require matched attention dimensions and fixed layer correspondences. We overcome these limitations by 1) focusing distillation on critical tokens in reasoning steps instead of full attention matrices, and 2) using dynamic layer routing via MoL modules to automatically select optimal teacher-student layer pairs, outperforming rigid mapping approaches.
**Figure 2a:** A sample from the CommonSenseQA dataset.
**Figure 2b:** Visualization of stepwise attention on critical tokens from the 32nd layer of the teacher model Qwen2.5-32B for the sample in Figure 2a. The horizontal axis represents the indices of critical tokens, and the vertical axis represents the indices of steps.
**Figure 2:** Progressive attention pattern on critical tokens (details in Appendix D).
## 3 Methodology
MoLSAKI introduces a novel knowledge distillation framework that enhances the reasoning of the student model through synergistic integration of CoT distillation and stepwise attention guidance. Specifically, we first prepare CoT data annotated by the teacher model and conduct CoT distillation (§3.1), subsequently extract stepwise attention on critical tokens from the teacher and student models in the process of CoT distillation (§3.2), and finally implement adaptive MoL layer alignment (§3.3).
### 3.1 CoT Distillation
We obtain CoT data for each question-answer pair {q, â} in a raw dataset D by few-shot prompting the teacher model (details in Appendix F.3). The teacher's response to each question q is divided into two components: rationale r and answer a (see the sample in Figure 3). The labeled dataset {q, r, a | q ∈ D, a = â} will be used for the subsequent CoT distillation of the student model.
Following Hsieh et al. (2023), we perform CoT distillation comprising two tasks (CoT Distillation module in Figure 3): 1) final answer prediction a given a question q and 2) rationale r generation for the same input q. The respective loss functions are as follows:
ℒ_pre = 𝔼_{q∈D}[ℒ_ce(f(q), a)],
ℒ_exp = 𝔼_{q∈D}[ℒ_ce(f(q), r)],
where f denotes the student model and ℒ_ce denotes the cross-entropy loss between model predictions and target tokens.
### 3.2 Stepwise Attention on Critical Tokens
Believing that distilling the teacher's stepwise attention on critical tokens during reasoning is more impactful than simply transferring rationales, we introduce the loss ℒ_att (in Eq. 6) of stepwise attention on critical tokens during CoT distillation to guide the student's progressive focus on key information.
To compute the loss ℒ_att, we first extract stepwise attention on critical tokens from both the teacher and student models (Extract Stepwise Attention on Critical Tokens module in Figure 3). In our design, "stepwise" denotes reasoning steps incorporating the question. As shown in the example in Figure 3, we segment the input sequence composed of question and rationale into reasoning steps based on periods, resulting in 5 steps.
The teacher model's tokenizer converts the input sequence composed of question and rationale into a token sequence {x₁^t, x₂^t, ..., x_M^t}. M₁ denotes the index set of all tokens partitioned by reasoning steps. Its element specifically denotes the index set of all tokens within a single reasoning step. Utilizing regular expression matching and the tokenizer's mapping, we obtain the index set of critical tokens from the token sequence, denoted as M₂. Its element denotes the index set of critical tokens corresponding to a specific critical word in the original text after tokenization (details in Appendix C.1).
The l-th layer of the teacher model subsequently constructs the self-attention matrix I_l^t ∈ ℝ^(M×M). To compute stepwise attention on critical tokens, we first extract columns from I_l^t at the indices of critical tokens, where each column represents attention received by a specific critical token. We then aggregate the attention weights by reasoning steps using mask matrices derived from M₁, resulting in stepwise attention on critical tokens A_l^t ∈ ℝ^(S×K), where S denotes the number of reasoning steps and K denotes the number of critical tokens (details in Appendix C.1).
**Figure 3:** The MoLSAKI framework consists of three components. In the example, the question and rationale have 13 numerical tokens and 5 steps in total. Thus, the stepwise attention on numerical tokens in both teacher and student models is 5 × 13.Similar Articles
How to Fine-Tune a Reasoning Model? A Teacher-Student Cooperation Framework to Synthesize Student-Consistent SFT Data
This paper introduces TESSY, a teacher-student cooperative framework for fine-tuning reasoning models that generates on-policy SFT data by decoupling generation into capability tokens (from teacher) and style tokens (from student), addressing catastrophic forgetting issues when using off-policy teacher data.
Teaching Thinking Models to Reason with Tools: A Full-Pipeline Recipe for Tool-Integrated Reasoning
This paper presents a full-pipeline recipe for teaching thinking models to reason with tools, achieving state-of-the-art performance on benchmarks like AIME 2025 when applied to Qwen3 models.
AtManRL: Towards Faithful Reasoning via Differentiable Attention Saliency
AtManRL is a method that uses differentiable attention manipulation and reinforcement learning to train LLMs to generate more faithful chain-of-thought reasoning by ensuring reasoning tokens causally influence final predictions. Experiments on GSM8K and MMLU with Llama-3.2-3B demonstrate the approach can identify influential reasoning tokens and improve reasoning transparency.
Learning to Reason with Insight for Informal Theorem Proving
This paper proposes DeepInsightTheorem, a hierarchical dataset and Progressive Multi-Stage SFT training strategy to improve LLMs' informal theorem proving by teaching them to identify and apply core techniques through insight-aware reasoning.
Disentangling Mathematical Reasoning in LLMs: A Methodological Investigation of Internal Mechanisms
This paper investigates how large language models perform arithmetic operations by analyzing internal mechanisms through early decoding, revealing that proficient models exhibit a clear division of labor between attention and MLP modules in reasoning tasks.