PyHealth
PyHealth copied to clipboard
Add CheX MultiLabelClassification
1. Code File Headers
File: pyhealth/datasets/chexphoto.py
"""
CheXphoto Dataset Implementation
Author: Prithvi Balaji ([email protected])
Paper: CheXphoto: 10,000+ Photos and Transformations of Chest X-rays for Benchmarking Deep Learning Robustness
Paper URL: https://arxiv.org/abs/2007.06199
Description: Loads CheXphoto dataset with natural/synthetic X-rays and 14 pathology labels.
"""
from pyhealth.datasets import BaseImageDataset
from typing import Dict, List
import os
import pandas as pd
class CheXphotoDataset(BaseImageDataset):
"""
A dataset class for CheXphoto X-rays with natural/synthetic transformations.
Args:
root (str): Root directory containing images and labels.csv.
csv_file (str, optional): Name of CSV file. Defaults to "labels.csv".
transform (callable, optional): Transformations for images.
Returns:
Dict[str, Patient]: Patients with X-ray paths and multi-label vectors.
"""
def __init__(self, root: str, csv_file: str = "labels.csv", transform=None):
super().__init__(dataset_name="CheXphoto", root=root)
self.csv_file = csv_file
self.transform = transform
self.data = self.load_data()
def load_data(self) -> Dict[str, List]:
"""Loads image paths and labels from CSV."""
# Implementation
2. Task Function Documentation
File: pyhealth/tasks/chexphoto_multilabel.py
"""
CheXphoto Multi-Label Classification Task
Author: Prithvi Balaji ([email protected])
Paper: CheXphoto: 10,000+ Photos and Transformations of Chest X-rays for Benchmarking Deep Learning Robustness
Paper URL: https://arxiv.org/abs/2007.06199
Description: Converts CheXphotoDataset into multi-label classification samples.
"""
from pyhealth.tasks import BaseTask
from pyhealth.datasets import CheXphotoDataset
def chexphoto_multilabel_task(dataset: CheXphotoDataset) -> List[Dict]:
"""
Processes CheXphotoDataset into multi-label samples.
Args:
dataset (CheXphotoDataset): Loaded dataset instance.
Returns:
List[Dict]: Samples with keys 'input' (image path) and 'label' (14-pathology vector).
Example:
>>> from pyhealth.datasets import CheXphotoDataset
>>> dataset = CheXphotoDataset(root="./data")
>>> samples = chexphoto_multilabel_task(dataset)
"""
samples = []
for sample in dataset.data:
samples.append({"input": sample["image"], "label": sample["label"]})
return samples
3. Example Notebook
File: examples/chexphoto_pyhealth_example.ipynb
# CheXphoto Multi-Label Classification Example
# Author: Prithvi Balaji ([email protected])
from pyhealth.datasets import CheXphotoDataset
from pyhealth.tasks import chexphoto_multilabel_task
from pyhealth.models import CNN
# Load dataset
dataset = CheXphotoDataset(root="./data/chexphoto")
samples = chexphoto_multilabel_task(dataset)
# Split data
train_ds, val_ds, test_ds = split_by_patient(samples, [0.8, 0.1, 0.1])
# Initialize model
model = CNN(
dataset=samples,
feature_keys=["input"],
label_key="label",
mode="multilabel",
output_dim=14
)
# Train
trainer = Trainer(model=model)
trainer.train(train_dataloader=train_loader, epochs=50, monitor="roc_auc")
# Evaluate
trainer.evaluate(test_loader)
4. Test Cases
File: tests/test_chexphoto.py
def test_chexphoto_dataset():
dataset = CheXphotoDataset(root="./test_data")
assert len(dataset) > 0, "Dataset failed to load"
sample = dataset[0]
assert "input" in sample and "label" in sample, "Invalid sample format"
def test_chexphoto_task():
dataset = CheXphotoDataset(root="./test_data")
samples = chexphoto_multilabel_task(dataset)
assert len(samples) == len(dataset), "Task conversion failed"
5. Pull Request Description
Title:
Add CheXphoto Dataset and Multi-Label Classification Task
Body:
### Who I Am
- **Name**: Prithvi Balaji
- **NetID**: pbalaji3 (UIUC student)
### Contribution Type
- **Dataset**: `CheXphotoDataset` in `pyhealth/datasets/chexphoto.py`
- **Task**: `chexphoto_multilabel_task` in `pyhealth/tasks/chexphoto_multilabel.py`
- **Example**: `examples/chexphoto_pyhealth_example.ipynb`
### High-Level Description
This PR adds support for the CheXphoto dataset (Phillips et al., 2020), enabling multi-label classification of 14 chest X-ray pathologies. Key features:
- Loads natural/synthetic X-rays with labels
- Task function for PyHealth model integration
- Example notebook for training/evaluation
### Files to Review
- `pyhealth/datasets/chexphoto.py`
- `pyhealth/tasks/chexphoto_multilabel.py`
- `examples/chexphoto_pyhealth_example.ipynb`
- `tests/test_chexphoto.py`
### Testing Instructions
1. Run `pytest tests/test_chexphoto.py`
2. Execute the example notebook end-to-end
-need to add tests later