torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Incorrect Std and Mean in OSCDDataModule?

Open Dibz15 opened this issue 1 year ago • 3 comments

Description

I've been using the OSCDDataModule to interact with OSCD data and train a model. I found myself pulling the mean and std to normalize my data for inference after model training is completed. However, I was finding that after normalizing my data (loaded using the OSCD class) the mean and std of my image was not reflecting a mean near 0 nor an std near 1.

Here's the values when using OSCDDataModule.mean and OSCDDataModule.std to normalize an image (note that these are the full 26 channels from the two images pre/post:

Before normalization:
Mean: tensor([1528.7535, 1349.5966, 1327.5056, 1416.8456, 1577.3069, 2114.3342,
        2411.4072, 2284.1833, 2601.5127,  549.7957,   10.6988, 2457.2490,
        1744.0032, 1554.7324, 1313.2804, 1176.0295, 1187.1091, 1358.0902,
        1746.9423, 1924.1726, 1883.0494, 2062.9543,  785.6702,   13.0836,
        1950.4657, 1452.3799])
Std: tensor([175.3837, 306.4943, 377.7910, 550.9829, 450.5551, 429.9716, 491.0525,
        540.5790, 530.8759,  86.6327,   1.9990, 693.3192, 586.7882, 142.9166,
        295.5955, 350.6173, 484.4846, 397.5941, 417.5098, 468.0989, 572.7402,
        512.6295, 124.8398,   1.6803, 617.9036, 535.3190])

After normalization:
Mean: tensor([-1.0407, -0.2964,  0.3155,  0.6036,  0.6684,  1.5635,  1.9915,  1.8116,
         2.0887, -1.9268, -1.1448,  1.7656,  1.2219, -0.5430, -0.7317, -1.1177,
        -0.9164, -0.8182, -1.6056, -1.9634, -1.6872, -1.7187,  1.2922, -0.6517,
        -0.6083, -0.4029])
Std: tensor([3.3602, 3.6743, 3.5743, 3.6455, 3.0554, 3.7089, 3.9859, 4.7151, 3.7530,
        1.1823, 0.4133, 3.2477, 3.2694, 2.7382, 3.5436, 3.3172, 3.2055, 2.6963,
        3.6014, 3.7996, 4.9956, 3.6240, 1.7037, 0.3474, 2.8944, 2.9826])

I made an effort to calculate the mean and std of the OSCD dataset myself, and then ran the normalization again:

After normalization:
Mean: tensor([-0.1541, -0.0384,  0.0794,  0.1540,  0.2017,  0.3333,  0.3881,  0.3339,
         0.4205, -0.2845, -0.4222,  0.4906,  0.3719, -0.0597, -0.1259, -0.2023,
        -0.1461, -0.1010, -0.1500, -0.1859, -0.1289, -0.1648,  0.4477, -0.1477,
        -0.0067,  0.0376])
Std: tensor([0.6379, 0.7391, 0.7027, 0.7197, 0.6221, 0.5656, 0.5785, 0.6236, 0.5769,
        0.2689, 0.2301, 0.6803, 0.6728, 0.5198, 0.7128, 0.6521, 0.6329, 0.5490,
        0.5492, 0.5515, 0.6607, 0.5571, 0.3875, 0.1934, 0.6063, 0.6138])

Here are my calculated mean and std values for the 13 S2 channels:

mean = torch.tensor([1571.1372, 1365.5087, 1284.8223, 1298.9539, 1431.2260, 1860.9531,
                2081.9634, 1994.7665, 2214.5986,  641.4485,   14.3672, 1957.3165,
                1419.6107])
std =  torch.tensor([274.9591,  414.6901,  537.6539,  765.5303,  724.2261,  760.2133,
                848.7888,  866.8081,  920.1696,  322.1572,    8.6878, 1019.1249,
                872.1970])

If I've done something wrong, please let me know. This is just what I had to do to get values in the range that I was expecting for my model, so I wanted to share it here.

Steps to reproduce

  1. Ensure kornia is installed: pip install kornia
  2. Run the following code to test.

from torchgeo.datasets import OSCD
from torchgeo.datamodules import OSCDDataModule
import torch
import kornia.augmentation as K

def normalize_sample(sample, mean, std):
    image = sample['image'].float()
    if len(image.shape) < 4:
        image = image.unsqueeze(0)
    normalize = K.Normalize(mean, std)
    normalized_image = normalize(image)
    sample['image'] = normalized_image
    return sample

