PyHealth icon indicating copy to clipboard operation
PyHealth copied to clipboard

Add CheX MultiLabelClassification

Open prithbalaji opened this issue 1 year ago • 0 comments

1. Code File Headers

File: pyhealth/datasets/chexphoto.py

"""  
CheXphoto Dataset Implementation  
Author: Prithvi Balaji ([email protected])  
Paper: CheXphoto: 10,000+ Photos and Transformations of Chest X-rays for Benchmarking Deep Learning Robustness  
Paper URL: https://arxiv.org/abs/2007.06199  
Description: Loads CheXphoto dataset with natural/synthetic X-rays and 14 pathology labels.  
"""

from pyhealth.datasets import BaseImageDataset  
from typing import Dict, List  
import os  
import pandas as pd  

class CheXphotoDataset(BaseImageDataset):  
    """  
    A dataset class for CheXphoto X-rays with natural/synthetic transformations.  

    Args:  
        root (str): Root directory containing images and labels.csv.  
        csv_file (str, optional): Name of CSV file. Defaults to "labels.csv".  
        transform (callable, optional): Transformations for images.  

    Returns:  
        Dict[str, Patient]: Patients with X-ray paths and multi-label vectors.  
    """  
    def __init__(self, root: str, csv_file: str = "labels.csv", transform=None):  
        super().__init__(dataset_name="CheXphoto", root=root)  
        self.csv_file = csv_file  
        self.transform = transform  
        self.data = self.load_data()  

    def load_data(self) -> Dict[str, List]:  
        """Loads image paths and labels from CSV."""  
        # Implementation  

2. Task Function Documentation

File: pyhealth/tasks/chexphoto_multilabel.py

"""  
CheXphoto Multi-Label Classification Task  
Author: Prithvi Balaji ([email protected])  
Paper: CheXphoto: 10,000+ Photos and Transformations of Chest X-rays for Benchmarking Deep Learning Robustness  
Paper URL: https://arxiv.org/abs/2007.06199  
Description: Converts CheXphotoDataset into multi-label classification samples.  
"""

from pyhealth.tasks import BaseTask  
from pyhealth.datasets import CheXphotoDataset  

def chexphoto_multilabel_task(dataset: CheXphotoDataset) -> List[Dict]:  
    """  
    Processes CheXphotoDataset into multi-label samples.  

    Args:  
        dataset (CheXphotoDataset): Loaded dataset instance.  

    Returns:  
        List[Dict]: Samples with keys 'input' (image path) and 'label' (14-pathology vector).  

    Example:  
        >>> from pyhealth.datasets import CheXphotoDataset  
        >>> dataset = CheXphotoDataset(root="./data")  
        >>> samples = chexphoto_multilabel_task(dataset)  
    """  
    samples = []  
    for sample in dataset.data:  
        samples.append({"input": sample["image"], "label": sample["label"]})  
    return samples  

3. Example Notebook

File: examples/chexphoto_pyhealth_example.ipynb

# CheXphoto Multi-Label Classification Example  
# Author: Prithvi Balaji ([email protected])  

from pyhealth.datasets import CheXphotoDataset  
from pyhealth.tasks import chexphoto_multilabel_task  
from pyhealth.models import CNN  

# Load dataset  
dataset = CheXphotoDataset(root="./data/chexphoto")  
samples = chexphoto_multilabel_task(dataset)  

# Split data  
train_ds, val_ds, test_ds = split_by_patient(samples, [0.8, 0.1, 0.1])  

# Initialize model  
model = CNN(  
    dataset=samples,  
    feature_keys=["input"],  
    label_key="label",  
    mode="multilabel",  
    output_dim=14  
)  

# Train  
trainer = Trainer(model=model)  
trainer.train(train_dataloader=train_loader, epochs=50, monitor="roc_auc")  

# Evaluate  
trainer.evaluate(test_loader)  

4. Test Cases

File: tests/test_chexphoto.py

def test_chexphoto_dataset():  
    dataset = CheXphotoDataset(root="./test_data")  
    assert len(dataset) > 0, "Dataset failed to load"  
    sample = dataset[0]  
    assert "input" in sample and "label" in sample, "Invalid sample format"  

def test_chexphoto_task():  
    dataset = CheXphotoDataset(root="./test_data")  
    samples = chexphoto_multilabel_task(dataset)  
    assert len(samples) == len(dataset), "Task conversion failed"  

5. Pull Request Description

Title:
Add CheXphoto Dataset and Multi-Label Classification Task

Body:

### Who I Am  
- **Name**: Prithvi Balaji  
- **NetID**: pbalaji3 (UIUC student)  

### Contribution Type  
- **Dataset**: `CheXphotoDataset` in `pyhealth/datasets/chexphoto.py`  
- **Task**: `chexphoto_multilabel_task` in `pyhealth/tasks/chexphoto_multilabel.py`  
- **Example**: `examples/chexphoto_pyhealth_example.ipynb`  

### High-Level Description  
This PR adds support for the CheXphoto dataset (Phillips et al., 2020), enabling multi-label classification of 14 chest X-ray pathologies. Key features:  
- Loads natural/synthetic X-rays with labels  
- Task function for PyHealth model integration  
- Example notebook for training/evaluation  

### Files to Review  
- `pyhealth/datasets/chexphoto.py`  
- `pyhealth/tasks/chexphoto_multilabel.py`  
- `examples/chexphoto_pyhealth_example.ipynb`  
- `tests/test_chexphoto.py`  

### Testing Instructions  
1. Run `pytest tests/test_chexphoto.py`  
2. Execute the example notebook end-to-end  



-need to add tests later

prithbalaji avatar May 08 '25 09:05 prithbalaji