wilds icon indicating copy to clipboard operation
wilds copied to clipboard

VAL_CENTER Indice in camelyon17_dataset.py

Open David-Drexlin opened this issue 3 months ago • 0 comments

Hi everyone,

I hope this is the right place to ask a question about the Camelyon17 dataset. My question is regarding the center-metadata indices for TEST_CENTER and VAL_CENTER, as defined in the camelyon17_dataset.py file. According to that file, the test and validation (OOD) centers are 0-indexed, with TEST_CENTER at index 2 and VAL_CENTER at index 1. My understanding is that this should correspond to the images shown in columns 5 and 4 of the paper (see the first image for reference). Is that correct?

When I naively plot the images according to their center labels per row (see the second image), I would expect the images for indices 2 and 1 to show the test and validation (OOD) slides in row 2 and 1 (zero-index) as well. Instead, it seems like the (validation) center indices are switched, with the test images corresponding to index 2 and validation (OOD) to index 4 instead of 1. Also inspecting the images directly in the data/patches directory showcases this behaviour e.g. patient 96 from center 4 seems to be Val (ODD) and e.g. patient 34 from center 1 seems to be part of train, at least visually to a layman.

Did I misunderstand something in the indexing or do you have any clue what could be wrong? Below are the images for reference and the code I used to generate them:

Wilds slides: camelyon_dataset

Slides as per my Code: slides_per_domain_class

Thanks in advance for any clarification!

Code:

import os
import pandas as pd
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import defaultdict

# constants
DATA_DIR = '/data/camelyon17_v1.0'
PATCHES_DIR = os.path.join(DATA_DIR, 'patches')
METADATA_CSV = os.path.join(DATA_DIR, 'metadata.csv')
MAX_IMAGES_PER_COMBINATION = 5
NUM_DOMAINS = 5  
NUM_CLASSES = 2  

# Load the metadata
metadata_df = pd.read_csv(
    METADATA_CSV,
    index_col=0,
    dtype={'patient': 'str'}
)

# Get labels
y_array = torch.LongTensor(metadata_df['tumor'].values)

# Get input image paths
input_paths = [
    os.path.join(
        PATCHES_DIR,
        f'patient_{patient}_node_{node}',
        f'patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png'
    )
    for patient, node, x, y in metadata_df[['patient', 'node', 'x_coord', 'y_coord']].values
]

# Get domains (centers)
centers = metadata_df['center'].astype(int).values

# Organize images into a dictionary keyed by (domain, class)
images_dict = defaultdict(list)

for img_path, label, domain in zip(input_paths, y_array, centers):
    key = (domain, label.item())
    if len(images_dict[key]) < MAX_IMAGES_PER_COMBINATION:
        try:
            img = Image.open(img_path).convert('RGB')
            images_dict[key].append(img)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")

# plot 
fig, axes = plt.subplots(nrows=NUM_DOMAINS, ncols=NUM_CLASSES * MAX_IMAGES_PER_COMBINATION, figsize=(24, 12))
plt.subplots_adjust(wspace=0.05, hspace=0.05)

for domain_idx in range(NUM_DOMAINS):
    for class_idx in range(NUM_CLASSES):
        key = (domain_idx, class_idx)
        images = images_dict.get(key, [])
        for img_idx in range(MAX_IMAGES_PER_COMBINATION):
            col_idx = class_idx * MAX_IMAGES_PER_COMBINATION + img_idx
            ax = axes[domain_idx, col_idx]
            if img_idx < len(images):
                ax.imshow(images[img_idx])
            ax.axis('off')

            if domain_idx == 0 and img_idx == 0:
                ax.set_title(f"Class {class_idx}")

        # Add domain labels to the first image in each row
        if class_idx == 0:
            ax = axes[domain_idx, 0]
            ax.text(-30, 32, f"Domain {domain_idx}", rotation=90, va='center', fontsize=12)
            #ax.text(-150, images[0].size[1] // 2, f"Domain {domain_idx}", rotation=90, va='center', fontsize=12)

plt.tight_layout()
plt.savefig("slides_per_domain_class.png")

Or very straightforward and then inspect:

import os
from wilds import get_dataset

def save_images():
    # Create the 'images' directory if it doesn't exist
    if not os.path.exists('images'):
        os.makedirs('images')

    # Load the camelyon17 dataset
    dataset = get_dataset(dataset='camelyon17', download=True)
    
    # Get the validation and test subsets
    val_data = dataset.get_subset('val')

    # Save the first 10 images from the validation set
    for i in range(10):
        x, y, metadata = val_data[i]
        x.save('images/val{}.png'.format(i+1))

if __name__ == '__main__':
    save_images()

Cheers David

Originally posted by @David-Drexlin in https://github.com/p-lambda/wilds/discussions/163

David-Drexlin avatar Nov 19 '24 13:11 David-Drexlin