AtManRL: Towards Faithful Reasoning via Differentiable Attention Saliency

arXiv cs.CL Papers

Summary

AtManRL is a method that uses differentiable attention manipulation and reinforcement learning to train LLMs to generate more faithful chain-of-thought reasoning by ensuring reasoning tokens causally influence final predictions. Experiments on GSM8K and MMLU with Llama-3.2-3B demonstrate the approach can identify influential reasoning tokens and improve reasoning transparency.

arXiv:2604.16158v1 Announce Type: new Abstract: Large language models (LLMs) increasingly rely on chain-of-thought (CoT) reasoning to solve complex tasks. Yet ensuring that the reasoning trace both contributes to and faithfully reflects the processes underlying the model's final answer, rather than merely accompanying it, remains challenging. We introduce AtManRL, a method that leverages differentiable attention manipulation to learn more faithful reasoning through reinforcement learning. By training an additive attention mask that identifies tokens in the CoT crucial for producing correct answers, we derive a saliency reward signal that encourages the model to generate reasoning traces that genuinely influence its final predictions. We integrate this saliency reward with outcome-based rewards within the GRPO framework to jointly optimize for correctness and interpretability. Experiments on GSM8K and MMLU with Llama-3.2-3B-Instruct demonstrate that our approach can identify influential reasoning tokens and enable training more transparent reasoning models.
Original Article
View Cached Full Text

Cached at: 04/20/26, 08:29 AM

# AtManRL: Towards Faithful Reasoning via Differentiable Attention Saliency
Source: https://arxiv.org/html/2604.16158
Max Henning Höth Aleph Alpha Research Lab1141

&Kristian Kersting TU Darmstadt Hessian\.AI Lab1141

&Björn Deiseroth Aleph Alpha Research Lab1141

&Letitia Parcalabescu Aleph Alpha Research Lab1141

###### Abstract

Large language models \(LLMs\) increasingly rely on chain\-of\-thought \(CoT\) reasoning to solve complex tasks\. Yet ensuring that the reasoning trace both contributes to and faithfully reflects the processes underlying the model’s final answer, rather than merely accompanying it, remains challenging\. We introduce AtManRL, a method that leverages differentiable attention manipulation to learn more faithful reasoning through reinforcement learning\. By training an additive attention mask that identifies tokens in the CoT crucial for producing correct answers, we derive a saliency reward signal that encourages the model to generate reasoning traces that genuinely influence its final predictions\. We integrate this saliency reward with outcome\-based rewards within the GRPO framework to jointly optimize for correctness and interpretability\. Experiments on GSM8K and MMLU with Llama\-3\.2\-3B\-Instruct demonstrate that our approach can identify influential reasoning tokens and enable training more transparent reasoning models\.

## 1 Introduction

