Towards Unified Models for Explainable AI in Precision Medicine

Cross-scale consistency from biological trajectories to clinical outcomes

Rui Miao
Department of Mathematical Sciences
Texas AI Research Institute
University of Texas at Dallas
Baseline biology labs, echo, organ burden Latent state scale-compressed patient representation Trajectory heads 3-year biomarker evolution Survival latent
\(z_i^S\)
Survival head 5-/15-year mortality risk Explanations test whether the stories agree

Prediction is not enough.

Fragmented workflow

Biomarkers labs, echo Surrogates separate curves Risk scalar score biology and survival are explained after the fact
  • Biology, trajectory, and survival are modeled separately.
  • Explanations are added after prediction.
  • Clinicians see a scalar risk score without a disease story.

Unified workflow

Baseline
\(x_i\)
\(z_i\)
shared latent
\(\hat y_i(t)\)
trajectory
\(z_i^S\)
survival latent
\(\hat S_i(t)\)
Explanations Scale = baseline variables → latent disease state → clinical observables / outcomes
  • Scale means the resolution of representation: covariates, latent state, trajectories, and outcomes.
  • One representation is constrained by multiple clinical observables.
  • Explanations interrogate both raw features and latent organization.
  • Validation asks whether all layers agree.

Sickle cell disease is an ideal stress test for unified explainable AI.

SCD multiorgan heterogeneity HematologicHb, reticulocytes, LDH RenaleGFR, creatinine, BUN HepaticALP, bilirubin, ALT CardiopulmonaryTRV, RVSP, RA area

Why this disease exposes the modeling problem

  • SCD has complicated heterogeneous progression across hematologic, renal, hepatic, and cardiopulmonary systems.
  • Those progression patterns are associated with survival outcomes, not just cross-sectional severity.
  • Mortality is clinically decisive but statistically sparse and censored.
The modeling target is heterogeneous disease progression linked to outcome, not a treatment-response claim.

The statistical problem is not "more variables"; it is missing supervision.

Rare observed events

Mortality is decisive but sparse, censored, and noisy at fixed horizons.

Composite endpoint burden

A single event label is hard to explain biologically without progression anchors.

Weak supervision

Baseline risk models need additional structure to learn disease state.

Prediction-time constraint

68 baseline covariates available at counseling mortality risk 5-/15-year readout

Risk must be computable from baseline variables alone for counseling, triage, and advanced-therapy decisions.

Training-time opportunity

baseline shared state 12 biomarker trajectories
\(z_i^S\)
survival

Early biomarker trajectories are available during model development and provide progression supervision for the latent state.

Data setting: NHLBI adult SCD cohorts

Main cohort

adults
598
deaths
218
5y mortality
16.5%

Median follow-up 7.1 years. Rich baseline labs and echocardiography plus auxiliary longitudinal biomarkers.

Validation cohort

adults
383
deaths
42
5y mortality
22.9%

Median follow-up 2.2 years. Used only for external validation, not for model fitting or selection.

baseline snapshot demographics, labs, echo 68 baseline covariates auxiliary trajectories 12 biomarkers over 3 years renal · hepatic · cardiopulmonary · physiology outcome all-cause mortality time-to-event with censoring

Footnote: main and validation analytic cohorts are drawn from NHLBI protocols 01-H-0088 and 04-H-0161 but are different cohorts; auxiliary trajectories are available only in the main cohort.

Three data anchors: baseline profile, auxiliary trajectories, and survival

Baseline patient profile

68 baseline covariates.

demographicsgenotypelabsechocardiographyvitals
Age Genotype Labs Vitals Echo History

Auxiliary longitudinal supervision

12 biomarkers over a 3-year window, observed irregularly and restricted to early follow-up.

renalhepaticcardiopulmonaryphysiology
3-year biomarker curves TRV creatinine ALP time from baseline

Survival endpoint

All-cause mortality with censoring provides the hard clinical outcome.

death eventcensoring5-year risk15-year risk
survival readout 5 years 15 years
\(\hat S_i(t)\)

Cross-scale consistency: the central object of the talk

Biological

Labs and echocardiography proxy organ-level pathophysiology.

Temporal

The same latent state should support short-term biomarker evolution.

Outcome

It should support calibrated long-horizon survival prediction.

Interpretive

Feature, subject, and latent-phenotype explanations should agree.

