PyHealth icon indicating copy to clipboard operation
PyHealth copied to clipboard

Add BulkRNABert model for cancer prognosis

Open gwicho38 opened this issue 1 month ago • 0 comments

Summary

This PR adds the BulkRNABert model for cancer prognosis tasks based on the paper by Gélard et al. (2025).

Key additions:

  • BulkRNABertLayer: Transformer encoder layer for gene expression data
  • BulkRNABert: Main model for cancer type classification (33 TCGA cancer types)
  • BulkRNABertForSurvival: Model for survival prediction using Cox proportional hazards loss
  • compute_c_index: Utility function for computing concordance index
  • Comprehensive unit tests

Paper reference: Gélard et al. (2025) "BulkRNABert: Cancer prognosis from bulk RNA-seq based language models"

  • Paper: https://www.biorxiv.org/content/10.1101/2024.06.13.598798
  • Model: https://huggingface.co/InstaDeepAI/BulkRNABert

Implementation Details

The implementation follows PyHealth's BaseModel pattern and provides:

  • Multiclass classification for 33 TCGA cancer types
  • Binary classification support
  • Survival prediction with Cox partial likelihood loss
  • C-index metric computation for survival evaluation

Test Plan

  • [x] Unit tests for BulkRNABertLayer encoder
  • [x] Unit tests for BulkRNABert classification model
  • [x] Unit tests for BulkRNABertForSurvival model
  • [x] Unit tests for compute_c_index function
  • [x] All tests pass locally

Context

This contribution was developed as part of a CS 598 DLH (Deep Learning for Healthcare) course project reproducing the BulkRNABert paper results.

gwicho38 avatar Dec 08 '25 03:12 gwicho38