Chain\-of\-thought \(CoT\) prompting \(Wei et al\., 2022 (https://arxiv.org/html/2604.16158#bib.bib6)\), supervised learning and reinforcement learning \(RL\) approaches \(Yang et al\., 2025 (https://arxiv.org/html/2604.16158#bib.bib23); OpenAI et al\., 2024 (https://arxiv.org/html/2604.16158#bib.bib24); Guo et al\., 2025 (https://arxiv.org/html/2604.16158#bib.bib25)\) eliciting reasoning traces have improved the reasoning abilities of large language models \(LLMs\)\. By generating intermediate reasoning steps before the final answer, models often reach higher accuracy on complex tasks\.

The presence of a reasoning trace, however, does not guarantee that the model actually uses it to arrive at its answer\. Consequently, a central question is: *Does the generated CoT causally influence the model’s final prediction and have explanatory power, or does it merely accompany it as a stylistic artifact?* This question relates to the notion of *faithfulness*, which asks whether an explanation reflects the model’s true decision\-making process \(Jacovi and Goldberg, 2020 (https://arxiv.org/html/2604.16158#bib.bib22)\)\. An unfaithful reasoning trace may appear plausible and logically coherent while the model reaches the correct answer through shortcuts that bypass the stated reasoning \(Agarwal et al\., 2024 (https://arxiv.org/html/2604.16158#bib.bib29)\)\. Prior work shows that LLMs can produce plausible\-sounding CoT explanations that do not align with the mechanisms that drive their predictions \(Turpin et al\., 2023 (https://arxiv.org/html/2604.16158#bib.bib26); Lanham et al\., 2023 (https://arxiv.org/html/2604.16158#bib.bib20); Barez et al\., 2025 (https://arxiv.org/html/2604.16158#bib.bib27)\)\.

To investigate this gap, we distinguish between *saliency* and *faithfulness*\. We define saliency as the measurable causal contribution of individual reasoning tokens to the final answer logits\. Faithfulness requires more, namely, the reasoning trace must accurately reflect the latent reasoning that produces the answer\. Saliency, therefore, constitutes a necessary but not sufficient condition for faithfulness\. Ensuring saliency of the reasoning trace, defined as the measurable influence of reasoning tokens on the final prediction, prevents CoT from degenerating into lengthy yet weakly relevant narratives\. Without such constraints, reasoning traces risk functioning as post\-hoc rationalizations rather than interpretable evidence of the model’s computation\.

Guided by this distinction, we propose AtManRL to enforce reasoning trace saliency, a method that explicitly trains models to produce salient reasoning traces using reinforcement learning\. Our approach builds on AtMan \(Deiseroth et al\., 2023 (https://arxiv.org/html/2604.16158#bib.bib5)\), an attention manipulation technique that allows targeted modification of attention weights through a predefined mask\. Whereas prior work uses AtMan for post\-hoc interpretability, we instead treat the attention manipulation mask as a learnable, differentiable object\. This allows us to: \(i\) *efficiently* identify which tokens in the reasoning trace are truly influential for the final answer, \(ii\) derive a saliency\-based reward signal from these contributions, and \(iii\) incorporate this signal into reinforcement learning to encourage the generation of salient reasoning steps while discouraging extraneous or weakly relevant explanatory content\.

Overall, our contributions are: \(1\) We introduce a **saliency reward** derived from optimizing a differentiable attention that identifies salient tokens in the CoT\. \(2\) We combine this saliency reward with outcome\-based rewards in the GRPO framework to **jointly train for correctness and reasoning quality** in terms of saliency\. \(3\) We evaluate our method on GSM8K and MMLU using Llama\-3\.2\-3B\-Instruct and show that we can **reduce extraneous reasoning while preserving accuracy**\.

## 2 Related Work

**CoT / Reasoning Traces.** CoT prompting \(Wei et al\., 2022 (https://arxiv.org/html/2604.16158#bib.bib6)\) and RL methods such as GRPO \(Shao et al\., 2024 (https://arxiv.org/html/2604.16158#bib.bib19)\) encourage LLMs to generate reasoning traces\. RL improves reasoning performance by optimizing outcome\-based rewards\. However, outcome rewards focus on answer correctness and do not enforce that the reasoning trace causally influences the final prediction\. In contrast, we explicitly reward causal dependency between CoT tokens and the answer\.

**Reasoning Trace Faithfulness.** The faithfulness of model explanations has been studied extensively in interpretability research\. Work demonstrated and argued that CoT explanations can be unfaithful, with models sometimes reaching correct answers through reasoning that contradicts their stated logic \(Turpin et al\., 2023 (https://arxiv.org/html/2604.16158#bib.bib26); Lanham et al\., 2023 (https://arxiv.org/html/2604.16158#bib.bib20); Parcalabescu and Frank, 2024 (https://arxiv.org/html/2604.16158#bib.bib28); Barez et al\., 2025 (https://arxiv.org/html/2604.16158#bib.bib27)\)\. Process reward models assign rewards to intermediate reasoning steps using external supervision \(Lightman et al\., 2023 (https://arxiv.org/html/2604.16158#bib.bib21)\) and improve the plausibility of CoT\. However, plausibility reflects consistency with an external evaluator, not alignment with the model’s internal computation\. Faithfulness instead reflects the model’s mechanisms that causally produce the answer\. Therefore in our method, we learn an attention mask for each sample to verify the causal influence of each token\.

**Critical Reasoning Tokens.** Work showed that individual CoT tokens \(called *critical tokens*\) can play outsized influence on LLM outputs \(Lin et al\., 2025 (https://arxiv.org/html/2604.16158#bib.bib2)\)\. Vassoyan et al\. \(2025 (https://arxiv.org/html/2604.16158#bib.bib8)\) encouraged exploration on such tokens to improve RL fine\-tuning efficiency\. Yan et al\. \(2024 (https://arxiv.org/html/2604.16158#bib.bib7)\) intervene on attention weights to mitigate over\-reliance on misleading tokens in few\-shot examples\. Unlike these methods, which analyze or manipulate reasoning tokens post hoc, we use differentiable attention manipulation to learn token\-level saliency and incorporate it into RL training\.

**Attention Manipulation.** AtMan \(Deiseroth et al\., 2023 (https://arxiv.org/html/2604.16158#bib.bib5)\) introduced memory\-efficient attention manipulation for transformer interpretability, enabling targeted suppression of individual tokens’ to estimate their influence\. We frame AtMan as a differentiable attention mask and optimize it toward correct answers via SGD, to identify salient reasoning tokens\.

## 3 Differentiable attention manipulation for faithful reasoning

In the following, we introduce our method AtManRL to train models to produce salient reasoning traces by framing AtMan as a differentiable attention mask\. Specifically, we \(1\) recap AtMan, \(2\) describe how we learn the mask, \(3\) derive a saliency measure from the optimized mask, and finally, \(4\) integrate saliency as an RL reward during training\.

### 3\.1 Background: AtMan Attention Manipulation

First, we review the additive AtMan\-attention manipulation introduced in Deiseroth et al\. \(2023 (https://arxiv.org/html/2604.16158#bib.bib5)\)\. In a standard transformer, attention outputs are computed as: O = softmax\(H\) · V, where · denotes matrix multiplication and where the pre\-softmax attention scores are given by H = QK^T / sqrt(d)\. Here, Q, K, V ∈ R^(h×s×d) denote the query, key, and value tensors with h attention heads, sequence length s, and head dimension d\. AtMan manipulates the pre\-softmax scores H with an additive mask H^(AtMan) ∈ R^(s×s):

H = Q · K^T / sqrt(d) + H^(AtMan) (1)

Applying the mask H^(AtMan) before the softmax ensures that the resulting attention scores still add to one after the softmax\. Additionally, unlike other perturbation methods in XAI \(e\.g\., Shapley values\), this does not introduce a shift in the input distribution or positional embeddings, but carefully manipulates the attention of the model of every single token\. Positive mask values increase attention to selected tokens, whereas negative values suppress their influence\. For autoregressive models, we additionally apply a lower\-triangular causal mask T and compute H_M = H ∘ T, where ∘ denotes the Hadamard product\. Deiseroth et al\. \(2023 (https://arxiv.org/html/2604.16158#bib.bib5)\) used H^(AtMan) to suppress the attention to individual tokens by assigning a fixed negative value – treated as a hyperparameter – to the corresponding columns of H^(AtMan) to analyze each individual token’s impact on the output logits of the LLM\.

### 3\.2 Training an H^(AtMan) mask for measuring saliency

Because the mask enters the pre\-softmax attention scores additively, it remains fully differentiable\. We restrict H^(AtMan) to tokens within the reasoning trace \(CoT\) and do not modify attention over prompt tokens or final answer tokens\. The prompt remains fixed and outside the model’s control, and therefore does not constitute a target for reward shaping\. Conversely, we require that the final answer depends causally on the reasoning trace\. If the reasoning trace is salient, perturbing its attention should affect the probability of the correct answer\.

We initialize all CoT\-related mask entries with a negative constant c = -0\.4\. This initialization uniformly suppresses attention to reasoning tokens and produces a flatter post\-softmax distribution\. From this suppressed state, we optimize the mask to restore the probability of the correct answer\.

Specifically, to train the mask, we minimize the cross\-entropy loss of the logits of the predicted answer tokens y_1:N under teacher forcing, as depicted in Figure 2 (https://arxiv.org/html/2604.16158#S3.F2):

L_mask = - (1/N) ∑_(n=1)^N log P(y_n | c_1:T, y_1:n-1, H^(AtMan)), (2)

where c_1:T denotes the CoT tokens\. That means a mask with all zeros would lead to a loss of 0\. The mask is the only trainable object at this stage to identify attention configurations that preserve answer likelihood under suppressed reasoning\.

We stop optimizing the mask after a fixed number of steps\. We normalize the mask by dividing by the initialization constant Ĥ^(AtMan) = H^(AtMan)/c and compute the average normalized mask value over the lower\-triangular \(causal\) region:

R_Faithfulness(a_i) = (1/|I_v|) ∑_(w∈I_v) Ĥ^(AtMan)_(w,v), I_v = { w ∈ {1, ..., n} | w ≥ v }. (3)

This quantity serves as our saliency measure and reward for rollout a_i\. Intuitively, it measures how strongly the reasoning tokens must be re\-enabled to preserve the correct answer probability\.

Figure 1: We initialize the additive attention mask H^(AtMan) with a negative value to suppress attention over CoT tokens\. We then optimize the mask for 200 steps to restore the correct answer probability\.

Figure 2: To identify non\-salient tokens, we optimize the mask per rollout with the goal of restoring the original label log probabilities\.

### 3\.3 Optimizing Saliency via Reinforcement Learning \(RL\)

For RL, we combine the saliency reward with a standard outcome reward:

R_Outcome(a_i) = { 0 if i = j, -1 otherwise }, (4)

where j denotes the ground\-truth answer and a_i the rollout prediction\.

Thus, the total reward is R_total(a_i) = R_Outcome(a_i) + R_Faithfulness(a_i)\. Following GRPO \(Shao et al\., 2024 (https://arxiv.org/html/2604.16158#bib.bib19)\), we compute the group\-normalized reward \hat{R}_total = (1/N) ∑_(i=1)^N R_total(a_i), and define the advantage A(a_i) = R_total(a_i) - \hat{R}_total\. We then update the policy using the clipped GRPO objective:

L_GRPO(θ) = -(1/N) ∑_(i=1)^N min( (π_θ(a_i|q) / π_θ_old(a_i|q)) A(a_i), clip(π_θ(a_i|q) / π_θ_old(a_i|q), 1-ε, 1+ε) A(a_i) ). (5)

### 3\.4 Implementation Details

**Mask Optimization:** We use AdamW \(Loshchilov and Hutter, 2019 (https://arxiv.org/html/2604.16158#bib.bib15)\) with a learning rate of 1e-3, betas of 0\.6 and 0\.9999, and a weight decay of 0\.05\. We train for 200 gradient steps to update H^(AtMan)\.

**Value Scaling:** Before adding H^(AtMan) to the attention scores, we clamp to an upper bound of 0 to prevent applying positive values to each token since we just want to detect non\-salient ones\. We then scale them with a factor 10 for faster convergence\.

**RL Training Details:** We fine\-tune the model for 8 epochs using GRPO with 8 rollouts per query and a maximum generation length of 1024 tokens\. For each update, we compute the saliency reward over batches of 8 queries\. We perform two gradient passes per batch using a mini\-batch of 2 to recompute policy log\-probabilities for the clipped RL objective\. We use a fixed 1×10^(-6) learning rate and ε=0\.2 standard clipping parameter\. We conducted all experiments on 48 NVIDIA A100 GPUs\.

Table 1: Comparison between the baseline and AtManRL

Similar Articles

Learning to Refine Hidden States for Reliable LLM Reasoning

arXiv cs.LG

Proposes ReLAR, a reinforcement-guided latent refinement framework that iteratively updates hidden representations in LLMs before decoding, improving reasoning reliability and efficiency compared to chain-of-thought methods.

Adaptive Latent Agentic Reasoning

arXiv cs.CL

This paper introduces Adaptive Latent Agentic Reasoning (ALAR), a dual-mode framework for LLM agents that uses compact latent reasoning for routine turns and selectively escalates to explicit chain-of-thought for harder decisions, achieving up to 84.6% token reduction while maintaining task accuracy.