multi-temporal-crop-classification-baseline
multi-temporal-crop-classification-baseline copied to clipboard
show_random_patches does not show image
Image is not visible:
Hi Robin, That function is from an older release of the code which we decided to remove later on. I will add the correct one sometime this week.
You can use this code snippet to visualize random samples from your custom dataset. Notice that you call the function before passing the dataset into the pytorch dataloader.
import numbers import matplotlib.pyplot as plt import random import torch
def show_random_patches(dataset, sample_num, rgb_bands=(3, 2, 1)): """ Plots a user-defined number of image chips and the corresponding labels. """
if not (isinstance(rgb_bands, (tuple, list)) and len(rgb_bands) == 3 and
all(isinstance(band, numbers.Number) for band in rgb_bands) and
1 <= sample_num <= len(dataset)):
raise ValueError("'sample_num' or 'rgb_bands' are not properly defined")
# Sample indices for visualization
sample_indices = random.sample(range(len(dataset)), sample_num)
fig, axs = plt.subplots(nrows=sample_num, ncols=2, figsize=(16, sample_num * 16 / 2), squeeze=False)
for i, idx in enumerate(sample_indices):
img, lbl, _ = dataset[idx]
r, g, b = (img[band,:,:].cpu().view(1, *lbl.shape) for band in rgb_bands)
axs[i, 0].set_title(f'Image Patch #{idx}')
axs[i, 0].imshow(torch.cat([r, g, b], 0).permute(1, 2, 0))
axs[i, 1].set_title(f'Label Patch #{idx}')
axs[i, 1].imshow(lbl)
plt.show()