Resolving the bias-precision paradox with stochastic causal representation learning for personalized medicine

arXiv cs.AI Papers

Summary

This paper introduces a stochastic causal representation learning framework to resolve the bias-precision paradox in personalized medicine, demonstrating improved accuracy and interpretability in ICU clinical decision support.

arXiv:2605.05706v1 Announce Type: new Abstract: Estimating individualized treatment effects from longitudinal observational data is central to data-driven medicine, yet existing methods face a fundamental limitation: reducing confounding bias often suppresses clinically informative heterogeneity, degrading patient-specific predictions. Here, we identify this tension as a bias-precision paradox in causal representation learning and introduce sampling-based maximum mean discrepancy (sMMD), a stochastic alignment strategy that replaces global adversarial balancing with subset-level matching. We instantiate this approach in a framework for counterfactual outcome prediction with attribution-grounded interpretability. Across two large-scale ICU cohorts (n = 27,783), our framework improves accuracy under distribution shift, reducing error by up to 11.5% and substantially increasing recall in high-risk tasks. Mechanistic analyses show that sMMD selectively preserves clinically decisive variables. In human-AI evaluation, our method outperforms clinicians-in-training and large language models, and improves clinician accuracy by 14.7% while reducing decision time, enabling interpretable, real-time clinical decision support.
Original Article Export to Word Export to PDF
View Cached Full Text

Cached at: 05/08/26, 08:37 AM

# Resolving the bias–precision paradox with stochastic causal representation learning for personalized medicine
Source: [https://arxiv.org/html/2605.05706](https://arxiv.org/html/2605.05706)
\[1, 2\]\\fnmDianbo\\surLiu

\[1\]\\orgdivSchool of Medicine,\\orgnameNUS,

2\]\\orgdivCollege of Design and Engineering,\\orgnameNUS,

3\]\\orgdivAier Academy of Ophthalmology,\\orgnameCentral South University

4\]\\orgdivDivision of Dermatology,\\orgnameThammasat University

5\]\\orgdivNational University Heart Center,\\orgnameNational University Health System

6\]\\orgdivBeijing Tongren Eye Center,\\orgnameBeijing Tongren Hospital, Capital Medical University

7\]\\orgdivSchool of Medicine,\\orgnameStanford University

8\]\\orgdivCollege of Computer Science,\\orgnameZhejiang University

9\]\\orgdivDuke\-NUS Medical School,\\orgnameNUS

10\]\\orgdivDepartment of Computer Science,\\orgnameNanjing University

11\]\\orgdivSchool of Medicine,\\orgnameYale University

12\]\\orgdivDeepMind,\\orgnameGoogle

13\]\\orgdivDepartment of Neurosurgery,\\orgnameState Key Laboratory of Oncology in South China

14\]\\orgdivCollaborative Innovation Center for Cancer Medicine,\\orgnameSun Yat\-sen University Cancer Center

15\]\\orgnameChengdu OrganoidMed Medical Laboratory

16\]\\orgdivDepartment of Geriatric Comprehensive Surgery and International Medicine,\\orgnameSichuan Provincial People’s Hospital

16\]\\orgdivDepartment of Geriatric Comprehensive Surgery and International Medicine,\\orgnameSichuan Provincial People’s Hospital

17\]\\orgdivHarvard Medical School,\\orgnameHarvard University

###### Abstract

Estimating individualized treatment effects from longitudinal observational data is central to data\-driven medicine, yet existing methods face a fundamental limitation: reducing confounding bias often suppresses clinically informative heterogeneity, degrading patient\-specific predictions\. Here, we identify this tension as a bias–precision paradox in causal representation learning and introduce sampling\-based maximum mean discrepancy \(sMMD\), a stochastic alignment strategy that replaces global adversarial balancing with subset\-level matching\. We instantiate this approach in a framework for counterfactual outcome prediction with attribution\-grounded interpretability\. Across two large\-scale ICU cohorts \(n = 27,783\), our framework improves accuracy under distribution shift, reducing error by up to 11\.5% and substantially increasing recall in high\-risk tasks\. Mechanistic analyses show that sMMD selectively preserves clinically decisive variables\. In human–AI evaluation, our method outperforms clinicians\-in\-training and large language models, and improves clinician accuracy by 14\.7% while reducing decision time, enabling interpretable, real\-time clinical decision support\.

###### keywords:

Medical error reduction, Confounding bias, Maximum mean discrepancy, Intensive care unit, Clinical decision support, Personalized medicine, Open\-source healthcare AI

## 1Introduction