fine clinical scale shared disease-state scale outcome scale
\(x_i\)
baseline covariates
\(z_i\)
shared latent state
\(\hat y_{ik}(t)\)
\(z_i^S\)
survival latent
\(\hat S_i(t)\)
\(\phi_i\)
explanation and audit scale Left to right: increasingly coarse representations, with observables read from the appropriate scale.
Do not draw \(y\) as the next coarse-grained state. \(z_i\) is the shared state; trajectory reads from \(z_i\), while survival first passes through \(z_i^S\).

Renormalization Group is a disciplined analogy, not a literal claim.

Physics example: 2D Ising block-spin RG

microscopic spin lattice majority vote coarse block spins Coarse block spins preserve magnetization and long-range correlations.

Clinical analogy: readouts attach at the right scale

baseline
\(x_i\)
\(h_i^{(k)}\)
earlier state
\(\hat y_{ik}(t)\)
progression
\(z_i\)
shared state
\(z_i^S\)
\(\hat S_i(t)\)
mortality Progression can read from earlier patient-state layers; mortality reads from a coarser survival latent.

Patient-level coarse-graining: from clinical measurements to an effective disease state

General mathematical diagram

\[ \begin{aligned} h_i^{(0)} &= E(x_i)\\ h_i^{(k)} &= C_{1:k}\!\left(h_i^{(0)};\theta\right)\\ \hat y_{ik}(t) &= g_{y,k}^{(k)}\!\left(h_i^{(k)},t\right),\quad k\le L\\ z_i^S &= H_S\!\left(h_i^{(L)}\right)\\ \hat S_i(t) &= g_S(z_i^S,t) \end{aligned} \]

Some progression readouts can attach to intermediate layers \(h_i^{(k)}\). The survival-specific state \(z_i^S\) is a coarser downstream layer for mortality risk.

The observable \(\hat y_{ik}(t)\) is not itself the next coarse-grained state; it is a readout from whichever hidden layer best carries that biological scale.
baseline
\(x_i\)
\(h_i^{(1)}\)
early organ scale
\(\hat y_{i1}(t)\)
\(h_i^{(k)}\)
intermediate scale
\(\hat y_{ik}(t)\)
\(z_i^S\)
coarsest
\(\hat S_i(t)\)
mortality risk Example: disease progression early biomarkers read from earlier layers; mortality reads from a coarser survival layer.

Residual stream and RG: the analogy is an evolving effective patient state.

Residual update

\[ h_i^{(\ell+1)}=\operatorname{LN}\!\left(h_i^{(\ell)}+F_\ell(h_i^{(\ell)})\right) \]
\(h_i^{(\ell)}\)
learned correction
\(F_\ell(h)\)
+ LN

Skip connection preserves the previous state; the residual block adds a task-guided correction; LayerNorm stabilizes the updated state.

Clinical-biological interpretation

biological changes organ signals
\(h_i^{(\ell)}\)
mixed patient state
\(h_i^{(\ell+1)}\)
clearer state
\(\hat y_i(t)\)
progression readout
\(z_i^S\)
coarsest
\(\hat S_i(t)\)
mortality risk

The residual stream is not physical RG. It is the network's running effective patient state. Each block rewrites the representation while preserving what matters for trajectory and survival observables.

LayerNorm is the stabilizing step after each residual correction.

Not cohort-level z-scoring

\[ \begin{aligned} u_i^{(\ell)} &= h_i^{(\ell)}+F_\ell(h_i^{(\ell)})\\ \operatorname{LN}(u_i) &= \gamma_\ell \odot {u_i-\mu_i \over \sqrt{\sigma_i^2+\epsilon}}+\beta_\ell \end{aligned} \]

\(\mu_i\) and \(\sigma_i\) are computed across hidden coordinates within subject \(i\), not across patients. \(\gamma_\ell\) and \(\beta_\ell\) are learned.

This normalizes one subject's internal hidden vector, not one clinical covariate across the cohort.
before LN after LN arbitrary magnitude stable internal scale

Trajectory reads from shared z; survival reads from a survival-specific latent state.

Trajectory observables

\[ \hat y_{ik}(t)=g_{y,k}(z_i,t),\qquad k=1,\ldots,12 \]
0-3 years predicted short-term physiology

Trajectory loss pushes the shared \(z_i\) to carry baseline-implied progression information.

Survival observable

\[ \begin{aligned} z_i^S &= H_S(z_i)\\ p_i &= g_S(z_i^S)\\ F_i(t_m) &= \sum_{r\le m}p_{ir}\\ S_i(t_m) &= 1-F_i(t_m) \end{aligned} \]
5 years survival curve readout

The survival branch has its own latent processing state \(z_i^S\); it is not a scalar directly copied from \(z_i\).

