wilds
wilds copied to clipboard
VAL_CENTER Indice in camelyon17_dataset.py
Hi everyone,
I hope this is the right place to ask a question about the Camelyon17 dataset. My question is regarding the center-metadata indices for TEST_CENTER and VAL_CENTER, as defined in the camelyon17_dataset.py file. According to that file, the test and validation (OOD) centers are 0-indexed, with TEST_CENTER at index 2 and VAL_CENTER at index 1. My understanding is that this should correspond to the images shown in columns 5 and 4 of the paper (see the first image for reference). Is that correct?
When I naively plot the images according to their center labels per row (see the second image), I would expect the images for indices 2 and 1 to show the test and validation (OOD) slides in row 2 and 1 (zero-index) as well. Instead, it seems like the (validation) center indices are switched, with the test images corresponding to index 2 and validation (OOD) to index 4 instead of 1. Also inspecting the images directly in the data/patches directory showcases this behaviour e.g. patient 96 from center 4 seems to be Val (ODD) and e.g. patient 34 from center 1 seems to be part of train, at least visually to a layman.
Did I misunderstand something in the indexing or do you have any clue what could be wrong? Below are the images for reference and the code I used to generate them:
Wilds slides:
Slides as per my Code:
Thanks in advance for any clarification!
Code:
import os
import pandas as pd
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import defaultdict
# constants
DATA_DIR = '/data/camelyon17_v1.0'
PATCHES_DIR = os.path.join(DATA_DIR, 'patches')
METADATA_CSV = os.path.join(DATA_DIR, 'metadata.csv')
MAX_IMAGES_PER_COMBINATION = 5
NUM_DOMAINS = 5
NUM_CLASSES = 2
# Load the metadata
metadata_df = pd.read_csv(
METADATA_CSV,
index_col=0,
dtype={'patient': 'str'}
)
# Get labels
y_array = torch.LongTensor(metadata_df['tumor'].values)
# Get input image paths
input_paths = [
os.path.join(
PATCHES_DIR,
f'patient_{patient}_node_{node}',
f'patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png'
)
for patient, node, x, y in metadata_df[['patient', 'node', 'x_coord', 'y_coord']].values
]
# Get domains (centers)
centers = metadata_df['center'].astype(int).values
# Organize images into a dictionary keyed by (domain, class)
images_dict = defaultdict(list)
for img_path, label, domain in zip(input_paths, y_array, centers):
key = (domain, label.item())
if len(images_dict[key]) < MAX_IMAGES_PER_COMBINATION:
try:
img = Image.open(img_path).convert('RGB')
images_dict[key].append(img)
except Exception as e:
print(f"Error loading image {img_path}: {e}")
# plot
fig, axes = plt.subplots(nrows=NUM_DOMAINS, ncols=NUM_CLASSES * MAX_IMAGES_PER_COMBINATION, figsize=(24, 12))
plt.subplots_adjust(wspace=0.05, hspace=0.05)
for domain_idx in range(NUM_DOMAINS):
for class_idx in range(NUM_CLASSES):
key = (domain_idx, class_idx)
images = images_dict.get(key, [])
for img_idx in range(MAX_IMAGES_PER_COMBINATION):
col_idx = class_idx * MAX_IMAGES_PER_COMBINATION + img_idx
ax = axes[domain_idx, col_idx]
if img_idx < len(images):
ax.imshow(images[img_idx])
ax.axis('off')
if domain_idx == 0 and img_idx == 0:
ax.set_title(f"Class {class_idx}")
# Add domain labels to the first image in each row
if class_idx == 0:
ax = axes[domain_idx, 0]
ax.text(-30, 32, f"Domain {domain_idx}", rotation=90, va='center', fontsize=12)
#ax.text(-150, images[0].size[1] // 2, f"Domain {domain_idx}", rotation=90, va='center', fontsize=12)
plt.tight_layout()
plt.savefig("slides_per_domain_class.png")
Or very straightforward and then inspect:
import os
from wilds import get_dataset
def save_images():
# Create the 'images' directory if it doesn't exist
if not os.path.exists('images'):
os.makedirs('images')
# Load the camelyon17 dataset
dataset = get_dataset(dataset='camelyon17', download=True)
# Get the validation and test subsets
val_data = dataset.get_subset('val')
# Save the first 10 images from the validation set
for i in range(10):
x, y, metadata = val_data[i]
x.save('images/val{}.png'.format(i+1))
if __name__ == '__main__':
save_images()
Cheers David
Originally posted by @David-Drexlin in https://github.com/p-lambda/wilds/discussions/163