![Refer to caption](https://arxiv.org/html/2605.05706v1/x1.png)Figure 1:Resolving the precision\-bias paradox in causal AI for critical care\. \(a\) AI integration within the ICU clinical workflow\. The system aggregates multimodal patient data—including demographics, vital signs, treatments, and laboratory results—to provide decision support for treatment selection\. \(b\) The bias\-precision dilemma and its causal origin\. Left: the causal directed acyclic graph \(DAG\) illustrates how time\-varying confounding arises in longitudinal treatment settings—co\-variates \(XX\) simultaneously influence treatment assignment \(AA\) and outcomes \(YY\), creating confounding pathways \(red arrows\) that balancing methods aim to block \(red crosses\)\. Right: the bias\-precision scatter plot shows that existing adversarial balancing methods are constrained by a technological frontier \(dashed line\), where aggressive bias removal \(xx\-axis\) sacrifices patient specificity \(yy\-axis\)\. In the over\-balancing regime, generic representations fail to capture individual physiological dynamics\. GITO \(green star\) transcends this frontier via sampling\-based Maximum Mean Discrepancy \(sMMD\), simultaneously mitigating confounding while preserving information essential for individualized prediction\. \(c\) The sampling\-based MMD \(sMMD\) balancing strategy\. Representations𝓑\\bm\{\\mathcal\{B\}\}are grouped according to treatment assignments\. Random sampling is performed across groups to align sample distributions\. The Maximum Mean Discrepancy \(MMD\) is then applied to minimize distributional differences between sampled treatment groups, achieving balanced representations without global homogenization\. \(d\) Human\-AI comparison and collaboration\. Patient trajectories are presented in modality\-appropriate formats: text prompts for LLMs, vital sign charts for medical students, and structured matrices for GITO\. GITO achieved 75\.6% prediction accuracy, outperforming LLMs \(best: 67\.2%\) and unassisted medical students \(55\.6%\)\. Bottom: feature attribution scores are combined with LLM\-driven interpretation to support human\-AI collaboration, where clinician performance improved when assisted by GITO’s explanations \(right\)\.![Refer to caption](https://arxiv.org/html/2605.05706v1/x2.png)
![Refer to caption](https://arxiv.org/html/2605.05706v1/x3.png)

Figure 2:The GITO framework and clinical decision\-support interface\.a,Patient longitudinal data \(vitals𝑽\\bm\{V\}, co\-variates𝑿t\\bm\{X\}\_\{t\}, outcomes𝒀t\\bm\{Y\}\_\{t\}, and treatments𝑨t\\bm\{A\}\_\{t\}\) are encoded into balanced representations𝓑t\\bm\{\\mathcal\{B\}\}\_\{t\}by the representation moduleΘℬ\\Theta\_\{\\mathcal\{B\}\}\. The outcome predictorΘY\\Theta\_\{Y\}generates multi\-step counterfactual predictions𝒀^\\hat\{\\bm\{Y\}\}under alternative treatment scenarios\. Sampling\-based MMD \(sMMD\) aligns treatment group distributions while preserving patient\-level heterogeneity \(right: distribution matching\), in contrast to adversarial methods that enforce global invariance \(left: scatter plots before vs\. after balancing\)\. An attribution module computes per\-variable contributions, which are translated into natural\-language clinical rationales by an LLM\-based explanation module \(upper right\)\.b,Web\-based clinical interface\. Clinicians upload patient data and select a prediction model \(top\)\. The dashboard displays variable\-level attribution scores, temporal contribution patterns, and counterfactual treatment trajectories \(right\-bottom\)\. An LLM\-generated explanation provides a structured clinical rationale with treatment preference distribution \(left\-bottom\)\. Data sources: MIMIC\-III \(N=25,186N\{=\}25\{,\}186, United States\) and AmsterdamUMCdb \(N=2,597N\{=\}2\{,\}597, the Netherlands\)\.Data\-driven personalized medicine aims to tailor interventions to individual patient characteristics by estimating treatment effects from longitudinal observational data\[journal/aim1997/757Rubin,journal/naturemed2024/30Feuerriegel,journal/epidemiology2000/Robins,journal/npjdm2020/17Sutton\]\. A critical challenge, however, limits this endeavour: observational data are inherently confounded, because treatment decisions reflect patient severity rather than random allocation\[journal/biometrika1983/41Rosenbaum,journal/jasa1984/516Rosenbaum\]\. To address this, state\-of\-the\-art frameworks enforce distributional alignment between treated and untreated populations in the learned representation space \(Figure[1](https://arxiv.org/html/2605.05706#S1.F1)a\)\[conference/nips2018/31Lim,conference/iclr2020/Bica,conference/icml2022/Melnychuk,conference/icml2024/wang\]\. This deconfounding process introduces an under\-examined trade\-off: the patient features that drive treatment assignment, and thus differ most between groups, are often the most clinically informative\. Aggressive alignment can neutralize the clinically informative heterogeneity \(severity indicators, disease subtypes, comorbidity profiles, and temporal trajectories\) that is essential for individualized prediction\[journal/ije2016/45Dahabreh,journal/biopsy2020/88Eric,journal/lancetDH/1Forte\]\. We term this the bias\-precision dilemma: global deconfounding improves average causal estimates at the expense of individual\-level specificity\. This dilemma is not confined to a single clinical domain; it arises wherever observational data guide treatment, from oncology dose optimization to chronic disease management\[journal/bmj2018/363Kent,journal/medrxiv2026/Soltanifar\]\.

Current representation\-learning approaches for treatment effect estimation, including Counterfactual Recurrent Networks \(CRN\)\[conference/iclr2020/Bica\], Causal Transformers \(CT\)\[conference/icml2022/Melnychuk\], Adversarial Counterfactual Temporal Inference Network \(ACTIN\)\[conference/icml2024/wang\], and their variants\[conference/nips2018/31Lim,conference/icdm2022/1053Li,conference/kdd2024/12Wu,conference/nips2024/37Bouchattaoui\], adopt a shared encoder that maps patient co\-variates into a latent space, from which treatment\-specific outcome heads generate predictions\. To remove confounding, these methods converge on a common paradigm: adversarial balancing\. The encoder is trained to produce representations from which treatment assignment is unrecoverable, while an adversarial objective simultaneously attempts to identify it\. This competition drives the learned representations toward global distributional invariance across treatment groups, effectively simulating the balance of a randomized trial\. The consequence, however, is indiscriminate: the adversarial objective neutralizes any distributional difference between treatment groups, regardless of clinical relevance \(Figure[1](https://arxiv.org/html/2605.05706#S1.F1)a\)\. Consider vasopressor therapy in sepsis\. Sicker patients are more likely to receive vasopressors, so blood pressure trajectories \(the signals that determine vasopressor need\) differ systematically between treated and untreated groups\. Adversarial balancing suppresses precisely these trajectories, rendering the model unable to distinguish patients who require intervention from those who do not\[conference/nips2022/Moayeri,conference/icml2024/Huang\]\. The result is “over\-balancing”: representations that are deconfounded on average but uninformative at the individual level, leading to poor generalization across clinically distinct subpopulations\[journal/crt2024/4Curth,journal/cer2017/288Li\]\. A second barrier compounds the first: existing causal inference frameworks remain opaque\. Standard explainability techniques yield feature\-level importance scores but not the contextual, clinically grounded rationale that physicians need to trust and act on a recommendation\[journal/naturemi/421Liu\]\. Together, over\-balanced representations and opaque predictions constrain the real\-world utility of current methods \(Figure[1](https://arxiv.org/html/2605.05706#S1.F1)a\)\.

The bias\-precision dilemma is particularly consequential in intensive care units \(ICUs\), where interventions are time\-critical and the margin for error is narrow\[journal/jbiei2016/3Roughead,journal/cricare2008/12Moyen,journal/bmj2019/366Panagioti\]\. Therapies such as vasopressor administration and mechanical ventilation demand continuous recalibration to a patient’s evolving physiology\[journal/ccm2024/1633Bauer,journal/aim2019/285Cox\], and treatment assignment is strongly confounded by illness severity\. ICU environments also generate rich, high\-resolution longitudinal data while carrying substantial adverse\-event burden; unsafe care contributes to an estimated 3 million deaths annually worldwide\[journal/bmj2016/353Makary,who2023patientsafety\]\. High stakes, strong confounding, and data availability make the ICU an ideal proving ground for resolving this dilemma\. Data\-driven models for personalized ICU interventions have shown the potential to reduce mortality by up to 20%\[conference/prmlhc2017/68Raghu,journal/eswa2021/169Akash,journal/aim2019/285Cox\], yet realizing this potential demands models that are both accurate and trustworthy\.

We propose GITO \(Generalized and Interpretable Treatment Outcome\), a framework that replaces adversarial balancing with a fundamentally different de\-confounding strategy \(Figure[1](https://arxiv.org/html/2605.05706#S1.F1)b,c\)\. In longitudinal treatment settings, the encoder produces representations across many time steps, each associated with a treatment assignment\. Rather than training a discriminator to enforce global invariance over all these representations simultaneously, GITO employs sampling\-based maximum mean discrepancy \(sMMD\): at each training iteration, small random subsets are drawn from each treatment group and aligned via MMD\. This stochastic, sample\-level alignment provides a softer distributional constraint that mitigates confounding without forcing the entire representation space into a single homogenized distribution\. As a result, the model retains the patient\-level heterogeneity \(severity indicators, disease subtypes, comorbidity profiles, and temporal trajectories\) that adversarial methods inadvertently discard through their pursuit of global invariance\. Accuracy alone, however, does not earn clinical trust; physicians must understand why a model recommends a course of action before they act on it\[journal/mlhc2019/106Tonekaboni\]\. Existing explainability methods such as SHAP values or attention\-weight visualization yield numerical feature attributions but lack the clinical context that supports bedside reasoning\[journal/lancetDH2021/3Ghassemi\]\. To bridge this gap, GITO incorporates an attribution\-grounded interpretability pipeline that translates the model’s per\-feature contributions into natural\-language clinical narratives via a Large Language Model, constrained to reason over model\-derived evidence to mitigate hallucination risk \(Figure[1](https://arxiv.org/html/2605.05706#S1.F1)b\)\.

We evaluate GITO on two real\-world ICU databases, MIMIC\-III \(25,186 patients, United States\) and AmsterdamUMCdb \(2,597 patients, the Netherlands\), spanning populations with distinct demographic and ethnic compositions\. GITO maintains robust performance when transferring from the training population \(White,N=3,560N\{=\}3\{,\}560\) to held\-out Asian \(N=119N\{=\}119\), African \(N=383N\{=\}383\), and Latino \(N=143N\{=\}143\) descendant subgroups, achieving a 13% average RMSE reduction relative to unbalanced baselines\. We validate the framework on two clinically distinct tasks: blood pressure trajectory prediction and ventilator re\-intubation risk assessment\. For re\-intubation, GITO\-augmented predictions reduced high\-risk false negatives by 42% \(recall 0\.506 to 0\.719\) with a concurrent AUC improvement from 0\.711 to 0\.756\. In a human\-AI benchmark \(Figure[1](https://arxiv.org/html/2605.05706#S1.F1)d and[7](https://arxiv.org/html/2605.05706#S2.F7)b\), GITO achieved 75\.6% accuracy, outperforming all four tested LLMs by 8\.4 to 19\.0 percentage points and unassisted medical students \(n=4n=4\)\. In a separate cooperation study with practicing clinicians \(n=3n=3, Figure[7](https://arxiv.org/html/2605.05706#S2.F7)c\), GITO’s explanations improved accuracy by 14\.7 percentage points, reduced decision\-making time by 74%, and raised the safety rate from 82\.4% to 89\.8%\. To our knowledge, GITO is the first framework to jointly resolve the bias\-precision trade\-off through sample\-level distributional alignment and to close the interpretability gap with attribution\-grounded LLM reasoning in longitudinal causal inference\. We release GITO as a freely accessible, open\-source, web\-based clinical tool \(Figure[2](https://arxiv.org/html/2605.05706#S1.F2)\) with sub\-50 ms inference latency on standard CPU hardware, enabling deployment within secure hospital intranets\. Both the sMMD alignment strategy and the interpretability pipeline are*domain\-agnostic*, applicable wherever individualized treatment effects must be estimated from observational data\.

Table 1:Demographic and clinical characteristics of the study cohort\.A random subset of 5,000 patients with ICU stays≥\\geq30 hours was drawn from MIMIC\-III \(N=25,186N=25\{,\}186\)\. Values are presented as median \(IQR\) ornn\(%\)\.Abbreviations:SD, standard division; SVR, Systemic vascular resistance; GCS score, Glasgow coma scale total; PEEP, Positive end\-expiratory pressure\. FiO2, Fraction inspired oxygen \(a\) For time\-varying vital signs, mean values were computed over the first 24 hours following ICU admission\. \(b\) For treatments, the average number of hours of continuous or intermittent interventions was computed across all patients\.

## 2Results

We developed GITO, a framework for individualized treatment outcome prediction in the ICU that integrates a sampling\-based maximum mean discrepancy \(sMMD\) alignment strategy with attribution\-grounded interpretability \(Figure[1](https://arxiv.org/html/2605.05706#S1.F1)\)\. To facilitate clinical adoption, GITO is implemented as an open\-source, web\-based decision\-support tool \(Figure[2](https://arxiv.org/html/2605.05706#S1.F2)\)\.

GITO was developed and validated on two large\-scale ICU cohorts from the United States \(MIMIC\-III;n=25,186n=25\{,\}186\) and the Netherlands \(AmsterdamUMCdb;n=2,597n=2\{,\}597\), comprising 27,783 individuals in total\. First, we evaluated generalization across geographic and demographic distribution shifts \(Table[2](https://arxiv.org/html/2605.05706#S2.T2), Table[3](https://arxiv.org/html/2605.05706#S2.T3), Figure[3](https://arxiv.org/html/2605.05706#S2.F3), Figure[4](https://arxiv.org/html/2605.05706#S2.F4)\)\. Next, we assessed downstream clinical utility through ventilator weaning prediction \(Figure[5](https://arxiv.org/html/2605.05706#S2.F5)\)\. Then, we evaluated model interpretability via attribution analysis and LLM\-based explanations in a septic shock case study \(Figure[6](https://arxiv.org/html/2605.05706#S2.F6)\)\. Moreover, we benchmarked GITO against ICU clinicians \(Figure[7](https://arxiv.org/html/2605.05706#S2.F7)\)\. In addition, we conducted human\-AI collaboration experiments to quantify the effect of explanation\-enhanced outputs on clinician performance \(Figure[7](https://arxiv.org/html/2605.05706#S2.F7)\)\. We further validated robustness to confounding on a fully synthetic dataset \(Table[5](https://arxiv.org/html/2605.05706#S2.T5), Figure[8](https://arxiv.org/html/2605.05706#S2.F8), Figure[9](https://arxiv.org/html/2605.05706#S2.F9)\)\. Finally, we quantified bias\-accuracy trade\-offs across representation\-balancing strategies \(Figure[10](https://arxiv.org/html/2605.05706#S2.F10)\)\.

### 2\.1Patient cohort

We analyze the effectiveness and robustness of GITO on three patient cohorts that span varying levels of complexity and real\-world variability: AmsterdamUMCdb, MIMIC\-III ICU cohort, and a fully synthetic tumor growth dataset\.

MIMIC\-III electronic medical record data\.MIMIC\-III database\[journal/scidata2016/Johnson\]is a large and widely used ICU patient cohort comprising detailed electronic health records\. In this study, we included patients whose ICU stays lasted between 30 and 60 hours to ensure sufficient temporal coverage for treatment\-outcome modeling\. A total of 25,186 patients met these criteria, comprising 56\.3% males and 43\.7% females, with a mean age of 62\.9 years\. The cohort included patients from 41 ethnic groups, with an average stay in the ICU of 44\.93 hours\. Among the included patients, vasopressor therapy was administered for an average of 7\.74±\\pm15\.02 hours and mechanical ventilation for 10\.39±\\pm17\.49 hours\. The baseline demographic and treatment characteristics are summarized in[1](https://arxiv.org/html/2605.05706#S1.T1)\.

AmsterdamUMCdb electronic medical record data\.The AmsterdamUMCdb database\[journal/ccm2021/49Thoral\]is a large, openly accessible intensive care dataset containing detailed electronic health records from two university medical centers in the Netherlands\. In this study, we included adult patients whose ICU stays lasted between 30 and 60 hours to ensure sufficient temporal coverage for treatment\-outcome modeling\. A total of 2,597 patients met these criteria, comprising 1,614 \(62\.2%\) males and 983 \(37\.8%\) females, with the largest age group being 70\-79 years \(25\.6%\)\. The mean ICU stay was 42\.9±\\pm7\.1 hours\. Among the included patients, vasopressor therapy was administered for an average of 12\.91±\\pm15\.76 hours, and mechanical ventilation for 12\.34±\\pm15\.40 hours\. Baseline demographic and treatment characteristics are summarized in Appendix[8](https://arxiv.org/html/2605.05706#A2.T8)\.

Synthetic patient cohort for controlled confounding evaluation\.To enable controlled evaluation of counterfactual prediction, we simulated a synthetic patient cohort \(n=10,000n=10\{,\}000\) using a pharmacokinetic\-pharmacodynamic \(PKPD\) tumor growth model\[journal/scirep2017/7Geng\]\. This model simulates individualized treatment responses with known ground\-truth counterfactual outcomes, allowing precise quantification of prediction accuracy under varying degrees of treatment selection bias\[conference/iclr2020/Bica,conference/icml2022/Melnychuk,conference/icml2024/wang\]\. The synthetic cohort includes patients with diverse baseline tumor characteristics \(volume and growth rate\) and treatment scenarios spanning 30\-day observation periods\. Confounding strength \(γ\\gamma\) was systematically varied from 0 \(randomized treatment\) to 7\.0 \(strong selection bias\) to evaluate model robustness across realistic clinical scenarios where treatment assignment depends on patient severity and prognosis\. While tumor growth differs from acute ICU conditions, the underlying mathematical framework for treatment effect estimation and confounding control directly translates to ICU trajectory prediction tasks\. Full simulation parameters, including treatment assignment mechanisms and validation protocols, are detailed in Appendix[B\.1](https://arxiv.org/html/2605.05706#A2.SS1)\.

### 2\.2GITO enables robust generalization across geographic and demographic shifts

Table 2:Multi\-step\-ahead prediction results on MIMIC\-III and AmsterdamUMCdb cohorts\. The IID setting refers to standard train\-test splits within the same population distribution\. In contrast, the OOD setting is designed to assess the generalization ability of models to previously unseen patient subgroups\. Specifically, for the MIMIC\-III cohorts, the OOD evaluation is conducted on non\-European decent patients, while models are trained only on European\-decent patients\. For AmsterdamUMCdb, OOD evaluation corresponds to a cross\-dataset generalization setting, where models are trained exclusively on MIMIC\-III and evaluated directly on AmsterdamUMCdb without any fine\-tuning\. This setting reflects a more realistic deployment scenario, where a model trained in one hospital system is applied to a different clinical environment with distinct patient characteristics and data distributions\.\*denotes statistically significant improvement \(p≤0\.05p\\leq 0\.05\)\.Table 3:Multi\-step\-ahead prediction results on the MIMIC\-III dataset in out of distribution settings with 3 ethnicity\. In this case, White patient were set as training set, and Asian, African and Latino descendant patients were set as test set, respectively\. Shown: RMSE as mean±\\pmstandard deviation over ten runs\.![[Uncaptioned image]](https://arxiv.org/html/2605.05706v1/img/experiment_results_diseases.png)

Figure 3:Disease\-stratified out\-of\-distribution prediction performance \(τ=𝟔\\bm\{\\tau=6\}\)\.Multi\-step\-ahead RMSE on MIMIC\-III patients stratified by ethnicity and disease category\. Models were trained exclusively on patients of European descent \(N = 3,560\) and evaluated on held\-out Asian \(N = 119\), Black \(N = 383\), and Hispanic \(N = 143\) patients across three disease groups: cardiovascular \(N = 75\), neurological \(N = 87\), and infectious/inflammatory \(N = 85\)\. Values are RMSE \(mean±\\pms\.d\.;n=10n=10independent runs\)\. Statistical significance was assessed by two\-sided pairedt\-tests comparing each sMMD\-enhanced model against its adversarial balancing counterpart;p∗<0\.05\{\}^\{\*\}p<0\.05\.

![Refer to caption](https://arxiv.org/html/2605.05706v1/img/reconstruction_signals_loss.png)Figure 4:Per\-variable information loss \(Δ​R2\\Delta R^\{2\}\) on the MIMIC\-III cohort comparing sMMD and MINE \(adversarial\) balancing\.\(a\)Clinically decisive variables directly involved in treatment decisions in ICU\.\(b\)General physiological vitals and laboratory values\. PositiveΔ​R2\\Delta R^\{2\}indicates information lost relative to an unbalanced encoder; values near zero indicate full preservation\. Error bars denote 95% confidence intervals over ten independent runs\.To evaluate whether the sMMD balancing strategy enables predictive models to generalize across hospitals and patient populations, we assessed performance on large\-scale ICU cohorts from the United States \(MIMIC\-III;n=25,186n=25\{,\}186\) and the Netherlands \(AmsterdamUMCdb;n=2,597n=2\{,\}597\)\. We examined generalization along three axes, geographic \(cross\-hospital\), demographic \(cross\-ethnicity\), and disease\-specific, and then investigated the mechanistic basis of the observed differences through per\-variable information\-preservation analysis\.

Institutional generalization \(geographic shift\)\.In the IID setting, GITO achieved parity with baseline models \(CRN, CT, and ACTIN\) on local test sets \(Table[2](https://arxiv.org/html/2605.05706#S2.T2)\)\. Performance divergence emerged under geographic shift: when models trained on the U\.S\.\-based MIMIC\-III cohort were deployed on the European AmsterdamUMCdb cohort without fine\-tuning, baseline models showed increased error rates, whereas GITO maintained lower RMSE\. Relative to each corresponding unbalanced baseline, sMMD\-enhanced models reduced RMSE by 2\.7\-11\.5% across prediction horizonsτ=1\\tau=1\-66\.

Demographic generalization \(subpopulation shift\)\.We next examined whether the model generalizes across patient ethnicities within the MIMIC\-III cohort\. When models trained on patients of European descent were evaluated on Asian, African, and Latino populations, GITO consistently yielded lower error rates \(Table[3](https://arxiv.org/html/2605.05706#S2.T3)\)\. The model reduced RMSE by 3\.5% for single\-step predictions and by up to 8\.8% for multi\-step forecasts compared to baselines\. These gains were most pronounced in populations with higher baseline prediction errors: in patients of African descent, GITO reduced RMSE by 10\.95%, and for patients of Asian descent, by 8\.8% atτ=6\\tau=6\.

Disease\-specific subgroup analysis\.We stratified performance by disease type to assess reliability across clinical contexts \(Figure[3](https://arxiv.org/html/2605.05706#S2.F3)\)\. GITO yielded the lowest RMSE across cardiovascular, neurological, and infectious conditions\. For cardiovascular disease, RMSE reductions reached 11\.9%, 11\.06%, and 19\.21% for patients of Asian, African, and Latino descent, respectively \(pairedt\-test, allp<0\.05p<0\.05;n=10n=10splits\)\. These disease\-stratified results support the clinical applicability of GITO across heterogeneous ICU populations\.

Per\-variable information preservation\.To characterise how each balancing strategy affected individual clinical variables, we quantified per\-variable information change\. Any representation learning inherently involves information compression due to finite model capacity and dimensionality reduction; some degree of reconstruction error is therefore unavoidable regardless of the balancing strategy applied\. To isolate theadditionaleffect of distribution alignment on each clinical variable, we computed per\-variableΔ​R2=Runbalanced2−Rbalanced2\\Delta R^\{2\}=R^\{2\}\_\{\\text\{unbalanced\}\}\-R^\{2\}\_\{\\text\{balanced\}\}on the MIMIC\-III cohort \(Figure[4](https://arxiv.org/html/2605.05706#S2.F4)\)\. A positiveΔ​R2\\Delta R^\{2\}indicates that balancing incurred additional information loss beyond baseline compression; values near zero indicate no added cost; negative values indicate that balancing improved reconstruction relative to the unbalanced encoder\. Both sMMD and the adversarial MINE objective showed positiveΔ​R2\\Delta R^\{2\}on strongly treatment\-correlated variables such as PaO2and heart rate \([4](https://arxiv.org/html/2605.05706#S2.F4)\(a\)\)\. In contrast, variables with minimal direct influence on acute treatment selection,such as magnesium, cholesterol, and pH \([4](https://arxiv.org/html/2605.05706#S2.F4)\(b\)\), remained near zero for both methods, suggesting that balancing predominantly affected treatment\-correlated dimensions without disrupting less directly treatment\-correlated physiological signals\. Despite this shared pattern, the two strategies differed substantially in both magnitude and consistency\. Across all seven clinically decisive variables in[4](https://arxiv.org/html/2605.05706#S2.F4)\(a\), PaO2, heart rate, respiratory rate, FiO2, GCS, PEEP, and PaCO2, sMMD achieved lowerΔ​R2\\Delta R^\{2\}than MINE, indicating uniformly better information preservation\. The gap was most pronounced for PaO2\(Δ​R2\\Delta R^\{2\}: 0\.09 for MINE vs\. 0\.04 for sMMD\), heart rate \(0\.09 vs\. 0\.06\), and glucose \(0\.11 vs\. 0\.01 in[4](https://arxiv.org/html/2605.05706#S2.F4)\(b\)\)\. Beyond magnitude, the two methods diverged in direction on several clinically critical variables\. For respiratory rate, FiO2, and GCS, MINE showed positiveΔ​R2\\Delta R^\{2\}\(information lost relative to the unbalanced baseline\), whereas sMMD achieved negativeΔ​R2\\Delta R^\{2\}\(information preserved better than the unbalanced encoder\)\. Conversely, MINE produced a large negativeΔ​R2\\Delta R^\{2\}for HCO3\(≈−0\.12\\approx\-0\.12\), yet incurred substantial positiveΔ​R2\\Delta R^\{2\}on PaO2and glucose\. Taken together, MINE exhibited high variance across variables, with both the largest gains and the largest losses, whereas sMMD produced a consistently near\-zeroΔ​R2\\Delta R^\{2\}profile across the full variable set, with selective improvements on variables directly involved in ventilator and consciousness assessment\.

### 2\.3Case study:improving ventilator intubation prediction by GITO

To evaluate clinical utility in a high\-stakes scenario, we applied GITO to predict re\-intubation risk within six hours of mechanical ventilation weaning in a cohort of 205 ICU patients \(Figure[5](https://arxiv.org/html/2605.05706#S2.F5)\)\. Standard predictive models were augmented with GITO\-projected six\-hour physiological trajectories and compared against a historical\-only baseline using 12 hours of retrospective data\.

The GITO\-augmented model improved performance across all four metrics \(Figure[5\(a\)](https://arxiv.org/html/2605.05706#S2.F5.sf1)\): accuracy increased from 0\.668 to 0\.756, precision from 0\.652 to 0\.719, recall from 0\.506 \(95% CI: 0\.398\-0\.612\) to 0\.719 \(95% CI: 0\.621\-0\.811\), and the F1\-score from 0\.570 to 0\.719\. ROC analysis confirmed consistent improvement across decision thresholds, with the AUC increasing from 0\.711 to 0\.756 \(Figure[5\(b\)](https://arxiv.org/html/2605.05706#S2.F5.sf2)\)\. The gains were most pronounced in the low\-to\-moderate false\-positive region where clinical decisions typically operate\.

Integration of GITO\-generated trajectories addressed key limitations of historical\-only models \(Fig\.[5\(a\)](https://arxiv.org/html/2605.05706#S2.F5.sf1)\)\. The baseline model achieved a recall of only 0\.506 \(95% CI: 0\.398–0\.612\), missing nearly half of patients requiring re\-intubation\. Incorporating GITO projections increased recall to 0\.719 \(95% CI: 0\.621–0\.811\)—an improvement of 42%\. Precision improved from 0\.652 to 0\.719, accuracy increased from 0\.668 to 0\.756, and the F1\-score rose from 0\.570 to 0\.719\.

Calibration analysis further demonstrated that the GITO\-augmented model produced more reliable risk estimates \(Figure[5\(c\)](https://arxiv.org/html/2605.05706#S2.F5.sf3)\)\. The expected calibration error decreased from 0\.335 for the baseline to 0\.169 for the GITO\-augmented model\. The baseline calibration curve exhibited pronounced non\-monotonicity, with observed reintubation rates near zero in the 0\.7\-0\.9 predicted\-probability range, indicating severe overconfidence in its high\-risk predictions\. In contrast, the GITO\-augmented model maintained a broadly monotonic relationship between predicted probabilities and observed event rates, with its calibration curve tracking closer to the diagonal across most risk strata\.

Error decomposition analysis revealed a clinically favourable shift in the prediction error profile \(Figure[5\(d\)](https://arxiv.org/html/2605.05706#S2.F5.sf4)\)\. Of the 67 patients who required re\-intubation, the GITO\-augmented model correctly identified 52 \(25\.4% of the cohort\) while misclassifying only 15 as low\-risk false negatives \(7\.3%\)\. Among the 138 patients who did not require re\-intubation, 100 \(48\.8%\) were correctly classified and 38 \(18\.5%\) were false positives\. Compared with the baseline model, which missed nearly half of re\-intubation cases \(recall 0\.506\), the GITO\-augmented model reduced false negatives by 42%, concentrating the majority of residual errors in the lower\-acuity false\-positive category\.

![Refer to caption](https://arxiv.org/html/2605.05706v1/img/ForestPlot_ventilator_weaning.png)\(a\)Forest plot comparing accuracy, precision, recall, and F1\-score between the historical baseline \(12\-hour vitals\) and the GITO\-augmented model \(6\-hour predicted trajectories\); point estimates with 95% confidence intervals\.
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/compare_ROC.png)\(b\)ROC curves; GITO augmentation \(AUC = 0\.756\) versus baseline \(AUC = 0\.711\)
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/calibration_curve.png)\(c\)Calibration curves with expected calibration error \(GITO\-augmented ECE = 0\.169 vs\. baseline ECE = 0\.335\)\.
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/Clinical_Error_Distributoin.png)\(d\)Confusion\-matrix decomposition on the validation cohort \(N=205N=205\), showing clinically relevant error patterns\.

Figure 5:Performance evaluation of the GITO framework for ventilator re\-intubation prediction, demonstrating improved discriminative ability, probability calibration, and clinically meaningful risk stratification\.
### 2\.4Case study: enhance AI model interpretability in a patient with septic shock

![Refer to caption](https://arxiv.org/html/2605.05706v1/img/interpretability_case.png)Figure 6:Interpretable treatment outcome prediction for an individual ICU patient with septic shock\.A single MIMIC\-III patient \(N = 1\) was analyzed using GITO \(ACTIN\-sMMD\) to predict mean blood pressure \(MBP\) trajectories over 12 time steps \(τ=12\\tau=12, corresponding to 12 hours\) under four hypothetical treatment strategies \(None, Vasopressor, Ventilation, Both\)\.\(Upper left\)Top\-five variables by integrated\-gradient magnitude at the current time step\.\(Upper right\)Temporal attribution: the five variables with the largest absolute integrated gradient at each historical time point, indicating their relative contribution to the predicted MBP\.\(Bottom\)Observed MBP \(black\) and predicted MBP trajectories under the four treatment strategies\. The LLM\-generated clinical narrative is presented separately in Box[2\.4](https://arxiv.org/html/2605.05706#S2.SS4)\.We examined the interpretability of GITO through an individual\-level case study of an ICU patient with septic shock,a life\-threatening condition in which vasopressor therapy is administered to maintain mean blood pressure \(MBP\) within a target range that ensures adequate organ perfusion without overtreatment\. We analyzed 12\-hour MBP trajectory predictions under four alternative treatment strategies \(Figure[6](https://arxiv.org/html/2605.05706#S2.F6)\)\. Variable attribution analysis identified tidal volume \(contribution 0\.20\), respiratory rate \(0\.25\), systemic vascular resistance \(SVR; 0\.18\), heart rate \(0\.22\), and oxygen saturation \(0\.15\) as the five most influential variables at the current time point \(Figure[6](https://arxiv.org/html/2605.05706#S2.F6), upper left\)\. Temporal attribution revealed that contribution magnitudes increased at later time points \(upper right\), with the largest effects concentrated in the final hours of the observation window\. Counterfactual trajectory analysis under four hypothetical strategies \(no treatment, vasopressor only, ventilation only, both\) revealed distinct MBP profiles \(Figure[6](https://arxiv.org/html/2605.05706#S2.F6), bottom\)\. The no\-treatment scenario projected gradual recovery toward the target range, indicating that the patient’s underlying physiology was trending toward stabilization without additional intervention\. Vasopressor administration accelerated this recovery but produced the most pronounced MBP rise, with the trajectory eventually exceeding the normal physiological range under sustained use\. Ventilation alone yielded moderate improvement, and combined therapy projected the highest MBP, further above the target ceiling\. The framework’s LLM\-based explanation module \(Box[2\.4](https://arxiv.org/html/2605.05706#S2.SS4)\) integrated historical MBP dynamics, variable attributions, and comparative trajectory analysis into a structured natural\-language summary\. Consistent with the trajectory patterns described above, the LLM assigned preference scores of 40% to vasopressor use \(rapid target attainment\), 30% to conservative management \(gradual but sufficient recovery\), 20% to ventilation, and 10% to combined therapy \(risk of overshooting the target range\)\. The treating clinicians chose not to escalate vasopressor therapy; the patient’s blood pressure subsequently recovered to stable levels\. Additional details about this case study are provided in Appendix[E](https://arxiv.org/html/2605.05706#A5)\.

Example: LLM\-generated clinical rationalePatient context:MAP fluctuating below the 65\-85 mmHg target range with transient dips consistent with septic shock\. Top contributing variables: respiratory rate \(0\.25\), heart rate \(0\.22\), tidal volume \(0\.20\), systemic vascular resistance \(0\.18\), SpO2\(0\.15\)\.Counterfactual analysis:Vasopressor\-only \(Vaso\) predicts rapid MAP recovery into the target range; no\-treatment \(None\) projects slower improvement; combined therapy \(Both\) risks overshooting\.Treatment preference:Vaso40%None30%Vent20%Both10%\.Full output: Supplementary Box[E\.2](https://arxiv.org/html/2605.05706#A5.SS2)

### 2\.5Specialized GITO AI model outperform medical students and rule\-guided LLMs in ventilator weaning prediction

![Refer to caption](https://arxiv.org/html/2605.05706v1/x4.png)Figure 7:Human\-AI comparison and cooperation for ventilator weaning assessment based on predicted treatment outcomes\.\(a\)Study design: medical students \(n=3n=3\), four large language models \(GPT\-4o, GPT\-5\.1, Gemini\-3, Grok\-4\.1\), and the GITO model each received patient vitals to predict 6\-hour post\-extubation trajectories, which were then used to assess re\-intubation risk forn=205n=205mechanically ventilated MIMIC\-III patients \(re\-intubation prevalence, 43%\)\.\(b\)Prediction accuracy \(Top\-1\) across all agents; error bars denote standard deviation overn=4n=4student participants orn=5n=5independent model runs, as applicable\.\(c\)Human\-AI cooperation outcomes from a two\-period crossover study withn=3n=3clinicians: predictive accuracy, decision\-making time, and clinically acceptable safety rate \(1−FNR1\-\\text\{FNR\}\), compared between unassisted and GITO\-assisted conditions \(within\-subject paired comparison\)\.To benchmark GITO against human practitioners and general\-purpose AI systems, we designed a prediction study in which medical students and junior clinicians \(n=4n=\\text\{4\}\) and four leading LLMs independently predicted re\-intubation outcomes for the same 205\-patient ventilator weaning cohort \(Figure[7](https://arxiv.org/html/2605.05706#S2.F7)a; experimental details in Methods[4\.6](https://arxiv.org/html/2605.05706#S4.SS6)\)\. GITO achieved a prediction accuracy of 75\.6%, outperforming all LLMs and human participants \(Figure[7](https://arxiv.org/html/2605.05706#S2.F7)b\)\. Among LLMs equipped with expert clinical reasoning prompts, GPT\-4o scored highest at 67\.2%, followed by Gemini\-3 \(60\.0%\), Grok\-4\.1 \(56\.6%\), and GPT\-5\.1 \(55\.6%\)\. Unassisted medical students achieved 58\.7%, marginally above random chance \(50\.1%\)\. When medical students were provided with GITO’s predictions and attribution\-based explanations, their accuracy increased from 58\.7% to 73\.4%, a 14\.7\-percentage\-point improvement \(Figure[7](https://arxiv.org/html/2605.05706#S2.F7)b\)\. However, collaborative performance remained below GITO’s standalone accuracy \(75\.6%\), with students occasionally overriding correct model predictions\.

### 2\.6Interpretable causal rationales of GITO models improved human clinicians’ performance

To assess whether GITO’s interpretable outputs can enhance clinical decision\-making, we conducted a controlled crossover experiment in which nine clinicians \(three attending physicians, three residents, and three medical students\) predicted re\-intubation risk for all 205 mechanically ventilated patients in our test cohort \(43% requiring re\-intubation; crossover design detailed in Methods[4\.6](https://arxiv.org/html/2605.05706#S4.SS6)\)\. GITO assistance improved clinician performance across all three evaluated dimensions \(Figure[7](https://arxiv.org/html/2605.05706#S2.F7)c\)\. Predictive accuracy increased from 59\.5% to 73\.2%, decision\-making time decreased from 205\.6 to 52\.6 minutes per case batch, a 74% reduction, and the clinically acceptable safety rate \(1−\-FNR\) rose from 82\.4% to 89\.8%\. In the crossover analysis, clinicians who initially worked without AI assistance and then received GITO’s predictions and attribution\-based explanations revised a substantial proportion of their initial incorrect predictions in the AI\-assisted round\.

### 2\.7Computational efficiency of GITO enables accessible and real\-time clinical deployment

To assess whether GITO’s architectural simplification translates into practical deployment advantages, we compared the computational cost of GITO \(ACTIN\-sMMD\) against the adversarial baseline \(ACTIN\) on identical hardware \(Table[4](https://arxiv.org/html/2605.05706#S2.T4)\)\. Replacing the discriminator with the sMMD module eliminated an auxiliary network, reducing the total parameter count by 3\.25% \(120\.0K to 116\.1K\) and converting the min\-max Optimization into a single\-objective problem\. Per\-epoch training time decreased by 21\.1% \(5\.12 s to 4\.04 s\), and the resulting stability improvement yielded a 9\.0% reduction in total training time \(23\.44 to 21\.34 min\)\. At inference, prediction latency was 32\.81 ms per patient on CPU hardware\. This low\-latency, CPU\-compatible inference enabled us to deploy GITO as an open\-source, web\-based clinical interface \(Figure[2](https://arxiv.org/html/2605.05706#S1.F2)\) that integrates longitudinal vital\-sign visualisation, counterfactual treatment simulation, and attribution\-based explanation into a unified dashboard supporting real\-time clinician interaction\. The platform is publicly available111[https://huggingface\.co/spaces/peisongzhang/TreatmentOutcomePredictionSystem](https://huggingface.co/spaces/peisongzhang/TreatmentOutcomePredictionSystem)and can be deployed within secure hospital intranets\.

Table 4:Computational cost comparison\. GITO \(ACTIN\-sMMD\) eliminates the discriminator network, reducing parameter count and inference latency\. Although sMMD calculation adds slight per\-epoch overhead, the model converges significantly faster, reducing total training time\.
### 2\.8sMMD effectively disentangles treatment bias from physiological heterogeneity in predicting treatment outcome

Table 5:One\-step\-ahead prediction results \(τ=1\\tau=1\) on the synthetic tumor growth dataset under varying levels of time\-varying confounding \(γ\\gamma\)\. Values denote RMSE \(mean±\\pmstandard deviation\) over ten independent runs\. Lower is better; best results are highlighted in bold\.![[Uncaptioned image]](https://arxiv.org/html/2605.05706v1/img/multi_step_comparison.png)

Figure 8:Comparison of multi\-step prediction performance \(2\-step, 4\-step, and 6\-step\) across models on the synthetic tumor growth dataset under the single\-treatment sliding window setting\. Results are reported as the average RMSE over ten independent runs, across increasing levels of time\-varying confounding strengthγ\\gamma\.

![Refer to caption](https://arxiv.org/html/2605.05706v1/img/reconstruction_signals_loss_curves.png)Figure 9:Reconstruction loss during training for different balancing objectives\. An independent decoder was trained to reconstruct the original patient co\-variates from the balanced representations produced by each method\. Lower reconstruction loss indicates greater preservation of patient\-specific information\. Shaded regions denote±\\pm1 standard deviation over ten independent runs\.To evaluate sMMD under controlled confounding, we used a synthetic tumor growth dataset \(n=10,000n=10\{,\}000\) with an adjustable time\-varying confounding parameter \(γ\\gamma\)\. Baseline architectures \(CRN, CT, ACTIN\) were compared against their sMMD\-enhanced counterparts across confounding levelsγ=0\\gamma=0\-77\. For one\-step\-ahead predictions \(τ=1\\tau=1\), sMMD conferred only marginal improvements at low confounding levels \(Table[5](https://arxiv.org/html/2605.05706#S2.T5)\)\. Atγ≥4\\gamma\\geq 4, sMMD\-enhanced models began to show consistent gains, with ACTIN\-sMMD achieving the lowest RMSE at six of eight confounding levels\. The advantage of sMMD became more pronounced at longer prediction horizons and higher confounding \(Figure[8](https://arxiv.org/html/2605.05706#S2.F8)\)\. At low confounding \(γ<3\\gamma<3\), all models performed comparably across all horizons\. Beyond this threshold, baseline RMSE rose steeply with increasingγ\\gammaandτ\\tau, whereas sMMD variants exhibited a flatter degradation trajectory\. ACTIN\-sMMD yielded the largest RMSE reductions relative to ACTIN at moderate\-to\-high confounding: atγ=5\\gamma=5, reductions were 14\.1% \(τ=2\\tau=2\), 16\.7% \(τ=4\\tau=4\), and 13\.5% \(τ=6\\tau=6\); atγ=7\\gamma=7, the short\-horizon gain reached 18\.2% \(τ=2\\tau=2\), though the margin narrowed at longer horizons \(10\.5% and 4\.8% atτ=4\\tau=4andτ=6\\tau=6, respectively\) as both models degraded under extreme confounding\. The benefit of sMMD was consistent across architectures: CT\-sMMD reduced RMSE relative to CT by 17\.1% atγ=5\\gamma=5,τ=6\\tau=6, though CRN\-sMMD showed more variable gains across settings\.

To quantify how much patient\-specific information each balancing strategy retains, we trained an independent decoder to reconstruct the original co\-variates from balanced representations \(Figure[9](https://arxiv.org/html/2605.05706#S2.F9)\)\. Among all balancing objectives tested,domain confusion \(CT\), gradient reversal \(CRN\), mutual\-information\-based loss \(ACTIN/MINE\), and sMMD,sMMD\-balanced representations achieved the lowest reconstruction error throughout training, with the gap widening as training progressed\.

### 2\.9GITO balances predictive accuracy with fairness and bias mitigation

![Refer to caption](https://arxiv.org/html/2605.05706v1/img/woBRM_sample_visualization.png)\(a\)Without balancing: treatment\-specific clustering\.
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/mmd_sample_visualization.png)\(b\)With sMMD: well\-mixed treatment groups\.
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/balancing_rmse_comparison.png)\(c\)Multi\-step RMSE comparison \(γ=10\\gamma=10\)\.
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/mmd_time_step_visualization.png)\(d\)Temporal evolution of sMMD\-balanced representations across time steps, showing that the learned representations remains balanced and consistent throughout the sequence\.
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/tsne_by_Age.png)\(e\)Colored by age\. A mild gradient structure is observed, suggesting that age may influence the physiological trajectories and potentially affect treatment outcomes\.
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/tsne_by_Gender.png)\(f\)Colored by gender\. Male and female patients are well mixed in the embedding space, indicating no observable gender bias in the learned representations\.
![Refer to caption](https://arxiv.org/html/2605.05706v1/img/tsne_by_Ethnicity.png)\(g\)Colored by ethnicity\. No distinct clustering is observed among ethnic groups, suggesting that the model captures clinically relevant features rather than demographic\-specific patterns\.

Figure 10:Visualization of learned representations and performance trade\-offs\.\(a\-b\) t\-SNE embeddings of the synthetic dataset showing treatment\-specific clustering without balancing \(a\) versus well\-mixed distributions with sMMD balancing \(b\)\. \(c\) Multi\-step RMSE comparison \(γ=10\\gamma=10\) showing sMMD achieves the best trade\-off between bias removal and accuracy\. \(d\) Temporal stability of sMMD\-balanced representations\. \(e\-g\) Embeddings of MIMIC\-III data colored by Age, Gender, and Ethnicity\. The structured gradient in Age \(e\) contrasts with the uniform mixing in Gender \(f\) and Ethnicity \(g\), indicating the preservation of clinically relevant physiology over demographic bias\.To examine whether sMMD removes treatment\-assignment bias while preserving clinically relevant structure, we visualized learned representations using t\-SNE on both synthetic and real\-world data \(Figure[10](https://arxiv.org/html/2605.05706#S2.F10)\)\. Under severe confounding \(γ=10\\gamma=10\), we compared three model variants: adversarial balancing \(ACTIN\), no balancing \(ACTIN\-woBRM\), and sMMD balancing \(ACTIN\-sMMD\) across prediction horizonsτ=1\\tau=1\-66\(Figure[10\(c\)](https://arxiv.org/html/2605.05706#S2.F10.sf3)\)\. Adversarial balancing yielded the highest average multi\-step RMSE \(4\.15\) and exhibited substantial instability: RMSE spiked atτ=2\\tau=2before partially recovering, and the variance across runs was markedly wider than for either alternative\. Removing balancing entirely produced a lower average RMSE \(3\.26\) with a smooth, monotonically increasing trajectory, but left treatment\-group distributions clearly separated in the latent space \(Figure[10\(a\)](https://arxiv.org/html/2605.05706#S2.F10.sf1)\)\. ACTIN\-sMMD achieved the lowest average RMSE \(3\.23\) with the narrowest cross\-run variance\. Its per\-horizon trajectory closely paralleled the unbalanced model, indicating that sMMD corrected for treatment\-assignment bias without incurring an accuracy penalty\. The resulting representations showed well\-mixed treatment groups \(Figure[10\(b\)](https://arxiv.org/html/2605.05706#S2.F10.sf2)\)\. Temporal visualisation confirmed that this alignment was sustained across all sequential time steps \(Figure[10\(d\)](https://arxiv.org/html/2605.05706#S2.F10.sf4)\), with no re\-emergence of treatment\-specific clustering at later horizons\. To assess whether the balanced representations encode demographic biases, we coloured the MIMIC\-III embeddings by patient attributes \(Figure[10\(e\)](https://arxiv.org/html/2605.05706#S2.F10.sf5)\-[10\(g\)](https://arxiv.org/html/2605.05706#S2.F10.sf7)\)\. Gender and ethnicity showed uniform mixing across the representation space, with no observable clustering by demographic group\. In contrast, age exhibited a structured gradient, with patients over 90 years forming a distinct cluster \(Figure[10\(e\)](https://arxiv.org/html/2605.05706#S2.F10.sf5), upper right\)\. Taken together, the representations did not exhibit systematic demographic partitioning; the age\-related structure was the sole axis of separation, consistent with preserved physiological heterogeneity rather than encoded demographic bias\.

## 3Discussion

A major barrier to deploying AI in intensive care is that current causal models often sacrifice patient\-specific information to reduce confounding bias, limiting generalization across populations\. This study introduced GITO, a framework that resolves this trade\-off through a sampling\-based MMD \(sMMD\) alignment strategy\. Across synthetic and real\-world ICU datasets, GITO demonstrated robust out\-of\-distribution performance, including cross\-hospital and cross\-ethnicity generalization, and achieved predictive accuracy comparable to or exceeding that of experienced clinicians\. Its interpretable outputs enhanced clinical reasoning and improved the performance of less experienced physicians, underscoring its role as an augmentation tool for clinical expertise\.

From a methodological perspective, GITO addresses the long\-standing tension between confounding removal and information preservation in treatment outcome prediction\. Adversarial balancing methods enforce global distributional invariance, which suppresses treatment\-related signals indiscriminately, including the clinically informative heterogeneity essential for individualized prediction\. The sMMD strategy circumvents this by performing stochastic, sample\-level alignment: at each iteration, small random subsets are drawn from each treatment group and aligned via MMD, imposing a softer constraint that encourages the encoder to learn treatment\-invariant features without overwriting covariate\-level detail\. Three lines of evidence support this interpretation\. First, synthetic experiments under strong confounding revealed that sMMD prevented the representation space from encoding treatment\-assignment artefacts while preserving patient\-level variation needed for accurate long\-range prediction, whereas adversarial objectives progressively eroded this information\. Second, reconstruction analyses confirmed that sMMD\-balanced representations achieved the lowest reconstruction error across all balancing objectives, indicating that downstream outcome heads retained access to richer, patient\-specific features\. Third, per\-variableΔ​R2\\Delta R^\{2\}analysis on MIMIC\-III data showed that adversarial MINE balancing exhibited high variance across clinical variables, achieving large negativeΔ​R2\\Delta R^\{2\}on some \(e\.g\., HCO3\) while incurring substantial positiveΔ​R2\\Delta R^\{2\}on others \(e\.g\., PaO2, glucose\), consistent with an invariance objective that non\-selectively reshapes the representation space\. In contrast, sMMD produced a consistently near\-zeroΔ​R2\\Delta R^\{2\}profile and, notably, achieved negativeΔ​R2\\Delta R^\{2\}on respiratory rate, FiO2, and GCS, variables that the adversarial method eroded, suggesting that stochastic sub\-sampling acts as an implicit regularizer that enhances reconstruction fidelity on these clinically decisive variables\. This selective preservation of variables central to ventilator management \(PaO2, FiO2, respiratory rate\), haemodynamic monitoring \(heart rate\), and consciousness assessment \(GCS\) provides a mechanistic account for the observed generalization advantage across hospitals, ethnicities, and disease categories\.

Clinically, the ventilator weaning experiment illustrates how GITO’s predicted trajectories translate methodological gains into patient\-level benefit\. The 42% reduction in high\-risk false negatives,from a baseline recall of 0\.506 to 0\.719,indicates that forward\-projected physiological trajectories captured deterioration signals absent from retrospective data alone\. In safety\-critical settings such as ventilator weaning, where missed re\-intubation events may lead to delayed intervention and adverse outcomes, this shift toward fewer missed detections is of direct clinical consequence\. The concurrent improvement in AUC \(0\.711 to 0\.756\) further confirms that these gains are not achieved at the expense of increased false alarms\. Importantly, calibration analysis revealed that the GITO\-augmented model not only improved discrimination but also produced substantially more reliable probability estimates \(ECE = 0\.169 vs\. 0\.335 for the baseline\)\. The baseline’s calibration curve exhibited pronounced non\-monotonicity, with observed reintubation rates near zero in the 0\.7\-0\.9 predicted\-probability range, indicating that patients flagged as “high\-risk” by the baseline were, in practice, rarely reintubated\. This form of miscalibration is particularly hazardous in clinical settings, as it may drive unnecessary interventions based on inflated risk estimates\. In contrast, the GITO model maintained a broadly monotonic relationship between predicted and observed event rates, enabling clinicians to interpret its probability outputs as meaningful risk estimates rather than ordinal rankings\. This property is a prerequisite for shared decision\-making, where the absolute magnitude of predicted risk, not merely relative ordering, directly informs the aggressiveness of subsequent management\.

Beyond prediction accuracy, the interpretability framework enhanced clinical reasoning\. The septic shock case study demonstrated how quantitative attribution, counterfactual trajectory analysis, and LLM\-generated explanations jointly enabled clinicians to interpret treatment trade\-offs, providing a coherent reasoning pathway that aligned model predictions with clinically meaningful narratives rather than isolated feature\-importance scores\. Critically, the counterfactual trajectories revealed that the patient’s blood pressure was trending toward recovery even without intervention, while sustained vasopressor use would accelerate recovery but risk overshooting the target range\. This multi\-scenario view enables a clinically important inference that neither trajectory alone would support: short\-term vasopressor administration to hasten stabilization, followed by timely de\-escalation to avoid over\-treatment, a nuanced strategy that goes beyond binary treat\-or\-not decisions\. The fact that treating clinicians independently chose conservative management, with the patient subsequently recovering, corroborates the clinical relevance of GITO’s trajectory\-based reasoning\.

The human\-AI comparison study further contextualizes these advantages\. GITO outperformed all four general\-purpose LLMs by 8\.4\-19\.0 percentage points despite these models being equipped with expert clinical reasoning prompts\. This performance gap is consistent with a fundamental limitation of prompt\-based approaches: while LLMs can apply rule\-based logic to static clinical snapshots \(e\.g\., checking RSBI thresholds\), they lack the capacity to model the non\-linear temporal dynamics,such as gradual haemo\-dynamic drift or evolving respiratory patterns, that GITO’s sMMD\-balanced representations are specifically trained to encode\. When medical students were provided with GITO’s predictions and attribution\-based explanations, their accuracy improved by 14\.7 percentage points \(from 58\.7% to 73\.4%\), demonstrating that the framework’s explanations are sufficiently interpretable to improve novice clinical judgment\. However, collaborative accuracy remained below GITO’s standalone performance \(75\.6%\), indicating imperfect trust calibration: students occasionally overrode correct model predictions based on their own assessment\. This suggests that effective deployment requires not only transparent explanations but also calibration mechanisms that help users recognize when to defer to algorithmic judgment\.

The cooperation study with practicing clinicians reinforced these findings while revealing additional benefits\. GITO assistance reduced decision\-making time by 74% \(from 205\.6 to 52\.6 minutes per case batch\) and improved the clinically acceptable safety rate from 82\.4% to 89\.8%\. The efficiency gain likely reflects the role of attribution\-based explanations in directing clinicians’ attention toward the most prognostically relevant variables, reducing the cognitive burden of manually reviewing high\-dimensional temporal data\. Importantly, the crossover design revealed that clinicians actively revised their initial incorrect predictions after reviewing GITO’s explanations, rather than passively accepting the model’s output\. This distinction is clinically meaningful: it suggests that GITO’s interpretable outputs engage clinicians in a corrective reasoning process, enabling them to identify errors in their own assessment rather than merely deferring to the algorithm\. The simultaneous gains in accuracy, efficiency, and safety position GITO as a decision\-support tool that strengthens human judgment by exposing the temporal dynamics and comparative consequences of alternative interventions\.

Translating these clinical benefits into practice requires both computational feasibility and demographic fairness\. On the engineering side, the replacement of adversarial min\-max optimization with a single\-objective sMMD loss yielded concrete deployment advantages: fewer trainable parameters, faster convergence, an, critically, sub\-50 ms inference latency on standard CPU hardware without GPU acceleration\. This low computational footprint enabled us to release GITO as an open\-source, web\-based platform that can be deployed within secure hospital intranets, lowering the barrier to adoption in resource\-limited settings where access to both experienced intensivists and specialized computing infrastructure is constrained\. On the fairness side, the representation analysis confirmed that sMMD\-balanced embeddings showed no systematic partitioning by gender or ethnicity, demographic attributes that should not influence physiological predictions\. The sole axis of structured separation was age, where patients over 90 years formed a distinct cluster; rather than indicating bias, this pattern is consistent with the well\-established physiological distinctiveness of advanced age, including reduced organ reserve and altered pharmacokinetics, which legitimately influence treatment response\. The contrast, mixing on demographics while preserving clinically meaningful age\-related heterogeneity, indicates that sMMD’s stochastic alignment selectively targets treatment\-assignment confounding without collapsing the physiological variation that underlies individualized prediction\. Together, the computational efficiency and demographic neutrality of the framework support its readiness for broader clinical adoption\. The same methodological foundation, sMMD alignment and attribution\-grounded interpretability, is domain\-agnostic and may extend to non\-ICU settings, including emergency triage, ward\-level monitoring, and chronic disease management\.

Several limitations of this study should be acknowledged, each pointing toward directions for future research\. First, although GITO was evaluated on two geographically distinct cohorts \(MIMIC\-III and AmsterdamUMCdb\), both are retrospective observational datasets; prospective validation in a randomised or pragmatic clinical trial setting remains necessary to confirm real\-world benefit and to quantify the effect of GITO\-assisted decision\-making on patient outcomes\. Second, the current sMMD framework is designed for binary treatment decisions\. Extending GITO to continuous treatment variables, such as drug dosages or infusion rates, through kernel\-based propensity matching would enable dosage optimization and graduated intervention strategies\. Third, our hourly modeling intervals may provide insufficient temporal granularity for time\-sensitive interventions such as vasopressor administration, where clinical effects manifest within minutes\. Multi\-resolution temporal modeling could address this limitation while expanding applicability to additional high\-impact interventions, including antibiotic selection and renal replacement therapy timing\. Fourth, while the attribution\-grounded LLM explanations improved clinician performance in our benchmark, they were not evaluated for factual accuracy against established clinical guidelines; the risk of LLM hallucination, even when constrained by model attributions, cannot be fully eliminated, and structured evaluation against clinical knowledge bases represents an important next step\. Fifth, ensuring temporal validity poses a fundamental challenge: clinical best practices evolve continuously as new evidence emerges, yet the current framework relies on historical data\. Future work will explore mechanisms for safely incorporating newly generated clinical evidence while enabling privacy\-preserving model updates, transforming GITO from a static prediction tool into a continuously learning clinical decision\-support system\.

By bridging methodological innovation with clinical relevance, GITO represents a step toward trustworthy, globally accessible AI for personalized treatment optimization in critical care and beyond\.

## 4Methods

### 4\.1Study design

We evaluated GITO on three patient cohorts spanning varying levels of complexity and real\-world variability: two real\-world ICU databases \(MIMIC\-III and AmsterdamUMCdb\) and one synthetic tumor growth dataset\.

MIMIC\-III electronic medical record data\.The MIMIC\-III database\[journal/scidata2016/Johnson\]is a large, publicly available ICU dataset comprising detailed electronic health records from Beth Israel Deaconess Medical Center \(Boston, U\.S\.A\.\)\. We included patients whose ICU stays lasted between 30 and 60 hours to ensure sufficient temporal coverage for treatment\-outcome modeling\. A total of 25,186 patients met these criteria, comprising 56\.3% males and 43\.7% females, with a mean age of 62\.9 years\. The cohort included patients from 41 self\-reported ethnic groups, with a mean ICU stay of 44\.93 hours\. Among the included patients, vasopressor therapy was administered for an average of 7\.74±\\pm15\.02 hours and mechanical ventilation for 10\.39±\\pm17\.49 hours\. All clinical variables were aggregated at hourly resolution\. The complete list of vital signs, laboratory values, and treatment variables used in the model is provided in Table[1](https://arxiv.org/html/2605.05706#S1.T1)\.

Out\-of\-distribution evaluation partitions\.To assess cross\-ethnicity generalization, we partitioned the MIMIC\-III cohort by self\-reported ethnicity: patients of European descent formed the training set, while Asian, African\-descent, and Latino patients served as three independent out\-of\-distribution \(OOD\) test sets\. These three groups were selected as the largest non\-European subpopulations with sufficient sample sizes for robust evaluation; remaining ethnic groups were excluded due to small cohort sizes\. For cross\-hospital evaluation, models trained on the MIMIC\-III cohort were deployed on AmsterdamUMCdb without fine\-tuning\.

Disease\-category stratification\.To evaluate disease\-specific robustness, patients were stratified into four clinical categories based on primary diagnosis ICD\-9 codes: cardiovascular and circulatory disorders \(e\.g\., acute myocardial infarction, congestive heart failure, coronary artery disease\), neurological disorders \(e\.g\., stroke, intracranial haemorrhage, seizure\), infectious and inflammatory diseases \(e\.g\., pneumonia, sepsis, septic shock\), and gastrointestinal, hepatobiliary, and metabolic disorders \(e\.g\., gastrointestinal bleed, pancreatitis, diabetic ketoacidosis\)\. The full mapping of ICD\-9 codes to disease categories is provided in Supplementary Table[9](https://arxiv.org/html/2605.05706#A2.T9)\.

Ventilator weaning sub\-cohort\.For the ventilator re\-intubation prediction task, we identified a sub\-cohort of 205 mechanically ventilated patients from the MIMIC\-III dataset\. Patients were selected based on ICD\-9 codes for heart failure \(428\.x\) and acute respiratory distress syndrome \(ARDS; 518\.82, 518\.5\), representing high\-risk populations for extubation failure\. Re\-intubation was defined as the resumption of mechanical ventilation within six hours of extubation\. Of the 205 patients, 67 \(43%\) required re\-intubation\.

Septic shock case study\.The individual\-level case study was selected from the MIMIC\-III cohort based on ICD\-9 code 785\.52 \(septic shock\) to demonstrate the interpretability framework on a clinically representative scenario involving vasopressor therapy decisions\.

AmsterdamUMCdb electronic medical record data\.The AmsterdamUMCdb database\[journal/ccm2021/49Thoral\]is a large, openly accessible intensive care dataset from two university medical centers in the Netherlands\. We included adult patients whose ICU stays lasted between 30 and 60 hours\. A total of 2,597 patients met these criteria, comprising 1,614 \(62\.2%\) males and 983 \(37\.8%\) females, with the largest age group being 70\-79 years \(25\.6%\)\. The mean ICU stay was 42\.9±\\pm7\.1 hours\. Among the included patients, vasopressor therapy was administered for an average of 12\.91±\\pm15\.76 hours, and mechanical ventilation for 12\.34±\\pm15\.40 hours\. Baseline demographic and treatment characteristics are summarized in Appendix[8](https://arxiv.org/html/2605.05706#A2.T8)\.

Synthetic patient cohort for controlled confounding evaluation\.To enable controlled evaluation of counterfactual prediction, we simulated a synthetic patient cohort \(n=10,000n=10\{,\}000\) using a pharmacokinetic\-pharmacodynamic \(PKPD\) tumor growth model\[journal/scirep2017/7Geng\]\. This model simulates individualized treatment responses with known ground\-truth counterfactual outcomes, allowing precise quantification of prediction accuracy under varying degrees of treatment selection bias\[conference/iclr2020/Bica,conference/icml2022/Melnychuk,conference/icml2024/wang\]\. The synthetic cohort includes patients with diverse baseline tumor characteristics \(volume and growth rate\) and treatment scenarios spanning 30\-day observation periods\. Confounding strength \(γ\\gamma\) was systematically varied from 0 \(randomized treatment\) to 7 \(strong selection bias\) for prediction experiments, and extended toγ=10\\gamma=10for the representation balancing analysis following established protocol\[conference/iclr2020/Bica\]\. Full simulation parameters, including treatment assignment mechanisms and validation protocols, are detailed in Appendix[B\.1](https://arxiv.org/html/2605.05706#A2.SS1)\.

### 4\.2Problem formulation and notations

The objective of individualized treatment outcome prediction is to estimate the potential evolution of a patient’s physiological state under alternative treatment strategies, including counterfactual scenarios not observed in the historical data\.

Patient trajectories\.For each patientii, we observe a longitudinal health trajectory spanning discrete time stepst=1,…,Tit=1,\\ldots,T^\{i\}\. At each steptt, let𝑿ti∈ℝdx\\bm\{X\}\_\{t\}^\{i\}\\in\\mathbb\{R\}^\{d\_\{x\}\}denote the time\-varying co\-variates \(e\.g\., vital signs and laboratory values\),𝑨ti∈\{a1,…,ada\}\\bm\{A\}\_\{t\}^\{i\}\\in\\\{a\_\{1\},\\ldots,a\_\{d\_\{a\}\}\\\}denote the treatment administered, and𝒀ti∈ℝdy\\bm\{Y\}\_\{t\}^\{i\}\\in\\mathbb\{R\}^\{d\_\{y\}\}denote the outcome of interest at the subsequent step\. Static co\-variates \(e\.g\., age, gender, ethnicity, or comorbidities\) are represented as𝑽i∈ℝdv\\bm\{V\}^\{i\}\\in\\mathbb\{R\}^\{d\_\{v\}\}\. In the present study, treatments are binary \(da=2d\_\{a\}=2\): presence or absence of vasopressor therapy, and presence or absence of mechanical ventilation\. The observational dataset forMMpatients is therefore

𝑯t=\{\{Xti,Ati,Yti\}t=1Ti∪𝑽i\}1M\.\\bm\{H\}\_\{t\}=\\\{\\\{X^\{i\}\_\{t\},A^\{i\}\_\{t\},Y^\{i\}\_\{t\}\\\}\_\{t=1\}^\{T^\{i\}\}\\cup\\bm\{V\}^\{i\}\\\}\_\{1\}^\{M\}\.\(1\)Patient history\.Following prior work\[journal/aim1997/757Rubin,journal/epidemiology2000/Robins,conference/iclr2020/Bica,conference/icml2022/Melnychuk,conference/icml2024/wang,conference/kdd2024/12Wu,conference/nips2024/37Bouchattaoui\], we define the history up to timettas𝑯¯t=\{𝑿¯t,𝑨¯t−1,𝒀¯t,𝑽\}\\bar\{\\bm\{H\}\}\_\{t\}=\\\{\\bar\{\\bm\{X\}\}\_\{t\},\\bar\{\\bm\{A\}\}\_\{t\-1\},\\bar\{\\bm\{Y\}\}\_\{t\},\\bm\{V\}\\\}, where𝑿¯t=\(𝑿1,…,𝑿t\)\\bar\{\\bm\{X\}\}\_\{t\}=\(\\bm\{X\}\_\{1\},\\ldots,\\bm\{X\}\_\{t\}\),𝑨¯t−1=\(𝑨1,…,𝑨t−1\)\\bar\{\\bm\{A\}\}\_\{t\-1\}=\(\\bm\{A\}\_\{1\},\\ldots,\\bm\{A\}\_\{t\-1\}\), and𝒀¯t=\(𝒀1,…,𝒀t\)\\bar\{\\bm\{Y\}\}\_\{t\}=\(\\bm\{Y\}\_\{1\},\\ldots,\\bm\{Y\}\_\{t\}\)\. We condition on𝑨¯t−1\\bar\{\\bm\{A\}\}\_\{t\-1\}to ensure causal consistency, as the outcome at timettis influenced by treatments administered beforett\.

Representation learning objective\.Instead of conditioning directly on high\-dimensional raw trajectories, we employ a representation learning networkfΘℬ​\(⋅\)f\_\{\\Theta\_\{\\mathcal\{B\}\}\}\(\\cdot\)to extract a compact, patient\-specific latent state:

𝓑t=fΘℬ​\(𝑯t\),\\bm\{\\mathcal\{B\}\}\_\{t\}=f\_\{\\Theta\_\{\\mathcal\{B\}\}\}\(\\bm\{H\}\_\{t\}\),\(2\)where𝓑t∈ℝD\\bm\{\\mathcal\{B\}\}\_\{t\}\\in\\mathbb\{R\}^\{D\}summaries the historical information available at timett\. We denote the sequence of latent representations up to timettas𝓑¯t=\(𝓑1,…,𝓑t\)\.\\bar\{\\bm\{\\mathcal\{B\}\}\}\_\{t\}=\(\\bm\{\\mathcal\{B\}\}\_\{1\},\\ldots,\\bm\{\\mathcal\{B\}\}\_\{t\}\)\.This representation trajectory serves as the input for estimating the expected counterfactual outcomes at future horizons\.𝓑¯t\\bar\{\\bm\{\\mathcal\{B\}\}\}\_\{t\}is then used to predict potential outcomes given assigned treatments𝒂¯t:t\+τ−1=\(at,…,at\+τ−1\)\\bar\{\\bm\{a\}\}\_\{t:t\+\\tau\-1\}=\(a\_\{t\},\\ldots,a\_\{t\+\\tau\-1\}\):

𝔼​\[𝒀^t\+τ​\[𝑨¯t:t\+τ−1\]\|𝓑¯t\],\\mathbb\{E\}\[\\hat\{\\bm\{Y\}\}\_\{t\+\\tau\}\[\\bar\{\\bm\{A\}\}\_\{t:t\+\\tau\-1\}\]\|\\bar\{\\bm\{\\mathcal\{B\}\}\}\_\{t\}\],\(3\)whereτ≥1\\tau\\geq 1denotes the prediction horizon, i\.e\., the number of future time steps ahead\. The key challenge is that treatment assignment in observational data is confounded: sicker patients may systematically receive more aggressive interventions\. To mitigate this, we introduce a balancing regularizer into the representation learning objective\. The total training loss comprises an outcome prediction term and a distribution alignment term:

ℒ=ℒΘ𝒀\+λ⋅ℒΘ𝓑,\\mathcal\{L\}=\\mathcal\{L\}\_\{\\Theta\_\{\\bm\{Y\}\}\}\+\\lambda\\cdot\\mathcal\{L\}\_\{\\Theta\_\{\\bm\{\\mathcal\{B\}\}\}\},\(4\)whereℒΘ𝓑\\mathcal\{L\}\_\{\\Theta\_\{\\bm\{\\mathcal\{B\}\}\}\}encourages the learned representations𝓑t\\bm\{\\mathcal\{B\}\}\_\{t\}to be distributionally similar across treatment groups\. In GITO, we instantiate this term using a sampling\-based Maximum Mean Discrepancy \(sMMD\) objective, described in detail in Section[4\.3](https://arxiv.org/html/2605.05706#S4.SS3)\.

Causal assumptions\.To establish the identifiability of treatment effects from observational data, we follow assumptions from previous studies\[journal/aim1997/757Rubin,journal/epidemiology2000/Robins,books/crc2008/Robins,conference/icml2024/wang\], including consistency, sequential ignorability, and sequential overlap\.

Assumption 1: Consistency \(aligning potential and observed outcomes\)\.At time stept\+1t\+1, the observed outcome𝒀t\+1\\bm\{Y\}\_\{t\+1\}is assumed to be the potential outcome𝒀t\+1​\[at\]\\bm\{Y\}\_\{t\+1\}\[a\_\{t\}\]that would have been realised under the assigned treatmentata\_\{t\}attt, i\.e\.,

𝒀t\+1=𝒀t\+1​\[at\],\\bm\{Y\}\_\{t\+1\}=\\bm\{Y\}\_\{t\+1\}\[a\_\{t\}\],\(5\)this assumption ensures that the observed outcome aligns with the counterfactual outcome under the specific, well\-defined treatmentata\_\{t\}\. This requires theStable Unit Treatment Value Assumption \(SUTVA\)to hold, specifically assuming no interference between subjects and a single, consistent version of the treatment\.

Assumption 2: Sequential overlap \(positivity\)\.For reliable estimation, we require that the probability of receiving any specific treatmentata\_\{t\}is bounded away from zero for any patient historyh¯t\\bar\{h\}\_\{t\}that has a non\-zero probability of occurrence\.

0<P​\(𝑨t=at∣𝑯¯t=h¯t\)<1,if​P​\(𝑯¯t=h¯t\)\>0,for all​at∈𝑨t,0<P\(\\bm\{A\}\_\{t\}=a\_\{t\}\\mid\\bar\{\\bm\{H\}\}\_\{t\}=\\bar\{h\}\_\{t\}\)<1,\\quad\\text\{if \}P\(\\bar\{\\bm\{H\}\}\_\{t\}=\\bar\{h\}\_\{t\}\)\>0,\\quad\\text\{for all \}a\_\{t\}\\in\\bm\{A\}\_\{t\},\(6\)this condition, often termedpositivity, ensures that all treatment options remain possible given the observed clinical history\.

Assumption 3: Sequential ignorability \(no unmeasured confounding\)\.The treatment assigned at any timettis assumed to be conditionally independent of the potential outcome, given the observed history\. Formally, for allat∈𝑨a\_\{t\}\\in\\bm\{A\},

𝑨t⟂𝒀t\+1​\[at\]∣𝑯¯t,\\bm\{A\}\_\{t\}\\perp\\bm\{Y\}\_\{t\+1\}\[a\_\{t\}\]\\mid\\bar\{\\bm\{H\}\}\_\{t\},\(7\)this is theno unmeasured confounding \(NUC\)assumption, which is critical for counterfactual outcome estimation\. It implies that all variables influencing both𝑨t\\bm\{A\}\_\{t\}and𝒀t\+1\\bm\{Y\}\_\{t\+1\}have been adequately measured and included in𝑯¯t\\bar\{\\bm\{H\}\}\_\{t\}\. The key mathematical notation is summarized in Table[6](https://arxiv.org/html/2605.05706#S4.T6)\.

Table 6:Summary of mathematical notations used in the GITO framework\.1Input:Historical patient data

𝑯¯t=\{𝑿¯t,𝑨¯t−1,𝒀¯t,𝑽\}\\bar\{\\bm\{H\}\}\_\{t\}=\\\{\\bar\{\\bm\{X\}\}\_\{t\},\\bar\{\\bm\{A\}\}\_\{t\-1\},\\bar\{\\bm\{Y\}\}\_\{t\},\\bm\{V\}\\\}, Treatment assignments

𝑨t\\bm\{A\}\_\{t\};

2Output:Predicted multi\-step outcomes

𝒀^t\+1:t\+τ\\hat\{\\bm\{Y\}\}\_\{t\+1:t\+\\tau\}under treatment sequence

𝑨t:t\+τ−1\\bm\{A\}\_\{t:t\+\\tau\-1\};

3// Training Phase: One\-step ahead prediction with balancing

Initialize

λ←0\\lambda\\leftarrow 0;

⊳\\trianglerightInitial weight for balancing loss

4

5for*e​p​o​c​h=1epoch=1toE​P​O​C​HEPOCH*do

6

7Compute progression factor:

αe​p​o​c​h=21\+exp⁡\(−10⋅e​p​o​c​hE​P​O​C​H\)−1\\alpha\_\{epoch\}=\\frac\{2\}\{1\+\\exp\\left\(\-10\\cdot\\frac\{epoch\}\{EPOCH\}\\right\)\}\-1;

Update balancing weight:

λ←αe​p​o​c​h\\lambda\\leftarrow\\alpha\_\{epoch\};

⊳\\trianglerightProgressively increaseλ\\lambdaduring training

8

9Sample mini\-batch

ℳ\\mathcal\{M\}from training set;

10for*each time stepttinℳ\\mathcal\{M\}*do

11Encode input history

𝑯¯t=\{𝑿¯t,𝑨¯t−1,𝒀¯t,𝑽\}\\bar\{\\bm\{H\}\}\_\{t\}=\\\{\\bar\{\\bm\{X\}\}\_\{t\},\\bar\{\\bm\{A\}\}\_\{t\-1\},\\bar\{\\bm\{Y\}\}\_\{t\},\\bm\{V\}\\\};

12Learn representation

𝓑t=fΘℬ​\(𝑯¯t\)\\bm\{\\mathcal\{B\}\}\_\{t\}=f\_\{\\Theta\_\{\\mathcal\{B\}\}\}\(\\bar\{\\bm\{H\}\}\_\{t\}\);

13Predict one\-step outcome:

𝒀^t\+1=gΘY​\(𝓑t,𝑨t\)\\hat\{\\bm\{Y\}\}\_\{t\+1\}=g\_\{\\Theta\_\{Y\}\}\(\\bm\{\\mathcal\{B\}\}\_\{t\},\\bm\{A\}\_\{t\}\);

14

15

16Partition representations

\{𝓑t\}\\\{\\bm\{\\mathcal\{B\}\}\_\{t\}\\\}into treatment\-specific subsets

𝓓k=\{𝓑i∣ai=k\}\\bm\{\\mathcal\{D\}\}\_\{k\}=\\\{\\bm\{\\mathcal\{B\}\}\_\{i\}\\mid a\_\{i\}=k\\\}for each treatment type

k∈\{1,…,da\}k\\in\\\{1,\\dots,d\_\{a\}\\\};

;

⊳\\trianglerightSampling is done in representation space

17

18Compute prediction loss using Equation[11](https://arxiv.org/html/2605.05706#S4.E11):

ℒΘY=MSE​\(𝒀^t\+1,𝒀t\+1\)\\mathcal\{L\}\_\{\\Theta\_\{Y\}\}=\\text\{MSE\}\(\\hat\{\\bm\{Y\}\}\_\{t\+1\},\\bm\{Y\}\_\{t\+1\}\);

19Compute aggregate sMMD loss via Eq\.[13](https://arxiv.org/html/2605.05706#S4.E13):

ℒΘℬ←∑1≤i<j≤daMMDu2​\(𝑺i,𝑺j\)\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\}\\leftarrow\\sum\_\{1\\leq i<j\\leq d\_\{a\}\}\\text\{MMD\}^\{2\}\_\{u\}\(\\bm\{S\}\_\{i\},\\bm\{S\}\_\{j\}\);

20Compute total loss:

ℒ=ℒΘY\+λ​ℒΘ𝑩\\mathcal\{L\}=\\mathcal\{L\}\_\{\\Theta\_\{Y\}\}\+\\lambda\\mathcal\{L\}\_\{\\Theta\_\{\\bm\{B\}\}\};

21

22Update parameters

Θℬ,ΘY\\Theta\_\{\\mathcal\{B\}\},\\Theta\_\{Y\}via backpropagation;

23

24

25// Inference Phase: Multi\-step prediction via expanding window

26Given initial history

𝑯¯t=\{𝑿¯t,𝑨¯t−1,𝒀¯t,𝑽\}\\bar\{\\bm\{H\}\}\_\{t\}=\\\{\\bar\{\\bm\{X\}\}\_\{t\},\\bar\{\\bm\{A\}\}\_\{t\-1\},\\bar\{\\bm\{Y\}\}\_\{t\},\\bm\{V\}\\\}and treatment assignments

𝑨t:t\+τ−1\\bm\{A\}\_\{t:t\+\\tau\-1\}:

27Initialize cumulative contribution vector:

𝝎¯←𝟎\\bar\{\\bm\{\\omega\}\}\\leftarrow\\bm\{0\};

28

29for*k=1k=1toτ\\tau*do

30Encode history

𝑯¯t\+k−1\\bar\{\\bm\{H\}\}\_\{t\+k\-1\}to get

𝓑t\+k−1=fΘℬ​\(𝑯¯t\+k−1\)\\bm\{\\mathcal\{B\}\}\_\{t\+k\-1\}=f\_\{\\Theta\_\{\\mathcal\{B\}\}\}\(\\bar\{\\bm\{H\}\}\_\{t\+k\-1\}\);

31Predict:

𝒀^t\+k=fΘY​\(𝓑t\+k−1,𝑨t\+k−1\)\\hat\{\\bm\{Y\}\}\_\{t\+k\}=f\_\{\\Theta\_\{Y\}\}\(\\bm\{\\mathcal\{B\}\}\_\{t\+k\-1\},\\bm\{A\}\_\{t\+k\-1\}\);

32Compute variable contributions via Integrated Gradients:

𝝎t\+k=IG​\(𝑯t\+k−1,𝑨t\+k−1\)\\bm\{\\omega\}\_\{t\+k\}=\\text\{IG\}\(\\bm\{H\}\_\{t\+k\-1\},\\bm\{A\}\_\{t\+k\-1\}\);

Accumulate total contribution vector:

𝝎←𝝎\+𝝎t\+k\\bm\{\\omega\}\\leftarrow\\bm\{\\omega\}\+\\bm\{\\omega\}\_\{t\+k\};

⊳\\trianglerightEach𝝎\\bm\{\\omega\}is a vector over input variables

33;

34Update history:

𝑯t\+k←𝑯t\+k−1∪\{𝒀^t\+k,𝑨t\+k−1\}\\bm\{H\}\_\{t\+k\}\\leftarrow\\bm\{H\}\_\{t\+k\-1\}\\cup\\\{\\hat\{\\bm\{Y\}\}\_\{t\+k\},\\bm\{A\}\_\{t\+k\-1\}\\\};

35

36

37Compute average contribution:

𝝎¯←𝝎/τ\\bar\{\\bm\{\\omega\}\}\\leftarrow\\bm\{\\omega\}/\\tau;

Algorithm 1GITO training and inference procedure
### 4\.3GITO framework architecture

The GITO framework integrates three synergistic components to achieve reliable counterfactual prediction: \(1\) a representation learning module \(Encoder,Θℬ\\Theta\_\{\\mathcal\{B\}\}\) that extracts patient state embeddings; \(2\) an outcome prediction network \(Decoder,ΘY\\Theta\_\{Y\}\) that generates trajectory forecasts under arbitrary treatment plans; and \(3\) a sampling\-based distribution alignment mechanism \(sMMD\) that mitigates treatment\-selection confounding\. The overall architecture and training workflow are illustrated in Figure[2](https://arxiv.org/html/2605.05706#S1.F2), with the procedural logic detailed in Algorithm[1](https://arxiv.org/html/2605.05706#algorithm1)\.

Representation Learning \(Encoder\)\.The encoderΘℬ\\Theta\_\{\\mathcal\{B\}\}maps a patient’s historical record𝑯¯t\\bar\{\\bm\{H\}\}\_\{t\}to a compact latent representation\. As defined in Section[4\.2](https://arxiv.org/html/2605.05706#S4.SS2), the history at timettcomprises four components: time\-varying co\-variates𝑿¯t\\bar\{\\bm\{X\}\}\_\{t\}\(vital signs and laboratory values\), previous outcomes𝒀¯t\\bar\{\\bm\{Y\}\}\_\{t\}, past treatments𝑨¯t−1\\bar\{\\bm\{A\}\}\_\{t\-1\}, and static co\-variates𝑽\\bm\{V\}\(demographics\)\. At each time step, the temporal inputs\[𝑿t⊕𝒀t⊕𝑽\]\[\\bm\{X\}\_\{t\}\\oplus\\bm\{Y\}\_\{t\}\\oplus\\bm\{V\}\]are concatenated and processed jointly with the treatment history:

𝓑t=fΘℬ​\(𝑿¯t⊕𝒀¯t⊕𝑽,𝑨¯t−1\),\\bm\{\\mathcal\{B\}\}\_\{t\}=f\_\{\\Theta\_\{\\mathcal\{B\}\}\}\(\\bar\{\\bm\{X\}\}\_\{t\}\\oplus\\bar\{\\bm\{Y\}\}\_\{t\}\\oplus\\bm\{V\},\\;\\bar\{\\bm\{A\}\}\_\{t\-1\}\),\(8\)where⊕\\oplusdenotes feature\-level concatenation\. The resulting latent state𝓑t∈ℝD\\bm\{\\mathcal\{B\}\}\_\{t\}\\in\\mathbb\{R\}^\{D\}serves as an informationally sufficient summary of the patient’s physiological history up to timett\. Our framework is architecture\-agnostic: the encoder can leverage various sequential modeling backbones, such as Transformers\[conference/icml2022/Melnychuk\], LSTMs\[conference/iclr2020/Bica\], or 1D\-CNNs\[conference/icml2024/wang\]\. We empirically validate this portability across all three architectures in Section[2](https://arxiv.org/html/2605.05706#S2)\.

Outcome Prediction \(Decoder\) and Counterfactual Inference\.The prediction networkΘY\\Theta\_\{Y\}takes the learned representation𝓑t\\bm\{\\mathcal\{B\}\}\_\{t\}and a candidate treatment action𝒂t\\bm\{a\}\_\{t\}\(encoded as a one\-hot vector overdad\_\{a\}possible treatments\) to forecast the next outcome:

𝒀^t\+1=gΘY​\(𝓑t,𝒂t\)\.\\hat\{\\bm\{Y\}\}\_\{t\+1\}=g\_\{\\Theta\_\{Y\}\}\(\\bm\{\\mathcal\{B\}\}\_\{t\},\\bm\{a\}\_\{t\}\)\.\(9\)For multi\-step prediction over a horizonτ\\tau, the model operates autoregressively: each predicted outcome𝒀^t\+k\\hat\{\\bm\{Y\}\}\_\{t\+k\}is fed back to update the latent state, generating a continuous trajectory:

𝒀^t\+k\+1=gΘY​\(fΘℬ​\(𝑯¯t,𝒀^t\+1:t\+k\),𝒂t\+k\),k=1,…,τ−1\.\\hat\{\\bm\{Y\}\}\_\{t\+k\+1\}=g\_\{\\Theta\_\{Y\}\}\\\!\\bigl\(f\_\{\\Theta\_\{\\mathcal\{B\}\}\}\(\\bar\{\\bm\{H\}\}\_\{t\},\\hat\{\\bm\{Y\}\}\_\{t\+1:t\+k\}\),\\;\\bm\{a\}\_\{t\+k\}\\bigr\),\\quad k=1,\\dots,\\tau\-1\.\(10\)During training, we employ teacher forcing\[journal/neco1989/1Williams\]: the ground\-truth outcomes𝒀t\+1:t\+k\\bm\{Y\}\_\{t\+1:t\+k\}are supplied as inputs at each recursive step\. At inference, teacher forcing is switched off and the model autoregressively consumes its own predictions𝒀^t\+1:t\+k\\hat\{\\bm\{Y\}\}\_\{t\+1:t\+k\}, enabling multi\-step trajectory generation without access to future observations\. This autoregressive mechanism also enables counterfactual trajectory generation: by substituting alternative treatment sequences\{𝒂t′,𝒂t\+1′,…\}\\\{\\bm\{a\}^\{\\prime\}\_\{t\},\\bm\{a\}^\{\\prime\}\_\{t\+1\},\\dots\\\}into Eq\. \([10](https://arxiv.org/html/2605.05706#S4.E10)\), clinicians can explore hypothetical physiological responses to different treatment plans from the same patient state𝓑t\\bm\{\\mathcal\{B\}\}\_\{t\}\.

Bias Mitigation via sMMD\.To ensure that the learned representations𝓑t\\bm\{\\mathcal\{B\}\}\_\{t\}capture true physiological states rather than treatment\-assignment artifacts, we regularize the encoder with a sampling\-based Maximum Mean Discrepancy \(sMMD\) loss that minimizes the distributional distance between treatment groups in the latent space\. The formulation and computational details of this balancing objective are presented in Section[4\.4](https://arxiv.org/html/2605.05706#S4.SS4)\.

### 4\.4Training objective and optimization

Factual Prediction Loss\.To ensure the model accurately captures physiological dynamics, we minimize the mean squared error \(MSE\) between predicted and observed outcomes\. For a batch of training samples, the prediction loss is:

ℒΘY=1M′​∑i,t‖𝒀t\+1i−gΘY​\(𝓑ti,𝑨ti\)‖2,\\mathcal\{L\}\_\{\\Theta\_\{Y\}\}=\\frac\{1\}\{M^\{\\prime\}\}\\sum\_\{i,t\}\|\|\\bm\{Y\}\_\{t\+1\}^\{i\}\-g\_\{\\Theta\_\{Y\}\}\(\\bm\{\\mathcal\{B\}\}\_\{t\}^\{i\},\\bm\{A\}\_\{t\}^\{i\}\)\|\|^\{2\},\(11\)wheregΘYg\_\{\\Theta\_\{Y\}\}denotes the outcome prediction network andM′M^\{\\prime\}is the total number of transition tuples in the batch\.

Balancing loss\.A central challenge in counterfactual outcome estimation is learning latent representations that are predictive of outcomes yet invariant to treatment assignment\. While adversarial training \(e\.g\., GAN\-based discriminators\) can enforce such invariance, it frequently suffers from optimization instability \(min–max gaming\) and mode collapse\. To overcome these limitations, we adopt a discrepancy\-based regularization strategy using Maximum Mean Discrepancy \(MMD\)\[conference/nips2006/Gretton\]\. Unlike adversarial discriminators, MMD provides a closed\-form, kernel\-based distance metric that directly penalizes distributional mismatch\.

Balancing Loss via sMMD\.A central challenge in counterfactual outcome estimation is learning latent representations that are predictive of outcomes yet invariant to treatment assignment\. While adversarial training \(e\.g\., GAN\-based discriminators\) can enforce such invariance, it frequently suffers from optimization instability and mode collapse\[conference/icml2017/214arjovsky\]\. To overcome these limitations, we adopt Maximum Mean Discrepancy \(MMD\)\[conference/nips2006/Gretton\], a kernel\-based distributional distance that provides a stable, closed\-form regularization signal without min\-max optimization\.

For a set ofdad\_\{a\}possible treatments, we minimize the average pairwise discrepancy across all treatment groups:

ℒΘ​ℬ=1\(da2\)​∑1≤i<j≤daMMD2​\(𝓓i,𝓓j\),\\mathcal\{L\}\_\{\\Theta\{\\mathcal\{B\}\}\}=\\frac\{1\}\{\\binom\{d\_\{a\}\}\{2\}\}\\sum\_\{1\\leq i<j\\leq d\_\{a\}\}\\text\{MMD\}^\{2\}\(\\bm\{\\mathcal\{D\}\}\_\{i\},\\bm\{\\mathcal\{D\}\}\_\{j\}\),\(12\)where𝓓k=\{𝓑t∣𝑨t=k\}\\bm\{\\mathcal\{D\}\}\_\{k\}=\\\{\\bm\{\\mathcal\{B\}\}\_\{t\}\\mid\\bm\{A\}\_\{t\}=k\\\}denotes the set of latent representations associated with treatmentkk\.

Computing MMD over the entire dataset scales quadratically \(O​\(N2\)O\(N^\{2\}\)\) and is therefore computationally prohibitive\. We adopt a sampling\-based approximation \(sMMD\), formally instantiated as an unbiased U\-statistic estimator\. At each iteration, fixed\-size random subsets𝑺i⊂𝒟i\\bm\{S\}\_\{i\}\\subset\\mathcal\{D\}\_\{i\}and𝑺j⊂𝒟j\\bm\{S\}\_\{j\}\\subset\\mathcal\{D\}\_\{j\}\(with\|𝑺\|=Ns\|\\bm\{S\}\|=N\_\{s\}\) are drawn to compute:

MMDu2​\(𝑺i,𝑺j\)\\displaystyle\\text\{MMD\}^\{2\}\_\{u\}\(\\bm\{S\}\_\{i\},\\bm\{S\}\_\{j\}\)=1Ns​\(Ns−1\)​∑x∈𝑺i∑x′∈𝑺i,x′≠xk​\(x,x′\)\+1Ns​\(Ns−1\)​∑y∈𝑺j∑y′∈𝑺j,y′≠yk​\(y,y′\)\\displaystyle=\\frac\{1\}\{N\_\{s\}\(N\_\{s\}\-1\)\}\\sum\_\{x\\in\\bm\{S\}\_\{i\}\}\\sum\_\{x^\{\\prime\}\\in\\bm\{S\}\_\{i\},x^\{\\prime\}\\neq x\}k\(x,x^\{\\prime\}\)\+\\frac\{1\}\{N\_\{s\}\(N\_\{s\}\-1\)\}\\sum\_\{y\\in\\bm\{S\}\_\{j\}\}\\sum\_\{y^\{\\prime\}\\in\\bm\{S\}\_\{j\},y^\{\\prime\}\\neq y\}k\(y,y^\{\\prime\}\)\(13\)−2Ns2​∑x∈𝑺i∑y∈𝑺jk​\(x,y\)\.\\displaystyle\-\\frac\{2\}\{N\_\{s\}^\{2\}\}\\sum\_\{x\\in\\bm\{S\}\_\{i\}\}\\sum\_\{y\\in\\bm\{S\}\_\{j\}\}k\(x,y\)\.The expected value of this estimator over random subsets is identical to the population squared MMD, ensuring an unbiased gradient signal\. We setNs=200N\_\{s\}=200, which provides a stable variance\-computation trade\-off; an ablation overNs∈\{50,100,200,500\}N\_\{s\}\\in\\\{50,100,200,500\\\}showed negligible performance variation \(<<0\.5% RMSE\), confirming the estimator’s robustness to this choice\.

We employ a Radial Basis Function \(RBF\) kernelk​\(𝒙,𝒚\)=exp⁡\(−‖𝒙−𝒚‖2/2​σ2\)k\(\\bm\{x\},\\bm\{y\}\)=\\exp\(\-\\\|\\bm\{x\}\-\\bm\{y\}\\\|^\{2\}/2\\sigma^\{2\}\)\. Because the RBF kernel is characteristic,MMD=0\\text\{MMD\}=0uniquely implies distributional equivalence\. The bandwidthσ\\sigmais dynamically determined at each step using the median heuristic, settingσ\\sigmato the square root of the median pairwise distance within the pooled subsets𝑺i∪𝑺j\\bm\{S\}\_\{i\}\\cup\\bm\{S\}\_\{j\}, yielding a data\-adaptive kernel scale that adjusts to the optimization trajectory without manual tuning\. We optimizeMMD2\\text\{MMD\}^\{2\}rather than MMD becauseMMD2\\text\{MMD\}^\{2\}is directly proportional to the variance in kernel embeddings, avoids numerical instability of the square root near zero, and thus ensures smoother gradient\-based alignment\. The procedure is summarized in Algorithm[2](https://arxiv.org/html/2605.05706#algorithm2)

Joint Optimization\.The entire framework is trained end\-to\-end by jointly optimizing the factual prediction error and the distribution balancing loss:

\(Θ^Y,Θ^ℬ\)=arg⁡minΘY,Θℬ⁡ℒΘY​\(ΘY,Θℬ\)\+λ​ℒΘℬ​\(Θℬ\),\(\\hat\{\\Theta\}\_\{Y\},\\hat\{\\Theta\}\_\{\\mathcal\{B\}\}\)=\\arg\\min\_\{\\Theta\_\{Y\},\\Theta\_\{\\mathcal\{B\}\}\}\\mathcal\{L\}\_\{\\Theta\_\{Y\}\}\(\\Theta\_\{Y\},\\Theta\_\{\\mathcal\{B\}\}\)\+\\lambda\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\}\(\\Theta\_\{\\mathcal\{B\}\}\),\(14\)whereλ\\lambdagoverns the trade\-off between predictive accuracy and confounding removal\.Sigmoidal Annealing Schedule\.Applying strong balancing regularization during early training can suppress physiologically meaningful variability before the encoder has learned useful representations\. To ensure stable convergence, we adopt a smooth sigmoidal schedule forλ\\lambda\. At training epochee:

λe=21\+exp⁡\(−10⋅eE\)−1,\\lambda\_\{e\}=\\frac\{2\}\{1\+\\exp\\\!\\bigl\(\-10\\cdot\\tfrac\{e\}\{E\}\\bigr\)\}\-1,\(15\)whereEEdenotes the total number of epochs\. This schedule initializesλ\\lambdanear zero, allowing the encoder to first learn physiologically relevant features, then smoothly increases toward full regularization during mid\-to\-late training\. This progressive strategy yielded consistently stable training behavior across all three backbone architectures evaluated \(Appendix[C\.2](https://arxiv.org/html/2605.05706#A3.SS2)\)\.

Input:Batch representations

𝓑∈ℝN×D\\bm\{\\mathcal\{B\}\}\\in\\mathbb\{R\}^\{N\\times D\}, treatment labels

𝑨∈\{1,…,da\}N\\bm\{A\}\\in\\\{1,\\dots,d\_\{a\}\\\}^\{N\}, sample size

NsN\_\{s\}, kernel function

k​\(⋅,⋅\)k\(\\cdot,\\cdot\)
Output:Balancing loss

ℒΘℬ\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\}
1

//Group representations by treatment

2Partition

𝓑\\bm\{\\mathcal\{B\}\}into subsets

\{𝓓1,…,𝓓da\}\\\{\\bm\{\\mathcal\{D\}\}\_\{1\},\\dots,\\bm\{\\mathcal\{D\}\}\_\{d\_\{a\}\}\\\}where

𝒟k=\{𝒃∈𝓑∣a=k\}\\mathcal\{D\}\_\{k\}=\\\{\\bm\{b\}\\in\\bm\{\\mathcal\{B\}\}\\mid a=k\\\};

3

ℒΘℬ←0\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\}\\leftarrow 0;

4

5foreach*treatment pair\(i,j\)\(i,j\)with1≤i<j≤da1\\leq i<j\\leq d\_\{a\}*do

6if*\|𝓓i\|<Ns\|\\bm\{\\mathcal\{D\}\}\_\{i\}\|<N\_\{s\}or\|𝓓j\|<Ns\|\\bm\{\\mathcal\{D\}\}\_\{j\}\|<N\_\{s\}*then

continue;

//Skip if insufficient

7

8end if

9Sample

𝑺i∼𝒟i\\bm\{S\}\_\{i\}\\sim\\mathcal\{D\}\_\{i\},

𝑺j∼𝒟j\\bm\{S\}\_\{j\}\\sim\\mathcal\{D\}\_\{j\}, each of size

NsN\_\{s\};

10

//Unbiased kernel estimates \(Eq\.[13](https://arxiv.org/html/2605.05706#S4.E13)\)

11

μ^i​i←1Ns​\(Ns−1\)​∑p≠qk​\(𝒔pi,𝒔qi\)\\hat\{\\mu\}\_\{ii\}\\leftarrow\\frac\{1\}\{N\_\{s\}\(N\_\{s\}\{\-\}1\)\}\\\!\\sum\_\{p\\neq q\}k\(\\bm\{s\}\_\{p\}^\{i\},\\bm\{s\}\_\{q\}^\{i\}\);

12

μ^j​j←1Ns​\(Ns−1\)​∑p≠qk​\(𝒔pj,𝒔qj\)\\hat\{\\mu\}\_\{jj\}\\leftarrow\\frac\{1\}\{N\_\{s\}\(N\_\{s\}\{\-\}1\)\}\\\!\\sum\_\{p\\neq q\}k\(\\bm\{s\}\_\{p\}^\{j\},\\bm\{s\}\_\{q\}^\{j\}\);

13

μ^i​j←1Ns2​∑p,qk​\(𝒔pi,𝒔qj\)\\hat\{\\mu\}\_\{ij\}\\leftarrow\\frac\{1\}\{N\_\{s\}^\{2\}\}\\sum\_\{p,q\}k\(\\bm\{s\}\_\{p\}^\{i\},\\bm\{s\}\_\{q\}^\{j\}\);

14

15

MMD^i​j2←μ^i​i\+μ^j​j−2​μ^i​j\\widehat\{\\mathrm\{MMD\}\}^\{2\}\_\{ij\}\\leftarrow\\hat\{\\mu\}\_\{ii\}\+\\hat\{\\mu\}\_\{jj\}\-2\\hat\{\\mu\}\_\{ij\};

16

ℒΘℬ←ℒΘℬ\+MMD^i​j2\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\}\\leftarrow\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\}\+\\widehat\{\\mathrm\{MMD\}\}^\{2\}\_\{ij\};

17

18end foreach

19

ℒΘℬ←ℒΘℬ/\(da2\)\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\}\\leftarrow\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\}\\big/\\binom\{d\_\{a\}\}\{2\};

//Average over all pairs

20

21return

ℒΘℬ\\mathcal\{L\}\_\{\\Theta\_\{\\mathcal\{B\}\}\};

Algorithm 2Computation of Sampling\-based MMD Loss
### 4\.5Downstream clinical classifier for ventilator re\-intubation

To evaluate the clinical utility of GITO\-generated counterfactuals, we developed a downstream predictive model tasked with assessing the risk of re\-intubation within six hours post\-extubation\. The classifier takes as input a concatenated multivariate time series consisting ofThist=12T\_\{\\text\{hist\}\}=12hours of observed history andTpred=6T\_\{\\text\{pred\}\}=6hours of predicted trajectories from GITO\.

Model Architecture\.We adopted a 1D Residual Convolutional Neural Network \(ResNet\-1D\) architecture\[journal/artmed2021/117Jia\], which is effective at capturing local temporal dependencies in physiological signals\. The backbone comprised three residual blocks with progressively increasing channel dimensions \(\[64,128,256\]\[64,128,256\]\), each containing two 1D convolutional layers \(kernel sizek=3k=3, stride=1=1, same padding\) with batch normalization and ReLU activation, followed by an identity shortcut connection\. The output of the final residual block was passed through global average pooling and a fully connected layer \(256→1256\\to 1\) with sigmoid activation\. The dataset was split at the patient level into training \(70%\), validation \(15%\), and test \(15%\) subsets\. The model parameters were optimized using the AdamW algorithm \(learning rateη=10−3\\eta=10^\{\-3\}, weight decayλ=10−4\\lambda=10^\{\-4\}\)\. Given the inherent class imbalance in ventilator weaning outcomes \(where re\-intubation events are the minority\), we employed two complementary loss functions to ensure robust sensitivity:

Weighted Binary Cross\-Entropy \(BCE\):To penalize false negatives more heavily, we applied class weights inversely proportional to class frequencies \(w\+=N/\(2​N\+\),w−=N/\(2​N−\)w\_\{\+\}=N/\(2N\_\{\+\}\),\\;w\_\{\-\}=N/\(2N\_\{\-\}\), yieldingw\+≈1\.53,w−≈0\.74w\_\{\+\}\\approx 1\.53,\\;w\_\{\-\}\\approx 0\.74for the 43% re\-intubation prevalence\):

ℒBCE=−1N​∑i=1N\[w\+​yi​log⁡\(p^i\)\+w−​\(1−yi\)​log⁡\(1−p^i\)\],\\mathcal\{L\}\_\{\\mathrm\{BCE\}\}=\-\\frac\{1\}\{N\}\\sum\_\{i=1\}^\{N\}\\Big\[w\_\{\+\}y\_\{i\}\\log\(\\hat\{p\}\_\{i\}\)\+w\_\{\-\}\(1\-y\_\{i\}\)\\log\(1\-\\hat\{p\}\_\{i\}\)\\Big\],\(16\)whereyi∈\{0,1\}y\_\{i\}\\in\\\{0,1\\\}is the ground\-truth label,p^i=σ​\(z^i\)\\hat\{p\}\_\{i\}=\\sigma\(\\hat\{z\}\_\{i\}\)is the predicted probability\.

Focal Loss:In sensitivity analyses, we further addressed the “easy\-negative” dominance problem using Focal Loss, which dynamically down\-weights well\-classified examples to focus training on hard samples:

ℒfocal=−1N​∑i=1Nα​\(1−pt,i\)γ​log⁡\(pt,i\),\\mathcal\{L\}\_\{\\mathrm\{focal\}\}=\-\\frac\{1\}\{N\}\\sum\_\{i=1\}^\{N\}\\alpha\(1\-p\_\{t,i\}\)^\{\\gamma\}\\log\(p\_\{t,i\}\),\(17\)wherept,i=p^ip\_\{t,i\}=\\hat\{p\}\_\{i\}ifyi=1y\_\{i\}=1andpt,i=1−p^ip\_\{t,i\}=1\-\\hat\{p\}\_\{i\}otherwise\. We set the balancing factorα=0\.25\\alpha=0\.25and the focusing parameterγ=2\.0\\gamma=2\.0following standard practices for dense object detection and rare event prediction\.

### 4\.6Human\-AI comparison and collaboration study

Study design and participants\.To evaluate GITO’s impact on clinical decision\-making, we employed a two\-period, two\-treatment crossover design involving nine healthcare professionals stratified by clinical experience: three attending physicians \(senior clinicians\), three residents \(junior clinicians\), and three medical students\. All participants provided written informed consent\. Participants predicted the risk of re\-intubation within six hours post\-extubation for the full cohort of 205 MIMIC\-III ventilator weaning cases \(see Section[4\.1](https://arxiv.org/html/2605.05706#S4.SS1)\)\. The prediction task was binary: whether the patient would require re\-intubation \(label=1=1\) or not \(label=0=0\)\.

Crossover procedure\.Participants were randomly assigned to one of two sequences to control for learning and order effects\. Group 1 \(n=6n=6\) first made predictions without AI assistance \(control period\), then with GITO assistance \(treatment period\)\. Group 2 \(n=3n=3\) followed the reverse order\. In all experimental phases, clinicians were provided with the patient’s baseline clinical information, including demographic data \(age, gender\), admission diagnosis, and key vital signs from the preceding 12 hours \(features consistent with those defined in Table[1](https://arxiv.org/html/2605.05706#S1.T1)\)\. In the GITO assistance phase \(the treatment period\), this baseline information was augmented\. Clinicians additionally received GITO’s output: a quantitative re\-intubation risk prediction and a corresponding attribution\-based interpretable explanation generated from the patient’s time\-series data\. Participants were required to review this AI output before finalizing their clinical decision, enabling the assessment of the effect of human\-AI collaboration on accuracy and process\.Outcome measures\.Three primary endpoints were evaluated:

1. 1\.Prediction accuracy: the proportion of cases in which the clinician’s prediction matched the ground\-truth re\-intubation outcome \(Top\-1 accuracy\), analyzed using a linear mixed\-effects model to correct for the crossover design and control for clinical experience tier\.
2. 2\.Decision\-making time: the elapsed time from case presentation to final prediction submission, recorded per batch of 205 cases\.
3. 3\.Clinically acceptable safety rate\(1−FNR1\-\\text\{FNR\}\): In this prediction task, the two types of error carry asymmetric clinical consequences\. A false negative, predicting label=0=0when the patient actually requires re\-intubation, represents the most dangerous error, as it may lead to premature extubation and subsequent respiratory failure requiring emergency re\-intubation\. Conversely, a false positive,predicting label=1=1when the patient would not require re\-intubation, results in a conservative decision to delay extubation, an outcome that, while suboptimal, does not pose an immediate safety risk to the patient\. The clinically acceptable safety rate was therefore defined as1−FNR1\-\\text\{FNR\}\(equivalently, recall for the positive class\), reflecting the proportion of true high\-risk cases correctly identified\.

Accuracy and safety rate were compared between control and treatment periods using within\-subject paired comparisons\. The crossover design enabled each participant to serve as their own control, reducing inter\-individual variability\.

Foundation model benchmarking\.To establish a rigorous AI baseline, we evaluated four state\-of\-the\-art general\-purpose large language models,GPT\-4o, GPT\-5\.1, Gemini\-3, and Grok\-4\.1,against GITO on the same 205\-patient cohort\. To ensure these models operated at their peak potential, we did not rely on zero\-shot inference\. Instead, we implemented a Structured Clinical Reasoning Pipeline that encoded expert ICU knowledge into the system prompt\. The prompt instructed the LLMs to follow a three\-step reasoning process:

1. 1\.Feature extraction and rule application:The models first evaluated key physiological indicators against standard weaning thresholds derived from clinical literature\. Specific criteria included: - •Respiratory mechanics and gas exchange:Rapid Shallow Breathing Index \(RSBI\)<105<105breaths/min/L\[journal/nejm1991/21yang,journal/atm2016/11Karthika\]; Tidal Volume\>5\>5mL/kg\[journal/nejm1995/332Estebon,tobin2006principles\]; Respiratory Rate8≤RR≤308\\leq\\text\{RR\}\\leq 30\[journal/erj2007/29Boles\]; PaO2/FiO2ratio≥200\\geq 200mmHg\[journal/nejm1995/332Estebon,journal/erj2007/29Boles\]; PaCO<250\{\}\_\{2\}<50mmHg or within baseline range\[journal/erj2007/29Boles\]\. - •Haemodynamic stability:Mean Arterial Pressure \(MAP\)≥65\\geq 65mmHg\[journal/erj2007/29Boles,journal/ccm2021/49Evans\]; Heart Rate60≤HR≤14060\\leq\\text\{HR\}\\leq 140beats/min\[journal/erj2007/29Boles\]\. - •Acid\-base and metabolic status:pH7\.357\.35\-7\.457\.45\[journal/erj2007/29Boles,tobin2006principles\]; Lactate<2<2mmol/L\[journal/ccm2015/43Thille\]; Bicarbonate \(HCO−3\{\}\_\{3\}^\{\-\}\)2222\-3030mEq/L; Potassium \(K\+\)3\.53\.5\-5\.05\.0mEq/L\[journal/ec2018/7Kardalas,journal/statpearls2025/Castro\]; Sodium \(Na\+\)135135\-145145mEq/L\[journal/afp2015/5Braun,journal/statpearls2025/Castro\]\. - •Renal function:BUN/Creatinine evaluated for acute deterioration; renal failure\-associated fluid overload can precipitate pulmonary oedema\[journal/ccm2015/43Thille\]\.
2. 2\.Composite scoring:Based on these checks, the models synthesized a “Spontaneous Breathing Trial \(SBT\) Likelihood” \(High/Moderate/Low\)\.
3. 3\.Probabilistic prediction:Finally, models predicted the probability of re\-intubation, which was binarized using a decision threshold of 0\.5\.

Each model received the same patient data in an identical prompt format\. All predictions were deterministic \(temperature=0=0\) to ensure reproducibility\. This prompt design ensures that any observed performance gap reflects an inherent limitation of LLMs in processing temporal physiological dynamics, rather than a lack of domain knowledge\. The medical students in the comparison analysis \(n=3n=3\) were the same three students enrolled in the crossover study; their unassisted predictions from the control period served as the human baseline\.

### 4\.7Evaluation metrics

General protocol\.For multi\-step treatment outcome prediction experiments on MIMIC\-III, AmsterdamUMCdb, and the synthetic tumor growth dataset, all models were trained and evaluated overn=10n=10independent runs with fixed random seeds \(seeds 10, 101, 1010, 10101, 101010, 50, 505, 5050, 50505, 505050\) controlling data splitting, model initialization, and sMMD sampling; results are reported as mean±\\pmstandard deviation\. For downstream clinical tasks \(e\.g\., ventilator re\-intubation prediction\),n=5n=5independent runs were used unless otherwise stated\. Statistical significance between paired model comparisons was assessed using two\-sided pairedtt\-tests, with significance thresholds indicated by∗\(p<0\.05p<0\.05\) and∗∗\(p<0\.01p<0\.01\) in the corresponding tables\.

Model comparison \(regression\)\.For multi\-step treatment outcome prediction experiments, we evaluated numerical accuracy using the root mean squared error \(RMSE\), a standard regression metric that quantifies the discrepancy between predicted and observed outcomes:

RMSE=1M​∑i=1M\(𝒀^i−𝒀i\)2\.\\mathrm\{RMSE\}=\\sqrt\{\\frac\{1\}\{M\}\\sum\_\{i=1\}^\{M\}\(\\hat\{\\bm\{Y\}\}\_\{i\}\-\\bm\{Y\}\_\{i\}\)^\{2\}\}\.\(18\)Ventilator re\-intubation prediction \(classification\)\.For the binary re\-intubation prediction task, model performance was evaluated on an independent test cohort using accuracy, precision, recall, F1\-score, and area under the receiver operating characteristic curve \(AUROC\) as primary metrics:

Accuracy\\displaystyle\\mathrm\{Accuracy\}=TP\+TNTP\+TN\+FP\+FN,\\displaystyle=\\frac\{\\mathrm\{TP\}\+\\mathrm\{TN\}\}\{\\mathrm\{TP\}\+\\mathrm\{TN\}\+\\mathrm\{FP\}\+\\mathrm\{FN\}\},\(19\)Precision\\displaystyle\\mathrm\{Precision\}=TPTP\+FP,\\displaystyle=\\frac\{\\mathrm\{TP\}\}\{\\mathrm\{TP\}\+\\mathrm\{FP\}\},\(20\)Recall\\displaystyle\\mathrm\{Recall\}=TPTP\+FN,\\displaystyle=\\frac\{\\mathrm\{TP\}\}\{\\mathrm\{TP\}\+\\mathrm\{FN\}\},\(21\)F1\\displaystyle\\mathrm\{F1\}=2⋅Precision⋅RecallPrecision\+Recall,\\displaystyle=2\\cdot\\frac\{\\mathrm\{Precision\}\\cdot\\mathrm\{Recall\}\}\{\\mathrm\{Precision\}\+\\mathrm\{Recall\}\},\(22\)whereTP\\mathrm\{TP\},TN\\mathrm\{TN\},FP\\mathrm\{FP\}, andFN\\mathrm\{FN\}denote true positive, true negative, false positive, and false negative counts, respectively\. The AUROC was computed by integrating the receiver operating characteristic \(ROC\) curve over all classification thresholds:

AUROC=∫01TPR​\(FPR\)​d​\(FPR\),\\mathrm\{AUROC\}=\\int\_\{0\}^\{1\}\\mathrm\{TPR\}\(\\mathrm\{FPR\}\)\\,d\(\\mathrm\{FPR\}\),\(23\)whereTPR\\mathrm\{TPR\}andFPR\\mathrm\{FPR\}represent the true positive and false positive rates, respectively\. 95% confidence intervals for all classification metrics were estimated using bootstrap resampling \(10,000 iterations\) on the test cohort\.

Per\-variable information preservation \(Δ​R2\\Delta R^\{2\}\)\.To quantify how each balancing strategy affects the retention of individual clinical variables in the learned representations, we computed per\-variableΔ​R2\\Delta R^\{2\}:

Δ​Rj2=Rj,unbalanced2−Rj,balanced2,\\Delta R^\{2\}\_\{j\}=R^\{2\}\_\{j,\\text\{unbalanced\}\}\-R^\{2\}\_\{j,\\text\{balanced\}\},\(24\)whereRj,unbalanced2R^\{2\}\_\{j,\\text\{unbalanced\}\}andRj,balanced2R^\{2\}\_\{j,\\text\{balanced\}\}denote the coefficient of determination for reconstructing variablejjfrom the unbalanced and balanced representations, respectively\. A positiveΔ​R2\\Delta R^\{2\}indicates that balancing incurred additional information loss beyond baseline compression; values near zero indicate no added cost; negative values indicate that balancing improved reconstruction relative to the unbalanced encoder\. Error bars denote 95% confidence intervals overn=10n=10runs\.

Reconstruction loss\.To evaluate how much patient\-specific information each balancing objective preserves, an independent decoder network was trained to reconstruct the original patient co\-variates from the balanced representations\. The reconstruction objective was mean squared error \(MSE\):

ℒrecon=1M⋅dx​∑i=1M‖𝑿^i−𝑿i‖2,\\mathcal\{L\}\_\{\\text\{recon\}\}=\\frac\{1\}\{M\\cdot d\_\{x\}\}\\sum\_\{i=1\}^\{M\}\\\|\\hat\{\\bm\{X\}\}\_\{i\}\-\\bm\{X\}\_\{i\}\\\|^\{2\},\(25\)where𝑿^i\\hat\{\\bm\{X\}\}\_\{i\}is the reconstructed covariate vector and𝑿i\\bm\{X\}\_\{i\}is the original\. The decoder was trained with the encoder weights frozen, ensuring that reconstruction quality reflects only the information content of the balanced representations, not the decoder’s capacity\. Both training and validation reconstruction losses are reported\.

Evaluation of representation quality\.To validate the effectiveness of sMMD in mitigating confounding bias, we analyzed the structure of the learned latent representations using t\-Distributed Stochastic Neighbor Embedding \(t\-SNE\)\[journal/jmlr2008/11Van\]with perplexity=30=30and 1,000 iterations\. High\-dimensional patient embeddings were projected into a two\-dimensional manifold to assess two properties:

1. 1\.Treatment invariance: whether the distributions of treated and control groups are indistinguishable \(well\-mixed\) in the latent space, rather than forming treatment\-specific clusters\.
2. 2\.Preservation of prognostic structure: whether the embeddings retain clinically meaningful heterogeneity \(e\.g\., age\-related physiological variation\) while removing spurious demographic correlations \(e\.g\., gender, ethnicity\)\.

### 4\.8Interpretability and explainability pipeline

Gradient\-based feature attribution\.To quantify the contribution of individual physiological variables to the model’s predictions, we employed Integrated Gradients \(IG\)\[conference/icml2017/70Mukund\], a widely adopted axiomatic attribution method\. Given an input sequence𝑿¯t\\bar\{\\bm\{X\}\}\_\{t\}, we defined the baseline reference𝑿¯t0\\bar\{\\bm\{X\}\}^\{0\}\_\{t\}as the cohort mean for each variable\. This yields a raw attribution scoreϕi\(j\)\\phi^\{\(j\)\}\_\{i\}for variableiicorresponding to prediction stepjj\. To summarize importance across the entire prediction windowτ\\tau, we averaged the contributions:

ωiraw=1τ​∑j=1τϕi\(j\)\.\\omega\_\{i\}^\{\\text\{raw\}\}=\\frac\{1\}\{\\tau\}\\sum\_\{j=1\}^\{\\tau\}\\phi^\{\(j\)\}\_\{i\}\.\(26\)For comparative analysis, these raw scores were normalized using a softmax function to produce a relative importance distribution:

ωi=exp⁡\(ωiraw\)∑k=1dxexp⁡\(ωkraw\),\\omega\_\{i\}=\\frac\{\\exp\(\\omega\_\{i\}^\{\\text\{raw\}\}\)\}\{\\sum\_\{k=1\}^\{d\_\{x\}\}\\exp\(\\omega\_\{k\}^\{\\text\{raw\}\}\)\},\(27\)wheredxd\_\{x\}is the number of input variables\. This normalization highlights the dominant physiological signals driving the forecast\. In addition to the aggregated scoresωi\\omega\_\{i\}, the per\-step attributionsϕi\(j\)\\phi^\{\(j\)\}\_\{i\}are visualized individually to reveal how each variable’s contribution evolves over the prediction horizon \(see Figure[6](https://arxiv.org/html/2605.05706#S2.F6), upper right\)\.

Counterfactual trajectory generation\.To enable clinicians to compare alternative treatment strategies, GITO generates multi\-step predicted trajectories under each candidate treatment plan\. Given the learned representation𝓑¯t\\bar\{\\bm\{\\mathcal\{B\}\}\}\_\{t\}at the current time step, the model autoregressively rolls out future predictions by conditioning on a specified treatment sequence𝒂¯t:t\+τ−1\\bar\{\\bm\{a\}\}\_\{t:t\+\\tau\-1\}\. At each roll\-out step, the predicted outcome𝒀^t\+j\\hat\{\\bm\{Y\}\}\_\{t\+j\}is fed back as input for the next step\. In the present study, four scenarios were evaluated: no treatment, vasopressor only, ventilation only, and both treatments simultaneously\. The resulting trajectory set provides a comparative view of expected physiological evolution under each strategy\.

LLM\-driven interpretability and multimodal reasoning\.To bridge the gap between quantitative risk scores and clinical reasoning, we developed a structured multi\-modal prompting pipeline that synthesizes predictions into interpretable narratives\. The pipeline employs a large language model \(LLM\), GPT\-4o \(versiongpt\-4o\-2024\-08\-06\) by default, with temperature set to 0 for deterministic output and a maximum token limit of 4,096, though the platform supports user\-selectable alternatives, acting under a strict “critical care physician” persona to produce a three\-tiered clinical summary\.

To mitigate the risk of generative hallucination, we implemented a two\-stage Chain\-of\-Thought \(CoT\) framework that hybridizes explicit data extraction with scenario\-based reasoning:

- •Stage I: Visual grounding and extraction\.Unlike standard “black\-box” generation, the pipeline first enforces a grounding step\. The model is supplied with pre\-computed statistics \(current value, moving average, linear trend\) of the top\-kk\(k=5k=5\) contributing features identified by Integrated Gradients\. It is guided to cross\-reference these structured inputs with the encoded visual charts \(vital signs trend and patient history\) to validate physiological states \(e\.g\., verifying whether MAP is trending below the 65 mmHg threshold\) before narrative construction begins\.
- •Stage II: Structured narrative synthesis\.Leveraging the grounded data, the model generates a structured explanation following a rigorous protocol: \(1\)Primary metric analysis: assessment of the target outcome’s trajectory relative to historical interventions; \(2\)Holistic vital status: integration of secondary vital sign trends; \(3\)Comparative scenario reasoning: a disciplined comparison of the counterfactual prediction trajectories \(None, Vaso, Vent, Both\)\. The prompt enforces a “quantification discipline,” requiring the model to cite specific approximate deltas when comparing scenarios and prohibiting the inference of superiority when differences are clinically negligible \(<2%<2\\%probability delta\)\.

The LLM is additionally instructed to output a structured JSON response that includes, for each treatment scenario, a numerical preference score \(as a percentage\) reflecting the estimated clinical suitability based on trajectory analysis and grounded vital sign assessment\. The preference scores across all scenarios are constrained to sum to 100%, providing an interpretable ranking of treatment options\. An abbreviated example of the generated output is shown in Box[2\.4](https://arxiv.org/html/2605.05706#S2.SS4); the full prompt and output schema are provided in Appendix[C\.5](https://arxiv.org/html/2605.05706#A3.SS5)\.

### 4\.9Implementation details and experimental setup

Data preprocessing\.Given the irregular sampling frequency inherent in ICU electronic health records, handling missing data is critical\. We applied a Last Observation Carried Forward \(LOCF\) strategy to impute missing values in time\-varying co\-variates, followed by Next Observation Carried Backward \(NOCB\) for any remaining initial gaps\. This approach preserves the temporal continuity of physiological states\. To facilitate stable model convergence, continuous co\-variates \(both static and temporal\) were standardized using Z\-score normalization:

xt,i′=xt,i−μiσi,x^\{\\prime\}\_\{t,i\}=\\frac\{x\_\{t,i\}\-\\mu\_\{i\}\}\{\\sigma\_\{i\}\},\(28\)whereμi\\mu\_\{i\}andσi\\sigma\_\{i\}represent the global mean and standard deviation of featureiicalculated across the entire training corpus\.

Baseline comparisons\.We benchmarked GITO against state\-of\-the\-art treatment outcome estimation models:

- •CRN \(Counterfactual Recurrent Network\)\[conference/iclr2020/Bica\]: Uses LSTMs with domain adversarial training to build balanced representations\.
- •CT \(Causal Transformer\)\[conference/icml2022/Melnychuk\]: A Transformer\-based architecture that uses distinct attention heads for processing treatment and covariate history\.
- •ACTIN \(Adversarial Counterfactual Temporal Inference Network\)\[conference/icml2024/wang\]: The backbone architecture of our proposed method, which originally uses a GAN\-based discriminator for balancing\. We used ACTIN as the primary baseline to isolate the specific contribution of our sMMD module\.

To construct the sMMD\-enhanced variants \(CRN\-sMMD, CT\-sMMD, and ACTIN\-sMMD\), we replaced each model’s original adversarial balancing mechanism with our proposed sMMD loss while keeping all other architectural components and hyperparameters unchanged\. This controlled substitution isolates the effect of the balancing strategy from other architectural differences\.

Computational environment\.All models were implemented in Python 3\.10 using the PyTorch 2\.1 deep learning framework\. Model training was performed on the NUS Vanda high\-performance computing cluster equipped with2×2\\timesNVIDIA Tesla A40 GPUs \(48 GB VRAM each\) and2×2\\times36\-core Intel Xeon 8452Y CPUs\. Inference latency benchmarks \(Table[4](https://arxiv.org/html/2605.05706#S2.T4)\) were measured on CPU only \(Intel Xeon 8452Y\) without GPU acceleration, to reflect deployment conditions in resource\-constrained hospital environments\.

Training and evaluation protocol\.To ensure rigorous evaluation and prevent data leakage, the dataset was randomly partitioned at the patient level into training \(70%\), validation \(15%\), and test \(15%\) sets\. For the out\-of\-distribution \(OOD\) evaluation on MIMIC\-III, the training and validation sets comprised exclusively patients of European descent; the remaining non\-European subpopulations \(Asian, African\-descent, and Latino\) served as independent OOD test sets\.

Model parameters were optimized using the Adam optimizer\. Training was terminated via early stopping with a patience of 10 epochs based on validation\-set RMSE; the checkpoint with the lowest validation loss was selected for evaluation\. For all baselines and multi\-step\-ahead prediction, teacher forcing was used during training\[journal/neco1989/1Williams\]\. During evaluation of multi\-step\-ahead prediction, teacher forcing was switched off and models autoregressively fed their own predictions\.

Key hyperparameters for the primary MIMIC\-III experiments are summarized in Table[7](https://arxiv.org/html/2605.05706#S4.T7); full per\-dataset configurations are provided in Supplementary Tables[12](https://arxiv.org/html/2605.05706#A3.T12)\-[14](https://arxiv.org/html/2605.05706#A3.T14)\.

Table 7:Key hyperparameters for GITO and baselines on the MIMIC\-III dataset\.

## Data availability

The MIMIC\-III dataset is publicly available from PhysioNet \(https:// mimic\.physionet\.org/\)\. AmsterdamUMCdb is publicly available from the Amsterdam Medical Data Science website \(https://amsterdammedicaldatascience\. nl/\)\.

## Code availability

The implementation of GITO, together with preprocessing scripts and trained models, is available at https://github\.com/peisong\-zhang/COEOT\.

## Supplementary information

Supplementary information is available in the online version of the paper\. It includes additional figures, tables, methods, and source data supporting the findings of this study\.

## Appendix ARelated works

Early methodologies for counterfactual estimation focus on static data; existing methods mainly fall into the following categories: propensity score\-based approaches, covariate adjustment techniques, matching algorithms, and outcome modeling methods\.

Propensity score\-based methods, such as propensity score matching and inverse probability weighting, estimate the probability of receiving treatment conditional on co\-variates\[journal/jasa1984/516Rosenbaum,journal/biometrika1983/41Rosenbaum,journal/aim1997/757Rubin\]\. These methods aim to balance the covariate distributions between treated and untreated groups to reduce confounding bias\. However, they are sensitive to model misspecification and cannot address hidden confounding factors\. Furthermore, matching approaches may reduce sample size and statistical power because some units cannot be matched\. Another class of methods involves directly adjusting for co\-variates via regression models, such as linear regression and generalized linear models\[journal/pads2008/1202Austin\]\. While these models are easy to implement and interpret, their performance relies heavily on correct model specification\. Matching methods non\-parametrically pair treated units with control units of similar covariate distributions\. These approaches are attractive due to their simplicity\[journal/tkdd2021/1Yao\]\. However, in high\-dimensional settings, achieving good matches becomes increasingly difficult, leading to potential imbalance and loss of information due to discarded units\. Moreover, all these approaches may be limited in real\-world applications, particularly in healthcare where patient conditions and treatment effects often evolve over time, and decisions are made sequentially based on time\-varying information\.

To address these limitations, numerous methods have been developed for estimating treatment effects in time\-varying settings\. Unlike static approaches, these methods explicitly account for temporal dependencies, where treatments, co\-variates, and outcomes change over time\. These methods, such as marginal structural models \(MSMs\)\[journal/mathmodel1986/7Robins,journal/epidemiology2000/Robins,journal/jasa2001/454Hernán,books/crc2008/Robins\], attempt to mitigate confounding bias by reweighting or stratifying data based on estimated treatment probabilities\. However, these methods rely on strong assumptions about model specification and often struggle to capture complex temporal dependencies inherent in longitudinal data\. To address this limitation, more recent research integrates sequential models like recurrent neural networks \(RNNs\) with causal inference techniques such as inverse probability weighting \(IPW\)\[conference/nips2018/31Lim,conference/kdd2024/12Wu\]or G\-computation\[conference/mlh2021/282Li,conference/mlh2024/252Hong\]to better account for time\-varying treatment effects\. These methods extend traditional balancing strategies by leveraging sequential models to capture temporal patterns in observational data\. Among these, state\-of\-the\-art machine learning methods such as the counterfactual recurrent network \(CRN\)\[conference/iclr2020/Bica\], causal transformer \(CT\)\[conference/icml2022/Melnychuk\], and adversarial counterfactual temporal inference network \(ACTIN\)\[conference/icml2024/wang\]employ adversarial training to enhance balancing between treatment groups to mitigate bias\. Specifically, the core idea of adversarial domain adaptation is to train a discriminator to distinguish between treatment groups while the feature extractor learns representations that make this discrimination difficult\. This forces the learned representations to be treatment\-invariant, effectively reducing the influence of confounding variables in treatment assignment\. These adversarial\-based balancing strategies provide a flexible and data\-driven approach to balancing, avoiding the need for explicit functional form assumptions required by traditional causal inference methods\.

Despite their ability to remove associations between patient history and treatment assignments, they are highly sensitive to distribution shifts\[conference/nips2022/Moayeri\], meaning they may fail when applied to out\-of\-distribution scenarios\. In addition, these methods often struggle with the trade\-off between balancing and covariate information preservation\. The aggressive removal of treatment\-related signals can inadvertently lead to information loss, particularly under severe confounding bias, reducing the accuracy of the estimation\[conference/icml2024/Huang\]\. This highlights the need for alternative approaches that achieve robust balancing while preserving effective information for counterfactual outcome estimation\. To mitigate this issue, recent research proposed a covariance de\-correlation\-based mechanism to achieve a better trade\-off between bias reduction and prediction accuracy\[conference/iclr2025/wang\]\. However, this method is designed specifically for state\-space models \(SSMs\), which do not generalize well to other settings\. Moreover, none of these methods have investigated potential distribution shifts in real\-world healthcare, limiting their practical applicability\.

## Appendix BDataset

This section provides supplementary cohort\-level statistics for the datasets used in this study\. Table[8](https://arxiv.org/html/2605.05706#A2.T8)summaries the demographic and clinical characteristics of the AmsterdamUMCdb cohort, which served as a geographically independent validation set\. Table[9](https://arxiv.org/html/2605.05706#A2.T9)lists the mapping of primary\-diagnosis ICD\-9 codes to the four disease categories used for disease\-specific subgroup analysis on MIMIC\-III\.

Table 8:Demographic and clinical characteristics of the patient cohort with intensive care unit \(ICU\) stays between 30 and 60 hours\. This cohort retains the diversity of the broader population while ensuring tractable training and analysis\.Abbreviations: SD, standard deviation; SVR, systemic vascular resistance; PEEP, positive end\-expiratory pressure\. \(a\) For time\-varying vital signs, mean values were computed over the first 24 hours following ICU admission\. \(b\) Treatment durations reflect the average number of hours continuous or intermittent interventions were administered, averaged across all patients\.

Table 9:Mapping of primary diagnosis ICD\-9 codes to disease categories used for disease\-specific subgroup analysis\.### B\.1Details on synthetic tumor growth dataset

The Tumor Growth \(TG\) simulator\[journal/scirep2017/7Geng\]models the tumor volume𝒀^t\+1\\hat\{\\bm\{Y\}\}\_\{t\+1\}att\+1t\+1days post\-diagnosis, where the outcome is one\-dimensional \(i\.e\.,dy=1d\_\{y\}=1\)\. The model incorporates two binary treatment variables: \(i\) radiotherapy \(𝑨t\(r\)\\bm\{A\}\_\{t\}^\{\(r\)\}\) and \(ii\) chemotherapy \(𝑨t\(c\)\\bm\{A\}\_\{t\}^\{\(c\)\}\)\. Specifically, radiotherapy induces an immediate effectd​\(t\)d\(t\)on the subsequent outcome, whereas chemotherapy exerts an influence over multiple future time points through an exponentially decaying effectC​\(t\)C\(t\)\. And they are modeled as following equation:

𝒀t\+1=\(1\+ρ​log⁡K𝒀t−βc​Ct−\(αr​dt\+βt​dt2\)\+ϵt\)​𝒀t,\\bm\{Y\}\_\{t\+1\}=\(1\+\\rho\\log\\frac\{K\}\{\\bm\{Y\}\_\{t\}\}\-\\beta\_\{c\}C\_\{t\}\-\(\\alpha\_\{r\}d\_\{t\}\+\\beta\_\{t\}d\_\{t\}^\{2\}\)\+\\epsilon\_\{t\}\)\\bm\{Y\}\_\{t\},\(29\)whereρ,K,βc,αr,βr\\rho,K,\\beta\_\{c\},\\alpha\_\{r\},\\beta\_\{r\}are parameters in the simulation and and whereϵt∼N​\(0,0\.012\)\\epsilon\_\{t\}\\sim N\(0,0\.01^\{2\}\)is the sampled noise\. The parametersβc,αr,βr\\beta\_\{c\},\\alpha\_\{r\},\\beta\_\{r\}characterize individual patient responses and are drawn from a mixture of truncated normal distributions with three components\. For exact parameter values, refer to the code implementation\. The mixture component indices are treated as static co\-variates \(dv=1d\_\{v\}=1\)\. Time\-varying confounding is introduced through a biased treatment assignment, which remains identical for both treatment groups; i\.e\.,

𝑨tc,𝑨tr∼Bernoulli\(σ\(γDm​a​x\(D¯15\(¯𝒀t−1\)−Dm​a​x2\)\)\),\\bm\{A\}\_\{t\}^\{c\},\\bm\{A\}\_\{t\}^\{r\}\\sim Bernoulli\(\\sigma\(\\frac\{\\gamma\}\{D\_\{max\}\}\(\\bar\{D\}\_\{15\}\\bar\{\(\}\{\\bm\{Y\}\}\_\{t\-1\}\)\-\\frac\{D\_\{max\}\}\{2\}\)\)\),\(30\)whereσ\\sigmais a sigmoid activation with an output between \[0,1\] as the probability parameter of the Bernoulli distribution,Dm​a​xD\_\{max\}is the maximum tumor diameter,D¯15\(¯𝒀t−1\)\\bar\{D\}\_\{15\}\\bar\{\(\}\{\\bm\{Y\}\}\_\{t\-1\}\)is the average tumor diameter over the last 15 days, andγ\\gammais a confounding parameter, controlling the ”biasing effect” of tumor size on treatment assignment\. The largerγ\\gammais, the stronger the bias is\. This is a mechanism that introduces confounding dynamically based on tumor growth, simulating a real\-world scenario where physicians may adjust treatment strategies according to the tumor size\.

### B\.2Experiment details on MIMIC\-III dataset

We utilized the MIMIC\-extract dataset\[conference/chil2020/222Wang\], which applies a standardized preprocessing pipeline to the MIMIC\-III dataset\[journal/scidata2016/Johnson\]\. MIMIC\-extract offers intensive care unit \(ICU\) data aggregated on an hourly basis\. To handle missing values, both forward and backward filling are employed, followed by standard normalization of all continuous time\-varying features\. Our analysis includes 29 vital sign indicators, such as heart rate, respiratory rate, diastolic blood pressure, glucose, blood urea nitrogen, and 19 others\. In addition, we consider 3 static attributes \(e\.g\., age, gender, and ethnicity\)\. Categorical features are represented using one\-hot encoding\. These variables, comprising both dynamic co\-variates and invariant characteristics, are considered potential confounders\. We examine two binary treatments: vasopressor administration and mechanical ventilation\. The primary outcome of interest is \(diastolic\) blood pressure, which may either increase or decrease in response to these treatments\. This variation is crucial for clinicians when assessing the anticipated progression of patient trajectories under such interventions\. From the full MIMIC\-III cohort of 25,186 eligible patients \(see Methods: Patient Cohort\), each experiment randomly sampled 5,000 individuals who were admitted to the ICU for at least 30 hours, with a maximum stay capped at 60 hours\. The dataset was split into training, validation, and testing sets in a 70%/15%/15% ratio\. The study’s methodology was adapted based on the forecast horizonτ\\tau\. Specifically:

1. 1\.For one\-step\-ahead predictions, the full test set trajectories were used\.
2. 2\.For multi\-step prediction \(τ≥\\tau\\geq2\) the process involved definingτm​a​x≥τ\\tau\_\{max\}\\geq\\tauas the longest projection horizon\. Sub\-trajectories of at leastτm​a​x\+1\\tau\_\{max\}\+1steps were then extracted using a rolling origin approach, while initial vital sign readings up toτ\(i\)−τm​a​x\+1\\tau^\{\(i\)\}\-\\tau\_\{max\}\+1were removed to eliminate any foresight bias in the prediction process\.

To evaluate the generalization of the model under distribution shifts, we designed two out\-of\-distribution \(OOD\) settings based on both patient demographics and admission diagnoses\.

Ethnicity\-based OOD setting\.In the first OOD setting, we trained the model exclusively on White patients and evaluated its performance on non\-White subpopulations, treating each ethnicity group as an independent OOD test set\. To ensure statistical reliability and sufficient sample size, we selected Asian \(N=119\), Black \(N=383\), and Hispanic \(N=143\) patients, as these groups are the most represented among non\-White patients in the dataset\. This setup allows us to assess the model’s robustness across various ethnicity subpopulations, identifying potential biases in treatment outcome prediction\.

Table 10:Demographic and clinical characteristics of the patient cohort in ethnicity\-based distribution shift settings\.CharacteristicWhiteAsianBlackHispanic\(N=3,560\)\(N=119\)\(N=383\)\(N=143\)AgeAge \(≤\\leq89\), mean \(SD\)64\.16 \(16\.32\)59\.18 \(19\.19\)58\.30 \(17\.93\)53\.77 \(17\.38\)Age\>\>89, n \(%\)209 \(5\.4%\)9 \(7\.56%\)18 \(4\.70%\)2 \(1\.40%\)GenderMale, n \(%\)2,001 \(56\.2%\)62 \(52\.1%\)180 \(47\.4%\)93 \(64\.5%\)Female, n \(%\)1,559 \(33\.8%\)57 \(47\.9%\)203 \(52\.6%\)50 \(35\.5%\)Vitalsa\{\}^\{\\textit\{a\}\}Heart rate \(bpm\)84\.76 \(15\.23\)84\.85 \(16\.11\)87\.31 \(16\.47\)89\.08 \(16\.27\)Red blood cells \(M/μ\\muL\)3\.64 \(0\.60\)3\.65 \(0\.69\)3\.74 \(0\.69\)3\.80 \(0\.65\)Sodium \(mEq/L\)138\.46 \(4\.29\)139\.26 \(4\.66\)138\.66 \(4\.88\)139\.13 \(3\.96\)Mean BP \(mmHg\)77\.77 \(10\.45\)79\.47 \(11\.04\)82\.21 \(11\.24\)82\.92 \(12\.23\)SVR \(dyn⋅\\cdots/cm5\)1,499\.76 \(697\.57\)1,571\.20 \(694\.78\)1,575\.48 \(696\.18\)1,666\.86 \(656\.79\)Glucose \(mg/dL\)137\.09 \(37\.48\)137\.29 \(34\.15\)145\.40 \(47\.91\)139\.41 \(43\.12\)Chloride urine \(mEq/L\)67\.14 \(48\.43\)64\.44 \(50\.21\)65\.22 \(47\.60\)72\.71 \(49\.99\)GCS score13\.58 \(2\.56\)13\.40 \(2\.78\)13\.88 \(2\.22\)13\.77 \(2\.47\)Hematocrit \(%\)32\.43 \(5\.06\)31\.97 \(5\.21\)32\.47 \(5\.58\)33\.24 \(5\.77\)PEEP \(cmH2O\)5\.15 \(2\.21\)4\.79 \(1\.45\)5\.08 \(2\.33\)4\.95 \(1\.94\)Respiratory rate \(bpm\)18\.50 \(3\.90\)17\.81 \(3\.87\)19\.39 \(4\.38\)18\.28 \(4\.38\)Prothrombin time \(sec\)15\.18 \(5\.12\)14\.16 \(2\.64\)15\.00 \(4\.02\)14\.75 \(3\.90\)Cholesterol \(mg/dL\)162\.94 \(48\.35\)166\.28 \(48\.51\)160\.52 \(47\.05\)161\.60 \(47\.20\)Hemoglobin \(g/dL\)11\.04 \(1\.81\)10\.77 \(1\.80\)10\.85 \(1\.94\)11\.37 \(2\.02\)Creatinine \(mg/dL\)1\.25 \(1\.13\)1\.28 \(1\.27\)1\.83 \(2\.17\)1\.29 \(1\.42\)BUN \(mg/dL\)23\.30 \(18\.90\)23\.83 \(18\.49\)26\.56 \(23\.00\)21\.03 \(16\.46\)Bicarbonate \(mEq/L\)24\.04 \(4\.04\)23\.55 \(3\.66\)23\.80 \(4\.30\)23\.54 \(3\.49\)Calcium ionized \(mmol/L\)1\.56 \(7\.69\)1\.15 \(0\.17\)1\.68 \(10\.05\)2\.85 \(13\.03\)pCO2\(mmHg\)40\.94 \(8\.75\)40\.19 \(8\.10\)41\.51 \(9\.96\)39\.41 \(7\.34\)Magnesium \(mg/dL\)2\.01 \(0\.34\)2\.09 \(0\.55\)1\.99 \(0\.34\)1\.97 \(0\.29\)Anion gap \(mEq/L\)13\.75 \(3\.13\)13\.64 \(2\.86\)14\.44 \(3\.42\)13\.84 \(3\.10\)Phosphorous \(mg/dL\)3\.49 \(1\.12\)3\.41 \(0\.97\)3\.61 \(1\.26\)3\.60 \(1\.00\)Venous PvO2\(mmHg\)50\.69 \(13\.38\)52\.61 \(14\.21\)50\.91 \(13\.36\)52\.11 \(13\.34\)Platelets \(K/μ\\muL\)217\.89 \(104\.52\)202\.96 \(102\.38\)228\.30 \(100\.73\)215\.04 \(89\.34\)Calcium urine \(mg/dL\)5\.52 \(9\.81\)6\.56 \(11\.59\)5\.16 \(8\.93\)5\.19 \(9\.86\)Diastolic BP \(mmHg\)60\.02 \(10\.13\)61\.77 \(9\.86\)65\.14 \(11\.05\)65\.83 \(11\.21\)Treatmentsb\{\}^\{\\textit\{b\}\}Vasopressor \(h\)4\.15 \(0\.38\)3\.97 \(0\.37\)2\.38 \(0\.30\)2\.19 \(0\.91\)Ventilation \(h\)5\.52 \(0\.42\)6\.12 \(0\.44\)5\.12 \(0\.41\)6\.29 \(2\.62\)Abbreviations: SD, standard deviation; SVR, systemic vascular resistance; GCS, Glasgow Coma Scale; PEEP, positive end\-expiratory pressure\. \(a\) For time\-varying vital signs, mean values were computed over the first 24 h following ICU admission\. \(b\) Treatment durations reflect the average number of hours of continuous or intermittent interventions, averaged across all patients\.Diagnosis\-based OOD setting\.The second OOD setting introduces an additional level of domain shift by selecting specific broad disease categories from the OOD test set\. We focus on the following major disease groups to ensure a sufficient sample size: cardiovascular diseases, neurological disorders, and infectious and inflammatory diseases\. This selection strategy avoids the issue of data scarcity that would arise from choosing a single specific disease\. By evaluating the model on these distinct diagnostic subgroups, we aim to investigate whether domain shifts in underlying medical conditions further impact model performance beyond demographic shifts alone\. The details of this OOD setting are presented in Table[11](https://arxiv.org/html/2605.05706#A2.T11)\.

Table 11:Details of the OOD test set in the diagnosis\-based OOD setting\.
### B\.3Experiment details on human\-AI comparison and collaboration

This section provides supplementary details for the human\-AI comparison and collaboration study described in Methods \(Section[4\.6](https://arxiv.org/html/2605.05706#S4.SS6)\)\.Patient cohort and task\.The evaluation cohort comprised 205 ventilator weaning cases from MIMIC\-III \(see Section[4\.1](https://arxiv.org/html/2605.05706#S4.SS1)\)\. Patients were selected based on ICD\-9 diagnosis codes associated with conditions frequently requiring mechanical ventilation: heart failure \(428\.x: 428, 4280, 4281, 42820\-42843, 4289\), and acute respiratory distress syndrome \(518\.82, 518\.5\)\. The prediction task was binary: whether the patient would require re\-intubation within six hours of extubation \(label=1=1\) or not \(label=0=0\)\.

Data presentation to clinicians\.For human participants, each patient’s clinical information was presented via a structured dashboard\. Vital signs were arranged by their average values and visualized in groups of five time\-series panels, with the final panel displaying the prediction target and assigned treatment history\. Demographic data \(age, gender\), admission diagnosis, and key vital signs from the preceding 12 hours were provided in all conditions\. In the GITO\-assisted condition, the dashboard was augmented with GITO’s quantitative risk prediction and an attribution\-based interpretable explanation\. Figure[11](https://arxiv.org/html/2605.05706#A2.F11)and Figure[12](https://arxiv.org/html/2605.05706#A2.F12)\) present an example of the clinical visualization interface used in both the unassisted and GITO\-assisted conditions\.

LLM prompt design\.Four large language models, GPT\-4o, GPT\-5\.1, Gemini\-3, and Grok\-4\.1, were evaluated using a Structured Clinical Reasoning Pipeline\. To ensure that foundation models operated at their peak potential in the human\-AI comparison experiment \(Section[4\.6](https://arxiv.org/html/2605.05706#S4.SS6)\), we designed a structured clinical reasoning pipeline rather than relying on zero\-shot inference\. The complete system prompt and user prompt are presented in Box[B\.3](https://arxiv.org/html/2605.05706#A2.SS3)\. All models received identical prompts with temperature set to 0 for deterministic output\. The pipeline enforces a three\-stage reasoning process:

1. 1\.Stage A \(Data Extraction\):The model extracts the most recent values for 14 physiological parameters from the vitals trend image\.
2. 2\.Stage B \(Clinical Scoring\):Extracted values are evaluated against established weaning criteria, including the Rapid Shallow Breathing Index \(RSBI\), oxygenation status \(PaO2/FiO2ratio\), acid\-base balance, haemodynamic stability, and neurological status, culminating in a composite Spontaneous Breathing Trial \(SBT\) likelihood assessment\.
3. 3\.Stage C \(Risk Prediction\):The model outputs a probabilistic re\-intubation risk estimate \(0\.0\-1\.0\), a categorical risk level, and a 3\-6 sentence clinical rationale\.

All outputs were constrained to a structured JSON schema to enable automated parsing and comparison against ground\-truth labels\. The decision threshold was set at 0\.5 \(probability≥0\.5\\geq 0\.5classified as requiring re\-intubation\)\.

Box 1: Structured prompting protocol for foundation model re\-intubation predictionSystem Prompt``` You are an ICU clinical decision support system. STAGE A - DATA EXTRACTION Extract the most recent approximate values for: - Heart rate (beats/min) - Mean blood pressure MAP (mmHg) - Respiratory rate RR (breaths/min) - Tidal volume VT (mL) - PaCO2 (mmHg), pH - Bicarbonate (mEq/L), Potassium (mEq/L) - Sodium (mEq/L), Creatinine (mg/dL) - BUN (mg/dL), Lactate (mmol/L) - Glasgow Coma Scale total score - PaO2/FiO2 ratio (PF ratio) Return null if unreadable. STAGE B - CLINICAL SCORING Compute: - RSBI = RR / (VT / 1000) "good" if <80; "acceptable" if 80-105; "poor" if >105 - Oxygenation_ok: PF > 200 - Acid_base_ok: pH 7.35-7.45 and PaCO2 35-45 - Hemodynamics_ok: MAP >= 65 - Neurological_ok: GCS >= 8 - Respiratory_mechanics_ok: RR 8-30, VT adequate - SBT_likelihood: high / moderate / low STAGE C - 6-HOUR REINTUBATION RISK Predict: - probability (0.0-1.0) - risk_level: "high_risk" (>=0.5) / "moderate_risk" (0.2-0.49) / "low_risk" (<0.2) - explanation: 3-6 sentence clinical rationale OUTPUT FORMAT: valid JSON only. ``` User Prompt \(per patient\)``` Below is a 12-hour vitals trend for an ICU patient. Perform STAGE A, STAGE B, and STAGE C as defined. Patient demographics: - Gender: [Male/Female] - Age: [age] - Diagnosis: [primary diagnosis at admission] - Ventilator duration: [hours] [Attached: 12-hour vitals trend image] ```

![Refer to caption](https://arxiv.org/html/2605.05706v1/img/patient_template.jpg)Figure 11:Clinical dashboard \- unassisted condition\.Clinicians received patient demographics \(age, sex\), primary diagnosis at admission, cumulative duration of mechanical ventilation, vasopressor administration history with dosage, and 12\-hour temporal trajectories of 16 physiological parameters\. See Figure[12](https://arxiv.org/html/2605.05706#A2.F12)for the GITO\-assisted condition\.![Refer to caption](https://arxiv.org/html/2605.05706v1/img/patient_template_w_pred.png)Figure 12:Clinical dashboard \- GITO\-assisted condition\.In addition to the information in Figure[11](https://arxiv.org/html/2605.05706#A2.F11), clinicians received GITO’s 6\-hour predicted trajectories \(orange curves\) appended to the observed history, together with a quantitative re\-intubation risk score and an attribution\-based interpretable explanation\. Both dashboards were presented ton=4n=4medical students andn=3n=3clinicians during the two\-period crossover experiment in counterbalanced order \(see Methods[4\.6](https://arxiv.org/html/2605.05706#S4.SS6)\)\.

## Appendix CBaseline methods

### C\.1Details about baseline models

![Refer to caption](https://arxiv.org/html/2605.05706v1/x5.png)Figure 13:Graphical illustration of adversarial balancing strategies employed by baseline models\.Counterfactual Recurrent Network \(CRN\)\[conference/iclr2020/Bica\]: A recurrent neural network \(RNN\)\-based framework for counterfactual outcome estimation in longitudinal settings\. The model employs a sequence\-to\-sequence architecture to capture patient history over time, enabling the prediction of treatment outcomes at future time points\. To address time\-varying confounding bias, the model incorporates an adversarial gradient reversal \(ADR\) strategy\. As illustrated in Figure[13](https://arxiv.org/html/2605.05706#A3.F13), patient trajectories are first encoded by an RNN\-based representation moduleΘℬ\\Theta\_\{\\mathcal\{B\}\}to obtain latent representations\. These representations are then fed into two parallel branches:

1. 1\.Outcome Prediction Branch \(Green Path\): The representations are combined with assigned treatments and passed to an outcome predictor, which generates treatment outcome estimates\. The model minimizes a prediction loss defined in Equation[11](https://arxiv.org/html/2605.05706#S4.E11), encouraging the representations to be informative for outcome prediction\.
2. 2\.Adversarial Balancing Branch \(Red Path\): The same representations are simultaneously passed to a treatment discriminatorΘA\\Theta\_\{A\}, which attempts to predict the treatment assignments by minimizing a treatment loss defined in Equation[31](https://arxiv.org/html/2605.05706#A3.E31)\. Meanwhile, the representation moduleΘℬ\\Theta\_\{\\mathcal\{B\}\}is trained adversarially via gradient reversal, aiming to maximize the discriminator’s loss\. This is designed to fool the discriminator, encouraging the learned representations to be invariant to treatment assignments and thereby mitigating confounding bias\.

Through this dual\-objective design, the model learns representations that are both predictive of outcomes and balanced with respect to treatment groups, making it well\-suited for counterfactual inference in time\-varying clinical settings\.

ℒΘA=−∑j=1da𝕀​\(At=aj\)​log⁡ΘA​\(𝓑t\),\\mathcal\{L\}\_\{\\Theta\_\{A\}\}=\-\\sum\_\{j=1\}^\{d\_\{a\}\}\\mathbb\{I\}\(A\_\{t\}=a\_\{j\}\)\\log\\Theta\_\{A\}\(\\bm\{\\mathcal\{B\}\}\_\{t\}\),\(31\)
Causal Transformer \(CT\)\[conference/icml2022/Melnychuk\]: A transformer\-based model designed for counterfactual outcome estimation in longitudinal healthcare settings\. CT leverages both self\-attention and cross\-attention mechanisms to extract rich contextual representations from patient trajectories, capturing complex temporal dependencies\. To address treatment\-related confounding, CT employs an adversarial balancing strategy\. Unlike CRN, which uses gradient reversal to directly maximize the treatment prediction loss, CT introduces acausal domain confusion \(CDC\)loss to achieve treatment\-invariant representations\. This strategy encourages the representation moduleΘ𝓑\\Theta\_\{\\bm\{\\mathcal\{B\}\}\}to generate embeddings that are indistinguishable across different treatment groups, thereby rendering treatment assignments uninformative with respect to the learned representations\. Specifically, during adversarial training, the treatment discriminatorΘA\\Theta\_\{A\}is optimized to accurately classify treatment assignments using the standard treatment classification loss \(Equation[31](https://arxiv.org/html/2605.05706#A3.E31)\)\. In contrast, the representation encoderΘ𝓑\\Theta\_\{\\bm\{\\mathcal\{B\}\}\}is trained to confuse the discriminator via the CDC loss \(Equation[32](https://arxiv.org/html/2605.05706#A3.E32)\), which pushes the treatment prediction distribution toward uniformity, simulating random guessing:

ℒconf=−∑j=1da1da​log⁡ΘA​\(ℬt\),\\mathcal\{L\}\_\{\\text\{conf\}\}=\-\\sum\_\{j=1\}^\{d\_\{a\}\}\\frac\{1\}\{d\_\{a\}\}\\log\\Theta\_\{A\}\(\\mathcal\{B\}\_\{t\}\),\(32\)By explicitly enforcing treatment\-invariant representations through CDC, CT effectively mitigates confounding bias while maintaining the temporal coherence of patient trajectories\.

Adversarial Counterfactual Temporal Inference Network \(ACTIN\)\[conference/icml2024/wang\]: A temporal counterfactual inference framework that introduces a dual\-module architecture to improve the estimation of treatment outcomes over time\. To address confounding bias, ACTIN adopts a generative adversarial network \(GAN\)\-based strategy, which differs fundamentally from gradient reversal \(as in CRN\) and domain confusion \(as in CT\)\. In this approach, the treatment discriminatorΘA\\Theta\_\{A\}is trained to distinguish between real treatment assignments𝑨\\bm\{A\}and synthetic \(or “fake”\) treatments𝑨fake\\bm\{A\}\_\{\\text\{fake\}\}, which are generated by randomly shuffling or sampling from the treatment distribution\. These treatments are then paired with learned representations𝓑\\bm\{\\mathcal\{B\}\}as input to the discriminator\. The adversarial objective consists of two competing goals:

1. 1\.The discriminatorΘA\\Theta\_\{A\}is optimized to accurately identify whether a given treatment\-representation pair is real or fake\.
2. 2\.Meanwhile, the representation moduleΘ𝓑\\Theta\_\{\\bm\{\\mathcal\{B\}\}\}is trained to fool the discriminator,encouraging it to produce representations that obscure treatment identity and thereby reduce the mutual information between treatments and representations\.

This adversarial alignment pushes𝓑\\bm\{\\mathcal\{B\}\}toward a balanced latent space that is less predictive of treatment group, helping to mitigate confounding bias\. In parallel, ACTIN also minimizes a standard prediction loss to ensure that representations remain informative for outcome estimation\. By decoupling treatment information from the learned representations via a GAN\-based setup, ACTIN enables more robust and unbiased counterfactual outcome estimation across time\-varying clinical data\.

### C\.2Optimization properties and convergence discussion

The joint objective in Eq\.[14](https://arxiv.org/html/2605.05706#S4.E14)combines the factual prediction lossℒΘY\\mathcal\{L\}\_\{\\Theta\_\{Y\}\}and the sampling\-based MMD regularizerℒℬ\\mathcal\{L\}\_\{\\mathcal\{B\}\}\. Both components are bounded below, differentiable almost everywhere, and Lipschitz continuous on compact parameter domains, which ensures that the overall objective satisfies the standard conditions under which stochastic gradient descent \(SGD\)\-type algorithms converge to first\-order stationary points for non\-convex problems\. Although global optimality cannot be guaranteed, these regularity properties imply that the Optimization landscape is well\-behaved in the sense required for contemporary deep learning systems\.

In practice, however, jointly optimizing the predictive and balancing losses introduces non\-trivial challenges\. A large balancing weightλ\\lambdaapplied too early in training may suppress physiologically meaningful variability in the learned representation, leading to underfitting or even representation collapse\. To mitigate this effect, we employ a curriculum\-style adaptive schedule forλ\\lambda\(Algorithm[1](https://arxiv.org/html/2605.05706#algorithm1), line 5\)\. Specifically, in training epocheeofEEtotal epochs, the balancing coefficient is updated according to the sigmoidal progression:

λe=21\+exp⁡\(−10⋅eE\)−1\.\\lambda\_\{e\}=\\frac\{2\}\{1\+\\exp\(\-10\\cdot\\frac\{e\}\{E\}\)\}\-1\.\(33\)This schedule begins near zero, gradually increases during mid\-training, and asymptotically approaches one\. Early training therefore prioritises minimizingℒΘY\\mathcal\{L\}\_\{\\Theta\_\{Y\}\}, enabling the encoder to learn a stable embedding of physiological dynamics\. As training progresses, the increasingλe\\lambda\_\{e\}progressively strengthens distributional alignment in the latent space\. This progressive scheme improves Optimization stability and avoids premature over\-regularization\.

### C\.3Implementation details of the ventilator re\-intubation classifier

The ventilator re\-intubation classifier was implemented in PyTorch\. Each input sequence comprised 12 time steps of 14 features representing vital signs, augmented with time\-invariant statistical descriptors \(mean, standard deviation, and temporal slope\) that were repeated along the temporal axis\.

The model architecture, referred to asCNN1DAvg, included two3×13\\times 1convolutional layers with ReLU activations, followed by a residual block composed of two additional3×13\\times 1convolutions and a skip connection\. A global average pooling layer aggregated temporal information, and the resulting representation was passed through a dropout layer \(p=0\.3p=0\.3\) and a linear classification head that produced the final scalar logit\. Training was conducted with the AdamW optimizer \(learning rate10−310^\{\-3\}, weight decay10−410^\{\-4\}\), using binary cross\-entropy loss with a positive class weight equal to the ratio of negative\-to\-positive samples\. In selected runs, a focal loss variant \(α=0\.25,γ=2\.0\\alpha=0\.25,\\gamma=2\.0\) was adopted to emphasize difficult cases\. A weighted random sampler ensured class balance during training\. Learning rate scheduling followed cosine annealing with warm restarts \(T0=5T\_\{0\}=5,Tmult=2T\_\{\\text\{mult\}\}=2\)\. Gradients were clipped to∥∇θ∥2<1\.0\\lVert\\nabla\\theta\\rVert\_\{2\}<1\.0at each step to improve stability\. Training was performed for 100 epochs with a batch size of 256 on an NVIDIA A40 GPU\. The optimal classification threshold was determined on the validation set by maximizing the F1\-score\. All implementation details, including data augmentation, masking, and reproducibility controls, are available in the released code repository\.

### C\.4Implementation details of the reconstruction decoder

The reconstruction decoder is trained in a two\-stage procedure to evaluate the information content of the learned balanced representations𝓑\\bm\{\\mathcal\{B\}\}\. This two\-stage design ensures that the decoder’s reconstruction quality reflects the encoder’s representation fidelity rather than being confounded by joint Optimization dynamics\.

Stage 1: Main model training \(encoder \+ outcome head\)\.In the first stage, the full GITO model, comprising the TCN\-based encoder, the balanced representation module, and the outcome prediction head, is trained end\-to\-end for 400 epochs using the joint objectiveℒΘY\+λe​ℒℬ\\mathcal\{L\}\_\{\\Theta\_\{Y\}\}\+\\lambda\_\{e\}\\mathcal\{L\}\_\{\\mathcal\{B\}\}with the adaptiveλ\\lambdaschedule described in Appendix \(Optimization Properties\)\. The Adam optimizer is used with a learning rate of10−310^\{\-3\}and weight decay of10−410^\{\-4\}\. At the end of Stage 1, the best model checkpoint is selected based on the validation loss\.

Stage 2: Reconstruction decoder training \(frozen encoder\)\.After Stage 1,*all encoder parameters are frozen*\(requires\_grad=False\)\. A lightweight LSTM decoder and a linear projection layer are then trained to reconstruct the original input co\-variates and static features from the balanced representations\. The architectural details are:

- •LSTM decoder:A single\-layer LSTM that takes𝓑t∈ℝ48\\bm\{\\mathcal\{B\}\}\_\{t\}\\in\\mathbb\{R\}^\{48\}\(the balanced representation at each time step\) as input, with a hidden size of 25\.
- •Projection layer:A linear layer mapping from the LSTM hidden state \(ℝ25\\mathbb\{R\}^\{25\}\) to the combined covariate, static feature space \(ℝdx\+ds\\mathbb\{R\}^\{d\_\{x\}\+d\_\{s\}\}, wheredxd\_\{x\}is the number of time\-varying co\-variates anddsd\_\{s\}is the number of static features after one\-hot encoding\)\.

The reconstruction loss is mean squared error \(MSE\), masked by the active entries indicator to account for variable\-length sequences:

ℒrecon=1N​T​∑i=1N∑t=1T∥𝑿^i,t−𝑿i,t∥2,\\mathcal\{L\}\_\{\\text\{recon\}\}=\\frac\{1\}\{NT\}\\sum\_\{i=1\}^\{N\}\\sum\_\{t=1\}^\{T\}\\lVert\\hat\{\\bm\{X\}\}\_\{i,t\}\-\\bm\{X\}\_\{i,t\}\\rVert^\{2\},\(34\)where𝑿^i,t\\hat\{\\bm\{X\}\}\_\{i,t\}denotes the reconstructed covariate vector and𝑿i,t=\[𝒙i,t;𝒔i\]\\bm\{X\}\_\{i,t\}=\[\\bm\{x\}\_\{i,t\};\\bm\{s\}\_\{i\}\]is the concatenation of the observed time\-varying co\-variates and the \(time\-expanded\) static features\. Stage 2 training is conducted for 300 epochs using the Adam optimizer with a learning rate of10−310^\{\-3\}and a batch size of 64\. Only the LSTM decoder and projection layer parameters are updated; the number of trainable parameters in this stage is approximately 2% of the total model parameters\. Validation loss is monitored at each epoch, and the best decoder checkpoint is retained\. After training, per\-variable reconstruction quality is assessed on the validation set by computing variable\-specific MSE andR2R^\{2\}scores, separately for time\-varying co\-variates and static features\. These metrics directly inform theΔ​R2\\Delta R^\{2\}analysis reported in the main text \(Section[D\.1](https://arxiv.org/html/2605.05706#A4.SS1)\)\.

### C\.5Generative AI configuration and prompt engineering

Model configuration and input encodingThe interpretability module utilizes a multimodal LLM \(GPT\-4o by default; the platform supports user\-selectable alternatives\) configured with a temperature of 0\.7 to balance creativity with adherence to clinical facts\. The maximum output token limit is set to 800 to accommodate the four\-paragraph output structure\. Inputs are constructed using a hybrid schema:

- •Textual context:Patient demographics \(age, primary diagnosis\); serialized MAP predictions for four treatment scenarios \(None, Vaso, Vent, Both\) over five time steps; and calculated statistics for the top\-5 feature\-attribution variables \(latest value, moving average, linear trend direction\)\.
- •Visual context:Three high\-resolution plots—\(i\) Vital Signs Trend, \(ii\) Prediction Trajectory, and \(iii\) Patient History with treatment markers—are rendered usingmatplotlib, converted to Base64 strings, and injected into the model’s vision context\.

Prompt structure and constraintsThe prompting strategy enforces an “Extraction\-to\-Interpretation” logic via three components:1\. System persona and constraints\.The system prompt defines the agent as an“AI\-based clinical decision\-support analyst”and imposes explicit negative constraints:“You are NOT a treating physician,”“You must NOT issue definitive medical advice or prescriptive treatment orders,”and“Avoid prescriptive or guideline\-based language\.”This ensures the tone remains analytical and descriptive rather than directive\.2\. Dynamic context injection \(user prompt\)\.The user prompt is dynamically assembled at runtime to include:

- •Data anchoring:A list of the top\-5 key variables with their Integrated Gradients contribution scores, explicit latest values, and trend directions \(e\.g\.,“Respiratory rate \(contribution 0\.15\): Latest 24/min, Trend: increasing”\)\. This serves as the Stage I ground truth to prevent numerical hallucinations\.
- •Comparison discipline instructions:The model is instructed to \(i\) use approximate deltas \(e\.g\.,“∼\\scriptstyle\\sim3\-5% higher”\) rather than absolute precision; \(ii\) explicitly state if scenarios are“clinically similar”to avoid over\-interpreting noise; and \(iii\) follow a*minimal\-intervention rule*: if a less intensive strategy \(especially None\) is predicted to reach and remain within the diagnosis\-appropriate target range, it should be treated as sufficient\.

3\. Output formatting rules\.The model is constrained to produce exactly four paragraphs without markdown headers:

- •Paragraph 1 \(Target measurement context\):summarize the displayed measurement’s range, variability, and alignment with clinically relevant reference ranges over the observed period\.
- •Paragraph 2 \(Influential vital signs\):For each of the top\-5 variables, report the attribution score, approximate current value, and clinical implication\.
- •Paragraph 3 \(Predicted trajectory interpretation\):Compare the four treatment strategies in terms of magnitude, trend, and stability of the predicted MAP trajectories, explicitly linking predictions to the patient’s current physiological state and historical treatment responses\.
- •Paragraph 4 \(Model preference distribution\):Report a numerical preference score for each of the four treatment scenarios as approximate percentages summing to 100%, reflecting a trade\-off between \(i\) sufficiency in achieving the diagnosis\-specific target range, \(ii\) trajectory stability, \(iii\) consistency with historical responses, and \(iv\) intervention intensity following a minimal necessary intervention principle\.

An abbreviated example of the generated output structure:

```
{
  "paragraph_1": "MAP has fluctuated between 58 and 72 mmHg ...",
  "paragraph_2": "Heart rate (contribution 0.23) is elevated ...",
  "paragraph_3": "The model predicts higher MAP under Vaso ...",
  "preference_scores": {
    "None": 15,
    "Vasopressors": 40,
    "Ventilation": 20,
    "Both": 25
  }
}
```

Fallback mechanismTo ensure system robustness in clinical settings, a deterministic fallback mechanism is implemented\. In the event of an API failure or a violation of the formatting constraints \(detected via regex parsing\), the system reverts to a template\-based generator\. This fallback concatenates the pre\-computed variable importance rankings and statistical trends into a simplified text summary, ensuring that decision support remains available even without LLM generation\.

### C\.6Hyperparameter tuning

Following the methods used in ACTIN\[conference/icml2024/wang\], we conduct hyperparameter optimization for all baseline models using random searches\. The ranges for the random searches for CRN, CRN\-sMMD, CT, CT\-MMD, ACTIN and ACTIN\-MMD are provided in Tables[12](https://arxiv.org/html/2605.05706#A3.T12),[13](https://arxiv.org/html/2605.05706#A3.T13),[14](https://arxiv.org/html/2605.05706#A3.T14), respectively\. Following the original research, we conduct hyperparameter optimization for two distinct base models, TCN and LSTM for ACTIN\. It is worth noting that all sub\-models within ACTIN utilizes the same base model within our experiments\.

Table 12:The ranges for hyperparameter tuning of CRN and CRN\-sMMD for synthetic tumor growth and MIMIC\-III datasets\. The symbolsΘe​n\\Theta\_\{en\}andΘd​e\\Theta\_\{de\}denote the Encoder and Decoder sub\-models, respectively\.Table 13:The ranges for hyperparameter tuning of CT and CT\-sMMD for synthetic tumor growth and MIMIC\-III datasets\.Table 14:The ranges for hyperparameter tuning of ACTIN for synthetic tumor growth and MIMIC\-III datasets\.

## Appendix DExperiment results

### D\.1Experiments reults on MIMIC\-III dataset

In this section, we present additional experimental results on factual outcome estimation under the scenario of diagnosis\-based distribution shift, involving three baseline models and their sMMD\-enhanced versions\. The three major disease categories are presented in Table[15](https://arxiv.org/html/2605.05706#A4.T15)\(cardiovascular diseases\), Table[16](https://arxiv.org/html/2605.05706#A4.T16)\(neurological disorders\), and Table[17](https://arxiv.org/html/2605.05706#A4.T17)\(infectious and inflammatory diseases\), respectively\. These models are trained on White patients and evaluated on three non\-White ethnicity subgroups across three disease categories\.

In our experiments, both CRN and ACTIN models, when incorporated with the sMMD strategy, exhibited significant performance improvements under out\-of\-distribution \(OOD\) settings across all three disease cohorts\. Specifically, the models demonstrated consistently enhanced generalisation to unseen patient populations\. This indicates the effectiveness of sMMD in improving robustness\. In contrast, CT\-sMMD only showed performance gains in neurological disorders, while in other scenarios the improvements were marginal or negligible\. These findings suggest that sMMD is effective across different treatment response prediction tasks, whereas the benefits of CT\-sMMD may be limited to certain patient groups or clinical conditions\.

Table 15:Multi\-step\-ahead prediction results on the MIMIC\-III dataset in patients within cardiovascular disease\. Shown: RMSE as mean±\\pmstandard deviation over ten runs\.Table 16:Multi\-step\-ahead prediction results on the RW dataset in out of distribution settings \(patients within neurological disorders \)\. Shown: RMSE as mean±\\pmstandard deviation over ten runs\.Table 17:Multi\-step\-ahead prediction results on the MIMIC\-III dataset in out of distribution settings \(patients with infectious and inflammatory diseases\)\. Shown: RMSE as mean±\\pmstandard deviation over ten runs\.Table 18:Multi\-step prediction results on the fully\-synthetic tumor\-growth dataset \(γ=10\\gamma=10, lower values are better, with the best highlighted in bold\), where woBMR represents model without balancing strategy\. Shown: RMSE as mean±\\pmstandard deviation over ten runs\.

## Appendix ECase study: patient selection and interpretability workflow

This section details the patient selection criteria and the end\-to\-end workflow used to generate the interpretability analysis presented in the main text \(Section[2\.4](https://arxiv.org/html/2605.05706#S2.SS4)\)\.

### E\.1Patient selection

The case study patient was selected from the MIMIC\-III ventilator weaning subcohort \(N=205N=205\) according to the following criteria:

1. 1\.Diagnosis:The patient’s primary admission diagnosis was septic shock \(ICD\-9 785\.52\), a condition in which vasopressor therapy decisions are clinically impactful and where the trade\-off between treatment escalation and conservative management is well\-characterised\.
2. 2\.Treatment diversity:The patient’s ICU trajectory included periods of both vasopressor administration and mechanical ventilation, as well as intervals without active treatment\. This diversity ensured that all four counterfactual scenarios \(None, Vaso, Vent, Both\) were clinically plausible given the patient’s history\.
3. 3\.Non\-trivial prediction:GITO’s predicted re\-intubation risk for this patient fell within an intermediate probability range \(neither near 0 nor near 1\), representing a clinically ambiguous case where decision\-support tools provide the greatest added value\.
4. 4\.Outcome availability:The patient’s subsequent clinical trajectory \(successful recovery without vasopressor escalation\) was documented, enabling retrospective validation of the model’s preference distribution\.

### E\.2End\-to\-end interpretability workflow

The following steps describe the complete pipeline from raw patient data to the final LLM\-generated explanation shown in Box[E\.2](https://arxiv.org/html/2605.05706#A5.SS2):

1. 1\.Data ingestion\.The patient’s hourly time\-series data \(25 vital signs, 3 static attributes, 2 binary treatments\) were loaded from the MIMIC\-extract preprocessed dataset and normalized using the cohort\-level Z\-score parameters \(see Methods: Data Preprocessing\)\.
2. 2\.Multi\-step outcome prediction\.GITO’s encoder\-decoder architecture generated MAP predictions over a 5\-step horizon \(τ=5\\tau=5, corresponding to 5 hours\) under each of the four treatment scenarios\. At each rollout step, the predicted outcome was fed back as input for the next step \(autoregressive inference with teacher forcing disabled\)\.
3. 3\.Integrated Gradients attribution\.For each prediction stepj∈\{1,…,5\}j\\in\\\{1,\\dots,5\\\}, Integrated Gradients was applied with the cohort\-mean baseline to compute per\-variable attribution scoresϕi\(j\)\\phi^\{\(j\)\}\_\{i\}\. Scores were averaged across the prediction horizon to obtainωiraw\\omega\_\{i\}^\{\\text\{raw\}\}and normalized via softmax to produceωi\\omega\_\{i\}\(see Methods: Interpretability Pipeline, Eqs\.[26](https://arxiv.org/html/2605.05706#S4.E26)\-[27](https://arxiv.org/html/2605.05706#S4.E27)\)\. The top\-kk\(k=5k=5\) variables were selected for downstream reporting\.
4. 4\.Visualisation rendering\.Three charts were generated usingmatplotlib: - •Vital Signs Trend:Strip plots of the top\-5 variables with normal\-range shading\. - •Prediction Trajectory:Historical MAP with four counterfactual projection lines\. - •Patient History:Full MAP trajectory with treatment markers \(shape\-coded: circle = None, triangle = Vaso, square = Vent, star = Both\) and baseline/follow\-up separation\. All plots were converted to Base64\-encoded PNG images for injection into the LLM’s vision context\.
5. 5\.Statistics extraction\.For each top\-5 variable, the latest value, moving average, and linear trend direction were computed from the raw time series and serialized as structured text \(Stage I grounding data\)\.
6. 6\.LLM inference\.The system prompt, dynamic user prompt \(with data anchoring and comparison discipline instructions\), and three visual inputs were assembled and sent to GPT\-4o \(temperature = 0\.7, max tokens = 800\)\. The model produced a four\-paragraph response including the preference distribution \(see Appendix[C\.5](https://arxiv.org/html/2605.05706#A3.SS5)for the full prompt structure\)\.
7. 7\.Output validation\.The returned text was parsed to verify structural compliance \(four paragraphs, preference scores summing to 100%\)\. Had the response failed validation, the deterministic fallback generator would have been invoked automatically\.

The total inference time for steps 2\-7 was approximately 3\.2 seconds on CPU \(Intel Xeon 8452Y\) plus LLM API latency, consistent with the real\-time decision\-support requirements reported in the main text\.

Box: LLM\-generated explanationThe patient’s Mean Arterial Pressure \(MAP\) has been fluctuating below the standard therapeutic target range of 65\-85 mmHg throughout the observed period\. Historical data indicate significant variability, with several transient dips likely reflective of underlying septic shock\. Treatment markers indicate the use of vasopressors, aligning temporally with periods of lower MAP, although no direct causal inference is made\. The trend does not demonstrate a clear monotonic recovery, highlighting ongoing instability in blood pressure regulation\.Among the top 5 influential vital signs, tidal volume \(contribution 0\.20\) is stable at 650\.00 mL, suggesting consistent ventilation settings\. Respiratory rate \(contribution 0\.25\) is currently 17/min and increasing, which may indicate evolving respiratory demands\. Systemic vascular resistance \(contribution 0\.18\) remains stable at 645\.91 dyn·s/cm5, suggesting peripheral vasodilation\. Heart rate \(contribution 0\.22\) is 68/min, lower than average, with an increasing trend, potentially indicating compensatory responses\. Oxygen saturation \(contribution 0\.15\) is 98%, showing improvement and suggesting adequate oxygenation\.The predicted MAP trajectories under different treatment strategies show distinct patterns\. The ’None’ strategy projects a gradual increase in MAP, eventually reaching the target range but at a slower pace\. ’Vaso’ predicts a rapid increase, achieving and exceeding the target range quickly, likely due to the patient’s previous responsiveness to vasopressors\. ’Vent’ also improves MAP, but with less magnitude compared to ’Vaso\.’ The ’Both’ strategy predicts the highest MAP, potentially overshooting the target range\. These predictions are influenced by the patient’s stable tidal volume and increasing respiratory rate, reinforcing the benefit of vasopressors in this context\.Treatment suggestion:The model’s preference distribution reflects these insights\. ’Vaso’ is preferred at 40% due to its rapid attainment of MAP within the target range, consistent with historical treatment responses\. ’None’ is assigned 30%, acknowledging its eventual sufficiency but slower response\. ’Vent’ is preferred at 20% for moderate improvement without overshooting\. ’Both’ receives 10%, as it may provide unnecessary elevation in MAP\.

## References

Similar Articles

"Excuse me, may I say something..." CoLabScience, A Proactive AI Assistant for Biomedical Discovery and LLM-Expert Collaborations

arXiv cs.CL

CoLabScience introduces a proactive LLM assistant for biomedical research that autonomously intervenes in scientific discussions using PULI (Positive-Unlabeled Learning-to-Intervene), a novel reinforcement learning framework that determines when and how to contribute context-aware insights. The work includes BSDD, a new benchmark dataset of simulated research dialogues with intervention points derived from PubMed articles.

Elucidating the SNR-t Bias of Diffusion Probabilistic Models

Hugging Face Daily Papers

This paper identifies a Signal-to-Noise Ratio timestep (SNR-t) bias in diffusion probabilistic models during inference, where SNR-timestep alignment from training is disrupted at inference time. The authors propose a differential correction method that decomposes samples into frequency components and corrects each separately, improving generation quality across models like IDDPM, ADM, DDIM, EDM, and FLUX with minimal computational overhead.