Consistency is operational: losses define what the model must preserve.

Loss terms

\[ \begin{aligned} \mathcal L_{\mathrm{surv}} &=-\sum_i\Big[\delta_i\log p_i(T_i)+(1-\delta_i)\log \hat S_i(C_i)\Big]\\ \mathcal L_{\mathrm{traj},k} &=\sum_{(i,m)\in\Omega_k}\left(y_{ikm}-\hat y_{ik}(t_{ikm})\right)^2\\ \mathcal L_{\mathrm{cal}} &=\sum_m\left(\widehat{\mathrm{risk}}_m-\mathrm{KM}_m\right)^2\\ \mathcal L &=\mathcal L_{\mathrm{surv}}+\lambda_{\mathrm{traj}}\sum_k\mathcal L_{\mathrm{traj},k} +\eta\mathcal L_{\mathrm{cal}} \end{aligned} \]
The losses define relevance: survival preserves event-time discrimination and calibration; trajectories preserve short-term biological progression.
\(z_i\)
\(z_i^S\)
survival latent trajectory loss progression coherence survival loss risk ordering + calibration Relevance is task-defined.

Model architecture: shared representation, task-specific heads

Input 68 covariates Embedding feature tokens attention weights Shared trunk Block 1 Block 2 ... residual+ LN residual+ LN
\(z_i\)
last hidden layer
\(z_i^S\)
survival latent
\(\hat S_i(t)\)
y1 y2 y3 y4 y5 y6 y7 y8 y9 y10 y11 y12 12 trajectory heads basis coefficients → curves

Missingness and validation logic are part of the model design.

Missingness handling

NA augmented input
\((\tilde x_i,m_i)\)
indicator vector introduced here
  • Random forest imputation separately for main and validation cohorts.
  • Missingness indicators enter the attention encoder.
  • Random missingness masks improve robustness.
  • Longitudinal timestamps are preserved.

Validation logic

main cohort fit + tune bootstrap optimism correction external validation
  • C-statistic for discrimination.
  • 5-year Integrated Brier Score for prediction error.
  • Simulated 10-30% missingness without retraining.
  • Validation cohort untouched during fitting and selection.

Central ablation: trajectory-supervised survival improves discrimination and prediction error.

Main cohort

Model
C-stat
5y IBS
TSML (Sachdev 2021)*
0.739
0.058
Single-task DeepHit
0.827
0.037
Multi-Task DeepHit
0.882
0.029
TSML
Sachdev 2021
Single-task
Multi-task

External validation

Model
C-stat
5y IBS
TSML (Sachdev 2021)*
0.661
0.085
Single-task DeepHit
0.747
0.072
Multi-Task DeepHit
0.794
0.064

* TSML benchmark from Sachdev V. et al. 2021; TSML = two-step machine learning: variable selection followed by Cox models.

Operational robustness under simulated missingness

Main cohort: Multi-Task DeepHit

TSML 0% C-stat 0.739 0.8830.8500.8230.779 0%10%20%30% simulated missingness C-statistic IBS 0.0290.0410.0480.057

Validation cohort: Multi-Task DeepHit

TSML 0% C-stat 0.661 0.7940.7870.7510.702 0%10%20%30% simulated missingness C-statistic IBS 0.0640.0690.0830.085
Performance degrades as information is removed, but the trajectory-supervised model remains competitive under clinically plausible incompleteness.

Explainability is layered, not singular.

Population

Which baseline covariates are most influential overall?

global SHAP

Subgroup

Do risk mechanisms change between low- and high-risk strata?

subgroup SHAP

Subject

Why did the model assign this patient this risk?

waterfall

Mechanistic

Which sparse disease-state directions organize the model?

SAE
raw features shared residual stream common patient state risk prediction SHAP Waterfall SAE probe

SHAP recovers coherent organ-system risk signals and stage-dependent patterns.

Population-level top signals

reticulocyte countALPRA pressureTRVRA area
reticulocytes
ALP
RA pressure
TRV
RA area

Subgroup pattern

  • Low-risk subgroup: reticulocyte count is most influential.
  • High-risk subgroup: RA pressure becomes dominant.
  • ALP and right-heart measures remain important in both groups.
  • Reticulocyte count's contribution may depend on interactions rather than marginal group difference.
low-risk state hematologic signal high-risk state right-heart pressure

Waterfall example A: subject who died within 5 years

SHAP-style cumulative explanation: 5-year mortality risk

