PyHealth icon indicating copy to clipboard operation
PyHealth copied to clipboard

Add MMTM Model

Open stuti-agrawal opened this issue 1 month ago • 0 comments

Add MMTM (Multimodal Transfer Module) Model and MMTMLayer for Multimodal EHR Fusion

This PR adds MMTM (Multimodal Transfer Module) and MMTMLayer to the PyHealth model library. MMTM is a lightweight, effective cross-modal channel-attention fusion module, originally introduced in:

Joze et al., “Multimodal Transfer Module for CNN Fusion,” CVPR 2020

https://arxiv.org/abs/1911.08670

Although developed for computer vision, MMTM provides a strong baseline for multimodal Electronic Health Records (EHR), enabling efficient fusion between:

  • diagnosis codes
  • procedure codes
  • medications
  • lab measurements
  • clinical note embeddings or any two modalities with embedding vectors

This PR adapts MMTM to PyHealth’s BaseModel API in a way that is clean, reusable, and fully compatible with PyHealth datasets and processors.

Motivation

MMTM is increasingly used in recent multimodal EHR work, including as a baseline in:

CTPD: Cross-Modal Temporal Pattern Discovery for Enhanced Multimodal EHR Analysis (2024) Used as a fusion baseline for multimodal clinical time-series + notes.

By adding MMTM to PyHealth, this PR:

  • expands the library’s multimodal modeling capabilities
  • enables reproducible baselines for multimodal fusion studies
  • supports EHR benchmarks where complementary modalities must be fused
  • aligns with PyHealth’s goal of enabling reproducible AI4Health research

What’s Included

  1. MMTMLayer A standalone fusion layer that performs:
  • joint squeeze (dim_a + dim_b → bottleneck)
  • modality-specific excitation
  • channel-wise attention across both modalities
  • returns fused representations

Works for any tensor pair shaped (batch, feature_dim).

  1. MMTM model (BaseModel subclass)

A full PyHealth model that:

  • embeds two input modalities with EmbeddingModel
  • pools patient-level features with get_last_visit
  • fuses both modalities via MMTMLayer
  • performs classification using a final linear head

MMTM enforces exactly two modalities, matching typical multimodal EHR settings.

Unit Tests (test_mmtm.py)

Tests include:

  • layer forward behavior
  • correct model initialization
  • forward output format (loss, y_prob, y_true, logit)
  • gradient propagation
  • model parameters
  • dataset integration
  • device handling

All tests pass locally.

Example Usage

Included at the bottom of mmtm.py under if name == "main":: minimal dataset with "codes" and "procedures" creation of MMTM model forward + backward demonstration

Impact

This PR adds a widely used multimodal fusion mechanism, enabling PyHealth users to:

  • run CTPD baselines directly
  • benchmark multimodal EHR models
  • perform efficient two-modality fusion
  • expand multimodal research using a known CVPR-level baseline

MMTM is simple, fast, and improves the reproducibility of multimodal EHR research inside PyHealth.

stuti-agrawal avatar Dec 04 '25 08:12 stuti-agrawal