multi-temporal-crop-classification-baseline icon indicating copy to clipboard operation
multi-temporal-crop-classification-baseline copied to clipboard

show_random_patches does not show image

Open robmarkcole opened this issue 1 year ago • 3 comments

Image is not visible:

image

robmarkcole avatar Jan 08 '24 16:01 robmarkcole

Hi Robin, That function is from an older release of the code which we decided to remove later on. I will add the correct one sometime this week.

samKhallaghi avatar Jan 08 '24 19:01 samKhallaghi

You can use this code snippet to visualize random samples from your custom dataset. Notice that you call the function before passing the dataset into the pytorch dataloader.

samKhallaghi avatar Jan 22 '24 18:01 samKhallaghi

import numbers import matplotlib.pyplot as plt import random import torch

def show_random_patches(dataset, sample_num, rgb_bands=(3, 2, 1)): """ Plots a user-defined number of image chips and the corresponding labels. """

if not (isinstance(rgb_bands, (tuple, list)) and len(rgb_bands) == 3 and 
        all(isinstance(band, numbers.Number) for band in rgb_bands) and
        1 <= sample_num <= len(dataset)):
    raise ValueError("'sample_num' or 'rgb_bands' are not properly defined")

# Sample indices for visualization
sample_indices = random.sample(range(len(dataset)), sample_num)

fig, axs = plt.subplots(nrows=sample_num, ncols=2, figsize=(16, sample_num * 16 / 2), squeeze=False)

for i, idx in enumerate(sample_indices):
    img, lbl, _ = dataset[idx]
    r, g, b = (img[band,:,:].cpu().view(1, *lbl.shape) for band in rgb_bands)

    axs[i, 0].set_title(f'Image Patch #{idx}')
    axs[i, 0].imshow(torch.cat([r, g, b], 0).permute(1, 2, 0))
    axs[i, 1].set_title(f'Label Patch #{idx}')
    axs[i, 1].imshow(lbl)

plt.show()

samKhallaghi avatar Jan 22 '24 18:01 samKhallaghi