0%25%50%75%100% expected 5y mortality risk patient: 97.91% ALP321 U/L Creatinine1.7 mg/dL Hemoglobin6.8 g/dL TRV / RVSP3.1 m/s · 52 mmHg Other variablesother covariates +20%+16%+11%+8%+5% risk-increasing factors move the prediction upward

How to read the waterfall

  • The horizontal axis is the predicted 5-year mortality-risk percentage; this high-risk case uses a 0-100% range.
  • Each row shows the subject's observed value and the corresponding push in the 5-year mortality-risk prediction.
  • ALP, creatinine, hemoglobin, and TRV/RVSP form a coherent multiorgan burden pattern.
  • The final red line is the model's patient-specific 5-year mortality-risk prediction.
This subject died within 5 years, so the explanation should emphasize why the fitted model placed the subject far above the expected-risk baseline.

Interactive app: https://ling2-scd-prediction-mtdt.share.connect.posit.cloud

Waterfall example B: low near-term risk subject with later death

SHAP-style cumulative explanation: 5-year mortality risk

0%10%20%30%40% expected 5y mortality risk patient 5y mortality risk: 0.017% NT-proBNP42 pg/mL LA area13 cm² LV diameter4.3 cm Stroke volume83 mL Other varscovariates -5%-3%-3%-2%-1% protective factors move the patient-specific 5-year mortality risk downward

How to read the waterfall

  • The horizontal axis is the predicted 5-year mortality-risk percentage; this low-risk case uses a 0-40% range for readability.
  • Each row shows the subject's observed value and how that feature moves the 5-year prediction downward.
  • Favorable cardiac measures and low NT-proBNP contribute protective shifts for the 5-year horizon.
  • The final blue line is the model's patient-specific 5-year mortality-risk prediction.
This subject died later, so the example shows horizon specificity: a coherent low 5-year mortality-risk explanation does not imply low long-term risk.

Interactive app: https://ling2-scd-prediction-mtdt.share.connect.posit.cloud

Interpretability guardrails: what explanations can and cannot claim

Supported model behavior what moved this prediction? Hypothesis-generating biological interpretation does it resemble SCD pathophysiology? Not Claimed causal intervention not a treatment effect estimate clinical review loop external validation · mechanistic experiments · prospective cohorts · clinical trials

SAE sits on the residual stream: it decomposes hidden states into sparse signals.

Sparse Autoencoder (SAE) Probe

\[ \begin{aligned} a_i &= \operatorname{TopK}\!\left(\operatorname{ReLU}(E h_i^{(L)})\right)\\ \hat h_i^{(L)} &= D a_i\\ \mathcal L_{\mathrm{SAE}} &= \left\|h_i^{(L)}-\hat h_i^{(L)}\right\|_2^2 \end{aligned} \]

SAE is post hoc: it explains the model's residual-stream state \(h_i^{(L)}\), not raw input alone and not the primary survival head.

Interpretation comes from active sparse features: only a few coordinates of \(a_i\) are nonzero for each subject.

Reference point: Anthropic's "Towards Monosemanticity" / sparse-autoencoder interpretability blog series.

hidden state
\(h_i^{(L)}\)
Encoder TopK Decoder
\(\hat h_i\)
reconstruction tests whether sparse features preserve the hidden state right-heart stress renal reserve high-output remodeling

SAE features: from polysemantic hidden units to clinical concepts

residual stream layer
\(h_i^{(L)}\)
Polysemanticity one hidden unit mixes anemia, renal reserve, and right-heart stress one-layer SAE network
\(a_i=\operatorname{TopK}(\operatorname{ReLU}(W_Eh_i+b_E))\)
mixed hidden sparse features TopK: few active monosemantic features clinically coherent directions right-heart stress renal reserve high-output state Monosemanticity one sparse feature tracks one clinically interpretable concept phenotype clusters C1 C2 C3 C4 cluster grouping group features, then audit baseline contrasts and CIF

SAE cluster results after sparse-feature grouping

5-year CIF ladder

Cluster
Interpretation
Main
Validation
C1
Preserved cardiopulmonary phenotype
5.3%
8.5%
C2
Chronic anemia with early LA loading
8.4%
19.4%
C3
Anemia-driven high-output remodeling
22.0%
36.9%
C4
Cardio-renal-pulmonary vasculopathy
50.3%
57.1%
SAE main C-index
0.7448
SAE validation C-index
0.7474
SAE cluster cumulative incidence plots
Read left to right: residual state to sparse SAE features, clinically similar features to clusters, then clusters audited by baseline phenotype and 5-year CIF.