def get_norm_coefficients_old(bands="rgb"):
    mean = OSCDDataModule.mean
    std = OSCDDataModule.std
    if bands == "rgb":
        mean = mean[[3, 2, 1]]
        std = std[[3, 2, 1]]
    
    mean = torch.cat([mean, mean], dim=0)
    std = torch.cat([std, std], dim=0)
    return mean, std

def get_norm_coefficients_new(bands="rgb"):
    mean = torch.tensor([1571.1372, 1365.5087, 1284.8223, 1298.9539, 1431.2260, 1860.9531,
                    2081.9634, 1994.7665, 2214.5986,  641.4485,   14.3672, 1957.3165,
                    1419.6107])
    std =  torch.tensor([274.9591,  414.6901,  537.6539,  765.5303,  724.2261,  760.2133,
                    848.7888,  866.8081,  920.1696,  322.1572,    8.6878, 1019.1249,
                    872.1970])
    if bands == "rgb":
        mean = mean[[3, 2, 1]]
        std = std[[3, 2, 1]]
    
    mean = torch.cat([mean, mean], dim=0)
    std = torch.cat([std, std], dim=0)
    return mean, std

def get_channelwise_stats(image: torch.Tensor):
    if len(image.shape) < 4:
        image = image.unsqueeze(0) # Add batch dimension
    mean = image.mean(dim=[0, 2, 3])  # average across batch dimension, height, width
    std = image.std(dim=[0, 2, 3])  # standard deviation across batch dimension, height, width
    return mean, std

def print_stats(sample):
    mean, std = get_channelwise_stats(sample['image'].float())
    print(f'Mean: {mean}')
    print(f'Std: {std}')

dataset = OSCD(root="datasets/OSCD", split="test", bands="all", download=True)
old_mean, old_std = get_norm_coefficients_old(bands="all")
new_mean, new_std = get_norm_coefficients_new(bands="all")

idx = 0
# Print stats before normalization
print("Before normalization:")
print_stats(dataset[idx])

# Normalize
sample_old = normalize_sample(dataset[idx], old_mean, old_std)
sample_new = normalize_sample(dataset[idx], new_mean, new_std)

# Print stats after normalization
print("\nAfter normalization, old coefficients:")
print_stats(sample_old)

print("\nAfter normalization, new coefficients:")
print_stats(sample_new)
  1. (Optional) Recalculate mean/std for OSCD. The code I used is below:

from torch.utils.data import DataLoader

def CalcMeanVar(root, split="train", bands="rgb"):
    def t(img_dict):
        return {'image': img_dict['image'].to(torch.float), 'mask': img_dict['mask']}

    dataset = OSCD(root=root, split=split, bands=bands, download=True, transforms = t)
    loader = DataLoader(dataset, batch_size=1, num_workers=0)

    def preproc_img(img_dict):
        images = img_dict['image']
        batch_samples = images.size(0)
        B, C, W, H = images.size()
        # Separate the tensor into two tensors of shape (B, 3, W, H)
        image1 = images[:, :(C//2), :, :]
        image2 = images[:, (C//2):, :, :]
        # Stack them to get a tensor of shape (2B, 3, W, H)
        images = torch.cat((image1, image2), dim=0)
        images = images.view(-1, C//2, W, H)
        return images

    def compute_dataset_mean_std(dataloader):
        ex_img = preproc_img(next(iter(dataloader))).shape[1]
        total_sum = torch.zeros(ex_img)
        total_sq_sum = torch.zeros(ex_img)
        total_num_pixels = 0

        for batch in dataloader:
            image = preproc_img(batch).float()
            total_sum += image.sum(dim=[0, 2, 3])  # sum of pixel values in each channel
            total_sq_sum += (image ** 2).sum(dim=[0, 2, 3])  # sum of squared pixel values in each channel
            total_num_pixels += image.shape[0] * image.shape[2] * image.shape[3]  # total number of pixels in an image

        mean = total_sum / total_num_pixels  # mean = total sum / total number of pixels
        std = (total_sq_sum / total_num_pixels - mean ** 2) ** 0.5  # std = sqrt(E[X^2] - E[X]^2)

        return mean, std
    return compute_dataset_mean_std(loader)

mean, std = CalcMeanVar(root="datasets/OSCD", split="train", bands="all")
print(f'Calculated mean: {mean}')
print(f'Calculated std: {std}')

Version

0.5.0.dev0 (d0c773ffa03a15124b632025e2f11df2d3c798e0)

Dibz15 avatar Jun 23 '23 16:06 Dibz15

@iejMac @calebrob6 may have an idea where the original stats came from.

adamjstewart avatar Jun 23 '23 20:06 adamjstewart

Ping @calebrob6

adamjstewart avatar Sep 06 '23 20:09 adamjstewart