Add MMTM Model
Add MMTM (Multimodal Transfer Module) Model and MMTMLayer for Multimodal EHR Fusion
This PR adds MMTM (Multimodal Transfer Module) and MMTMLayer to the PyHealth model library. MMTM is a lightweight, effective cross-modal channel-attention fusion module, originally introduced in:
Joze et al., “Multimodal Transfer Module for CNN Fusion,” CVPR 2020
https://arxiv.org/abs/1911.08670
Although developed for computer vision, MMTM provides a strong baseline for multimodal Electronic Health Records (EHR), enabling efficient fusion between:
- diagnosis codes
- procedure codes
- medications
- lab measurements
- clinical note embeddings or any two modalities with embedding vectors
This PR adapts MMTM to PyHealth’s BaseModel API in a way that is clean, reusable, and fully compatible with PyHealth datasets and processors.
Motivation
MMTM is increasingly used in recent multimodal EHR work, including as a baseline in:
CTPD: Cross-Modal Temporal Pattern Discovery for Enhanced Multimodal EHR Analysis (2024) Used as a fusion baseline for multimodal clinical time-series + notes.
By adding MMTM to PyHealth, this PR:
- expands the library’s multimodal modeling capabilities
- enables reproducible baselines for multimodal fusion studies
- supports EHR benchmarks where complementary modalities must be fused
- aligns with PyHealth’s goal of enabling reproducible AI4Health research
What’s Included
- MMTMLayer A standalone fusion layer that performs:
- joint squeeze (dim_a + dim_b → bottleneck)
- modality-specific excitation
- channel-wise attention across both modalities
- returns fused representations
Works for any tensor pair shaped (batch, feature_dim).
- MMTM model (BaseModel subclass)
A full PyHealth model that:
- embeds two input modalities with EmbeddingModel
- pools patient-level features with get_last_visit
- fuses both modalities via MMTMLayer
- performs classification using a final linear head
MMTM enforces exactly two modalities, matching typical multimodal EHR settings.
Unit Tests (test_mmtm.py)
Tests include:
- layer forward behavior
- correct model initialization
- forward output format (loss, y_prob, y_true, logit)
- gradient propagation
- model parameters
- dataset integration
- device handling
All tests pass locally.
Example Usage
Included at the bottom of mmtm.py under if name == "main":: minimal dataset with "codes" and "procedures" creation of MMTM model forward + backward demonstration
Impact
This PR adds a widely used multimodal fusion mechanism, enabling PyHealth users to:
- run CTPD baselines directly
- benchmark multimodal EHR models
- perform efficient two-modality fusion
- expand multimodal research using a known CVPR-level baseline
MMTM is simple, fast, and improves the reproducibility of multimodal EHR research inside PyHealth.