Clusters are phenotype hypotheses generated from sparse hidden-state features.

C1: Preserved cardiopulmonary phenotype

Younger, lower TRV/RVSP, better e' relaxation signals, smaller chambers, lower BUN, hyperfiltration-range renal profile.

TRV downRVSP downeGFR up

C2: Chronic anemia with early LA loading

Mild anemia, slightly older age, larger LA size, mildly lower lateral e', without dominant pulmonary-pressure or renal-injury pattern.

hemoglobin downLA area uprenal injury absent

C3: Anemia-driven high-output remodeling

Lower hemoglobin/hematocrit, larger LV/LA/RA chambers, higher cardiac output and LV mass, elevated TRV.

LV/LA/RA upcardiac output upTRV up

C4: Cardio-renal-pulmonary vasculopathy

Higher TRV/RVSP, lower eGFR, higher creatinine/BUN, older age, lower e', lower E/A, thicker LV walls.

TRV/RVSP upeGFR downBUN up
sparse features active SAE neurons clinical audit baseline phenotype contrasts outcome audit 5-year CIF ladder: C1 5.3% to C4 50.3% SAE clusters are audit targets, not definitive disease mechanisms.

From sparse latent signals to clinician-readable explanations

👤 Patient data baseline labs echo, genotype ⚙️ Black-box model
\(z_i,\hat S_i(t)\)
risk + hidden state 🔎 SAE features evidence packet 🏷️ LLM output labels right-heart stress renal reserve 📝 LLM output draft text clinician + lay versions LLM output judge faithful? clinically plausible?
revise loop: relabel, rewrite, or reject 🩺 Clinician input rank, edit, reject explanations RL reward model faithful + audience-specific

One held-out patient, two audiences: the same model evidence should survive translation.

Clinician-facing output

"The model predicts a cumulative incidence curve that rises modestly in the first 5 years and then plateaus, reaching ~40% by year 15."

"An anemia with right ventricular pressure overload subgroup carries lower Hemoglobin and Hematocrit but higher Estimated.RVSP and TR.Peak.Velocity; this patient sits on the activated side for anemia markers but not for RVSP/TRV. The per-patient Delta_F is +0.0009 / +0.0147 / +0.0039 / +0.0019 at 1/5/10/15y."

"A low right ventricular pressure/remodeling contrast carries lower RVSP, LV mass index, TRV, and BUN; this patient sits on all four features, with the largest positive push at 5y."

"Landmarks: F(1)=0.0267, F(5)=0.1500, F(10)=0.3500, F(15)=0.4000."

Layperson-facing output

"The model predicts this patient's chance of the bad outcome starts low, about 2% by 1 year, but climbs steadily through the first five years, reaching roughly 10% by year 5."

"The curve then accelerates, jumping to about 35% by year 10, before leveling off slightly in the later years, ending near 55% by year 15."

"The strongest concern is the heart and lungs working harder than they should. The right side of the heart is under extra pressure, and the blood vessels carrying blood to the lungs are showing signs of strain."

"A secondary signal is the patient's blood counts being lower than normal, which adds a smaller but steady push to the risk, particularly in the later years."

LLM setup: judge LLM = Gemma-4-26B-A4B; two generator LLMs = Mistral-3-14B; hosted on 4xH200 NVL.

Generalizable design pattern for precision medicine

baseline state
\(x_i,m_i\)
shared latent
\(z_i\)
intermediate physiology
\(\hat y_i(t)\)
outcome latent
\(z_i^S\)
hard outcome
\(\hat S_i(t)\)
explanations audit consistency features · latent phenotypes trajectories · survival

Candidate domains

pulmonary hypertensionHFpEFCKD progressiontransplant outcomesrare disease natural historycancer survivorship / toxicity

Take-home messages

  • Trajectory supervision strengthens rare-event survival prediction from baseline data.
  • Cross-scale consistency links features, latent states, biomarker progression, survival, and explanations.
  • SAE and LLM probes turn hidden-state evidence into auditable phenotype and explanation hypotheses.

Next aim: ACCORD trial

  • Reanalyze ACCORD with a unified model for biomarker dynamics and hard outcomes.
  • Use reinforcement learning to propose optimized dynamic treatment regimes and dosage paths.
  • Keep recommendations explainable to clinicians and patients through the same cross-scale audit.

Thank You!

Acknowledgements

Gefei Lin · Xin Tian · Colin Wu

Swee Lay Thein and members of the SCD Lab

Contact

rui.miao@utdallas.edu

rui-miao.github.io

Speaker notes

Slide overview