PyHealth
PyHealth copied to clipboard
[Bounty] ConCare PyHealth 2.0 ver
Pull Request: ConCare Model Update (1.0 → 2.0)
Contributor Information
- Name: Joshua Steier
Contribution Type
- [x] Model Update
- [ ] New Model
- [ ] Dataset
- [ ] Task
- [ ] Bug Fix
- [ ] Documentation
- [ ] Other
High-Level Description
This PR updates the ConCare model from PyHealth 1.0 API to PyHealth 2.0 API.
Paper Reference
- Title: Concare: Personalized clinical feature embedding via capturing the healthcare context
- Authors: Liantao Ma et al.
- Venue: AAAI 2020
- Link: https://ojs.aaai.org/index.php/AAAI/article/view/5428
Changes Made
-
API Migration (1.0 → 2.0):
- Replaced
SampleEHRDatasetwithSampleDataset - Integrated
EmbeddingModelfor unified embedding handling - Removed explicit
feature_keys,label_key,mode,use_embeddingparameters - Simplified constructor to derive feature information from dataset schemas
- Replaced
-
Code Improvements:
- Added comprehensive Google-style docstrings for all classes and methods
- Added proper type hints throughout
- Fixed bug in
SingleAttention(removed undefinedself.Wdinitialization) - Improved code documentation with input/output descriptions
- Added file header with paper reference and description
-
Testing:
- Added comprehensive unit tests covering:
- Model initialization (with/without static features)
- Forward pass validation
- Backward pass (gradient flow)
- Embedding extraction
- Custom hyperparameters
- Multiclass classification
- Single feature input
- Added comprehensive unit tests covering:
-
Examples:
- Added example notebook for MIMIC-IV in-hospital mortality prediction
- Added standalone Python script example
Files to Review
| File | Description |
|---|---|
| pyhealth/models/concare.py | Main model implementation (updated) |
| tests/models/test_concare.py | Unit tests for ConCare model |
| examples/concare_mimic4_example.ipynb | Example notebook for MIMIC-IV |
How to Test
1. Run Unit Tests
python -m pytest tests/models/test_concare.py -v
2. Run Quick Example (main block)
python pyhealth/models/concare.py
3. Expected Output
{
'loss': tensor(..., grad_fn=<AddBackward0>),
'y_prob': tensor([[...], [...]], grad_fn=<SigmoidBackward0>),
'y_true': tensor([[...], [...]]),
'logit': tensor([[...], [...]], grad_fn=<AddmmBackward0>)
}
Checklist
- [x] Code follows PEP8 style (88 character line length)
- [x] Code follows Google-style docstrings
- [x] All functions have type hints
- [x] All functions have input/output documentation
- [x] File header includes author, paper title, link, and description
- [x] Unit tests pass
- [x] Example code runs successfully
- [x] Backward pass (loss.backward()) works
- [x] Code is rebased with main branch
Additional Notes
The ConCare model includes:
- Channel-wise GRUs: Separate GRU for each input feature dimension
- Time-aware attention: Captures temporal decay in healthcare context
- Multi-head self-attention: Captures feature interactions
- DeCov loss: Regularization to reduce feature redundancy
- Static feature support: Optional demographic/static features
This update maintains full backward compatibility with the original ConCare functionality while adopting the cleaner 2.0 API patterns.