PyHealth
PyHealth copied to clipboard
Add BulkRNABert model for cancer prognosis
Summary
This PR adds the BulkRNABert model for cancer prognosis tasks based on the paper by Gélard et al. (2025).
Key additions:
-
BulkRNABertLayer: Transformer encoder layer for gene expression data -
BulkRNABert: Main model for cancer type classification (33 TCGA cancer types) -
BulkRNABertForSurvival: Model for survival prediction using Cox proportional hazards loss -
compute_c_index: Utility function for computing concordance index - Comprehensive unit tests
Paper reference: Gélard et al. (2025) "BulkRNABert: Cancer prognosis from bulk RNA-seq based language models"
- Paper: https://www.biorxiv.org/content/10.1101/2024.06.13.598798
- Model: https://huggingface.co/InstaDeepAI/BulkRNABert
Implementation Details
The implementation follows PyHealth's BaseModel pattern and provides:
- Multiclass classification for 33 TCGA cancer types
- Binary classification support
- Survival prediction with Cox partial likelihood loss
- C-index metric computation for survival evaluation
Test Plan
- [x] Unit tests for BulkRNABertLayer encoder
- [x] Unit tests for BulkRNABert classification model
- [x] Unit tests for BulkRNABertForSurvival model
- [x] Unit tests for compute_c_index function
- [x] All tests pass locally
Context
This contribution was developed as part of a CS 598 DLH (Deep Learning for Healthcare) course project reproducing the BulkRNABert paper results.