[CS598] Add Mamba model support, new variants and updated tasks, and LR scheduler support for Trainer
Daniel Kwan (NetID: dwkwan2)
This PR implements the following changes:
- Preliminary support for the Mamba model, loosely based on reproduction of EHRMamba (https://arxiv.org/abs/2405.14567). Depends on the pure PyTorch mamba.py implementation of Mamba (https://github.com/alxndrTL/mamba.py, i.e.
pip install mambapy- a library which is also used in HF Transformers) for building the actual Mamba blocks, which itself optionally depends on the officialmamba_ssmimplementation for their CUDA implementation (use_cudacurrently set to false, to minimize dependencies) - Two new variants of existing clinical predictive tasks (as specified in EHRMamba) on the MIMIC-IV dataset:
- Mortality prediction within 31 days
- Binary length-of-stay prediction (greater than one week or not)
- Provides an updated version of the existing multi-class length-of-stay prediction function into the newer PyHealth task format (as the old function no longer works)
- Implement LR scheduler support to the PyHealth Trainer, which can take in a LRScheduler class or optionally, a function that builds a scheduler (such as a SequentialLR for chaining together linear warmup and decay)
Hi John,
I added the requested test cases and an example_mamba.ipynb notebook to examples/ for a minimal example on using the Mamba model - and also pushed some remaining code that I forgot to earlier.
I couldn't figure out whether tasks should also have unit tests under tests/, since there doesn't seem to be any currently for MIMIC-IV tasks aside from under the todo/ folder, so I simply demonstrated the new tasks under the same notebook
The notebook also demonstrates the small (and very niche) use-case of providing a builder fn for a LR scheduler - in most cases, you would simply provide the class as with the existing optimizer behaviour.
Thanks,