add CopernicusBenchBiomassS3 datamodule
Add datamodule for Copernicus-Bench Biomass-S3 dataset.
TO DISCUSS: are there official normalisation stats?
TO DISCUSS: are there official normalisation stats?
@wangyi111?
No, but applied per-band scale factors as https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S3_OLCI#bands.
@adamjstewart I will use these scale factors which I understand should be multiplied by each channel
Error implies inconsistent patch sizes:
TRAIN
Image typical: (21, 94, 94) (93.57%) from 3000 samples
Image anomalies:
- (21, 84, 94) (64 | 2.13%)
- (21, 94, 84) (62 | 2.07%)
- (21, 94, 95) (31 | 1.03%)
- (21, 95, 94) (18 | 0.6%)
- (21, 94, 85) (10 | 0.33%)
- (21, 85, 94) (5 | 0.17%)
- (21, 95, 95) (1 | 0.03%)
- (21, 84, 95) (1 | 0.03%)
- (21, 84, 84) (1 | 0.03%)
Mask typical: (282, 282) (95.23%) from 3000 samples
Mask anomalies:
- (282, 252) (72 | 2.4%)
- (252, 282) (70 | 2.33%)
- (252, 252) (1 | 0.03%)
VAL
Image typical: (21, 94, 94) (93.7%) from 1000 samples
Image anomalies:
- (21, 94, 84) (21 | 2.1%)
- (21, 84, 94) (18 | 1.8%)
- (21, 94, 95) (10 | 1.0%)
- (21, 95, 94) (6 | 0.6%)
- (21, 94, 85) (4 | 0.4%)
- (21, 85, 94) (4 | 0.4%)
Mask typical: (282, 282) (95.3%) from 1000 samples
Mask anomalies:
- (282, 252) (25 | 2.5%)
- (252, 282) (22 | 2.2%)
TEST
Image typical: (21, 94, 94) (94.2%) from 1000 samples
Image anomalies:
- (21, 94, 84) (19 | 1.9%)
- (21, 84, 94) (14 | 1.4%)
- (21, 94, 95) (14 | 1.4%)
- (21, 94, 85) (5 | 0.5%)
- (21, 95, 94) (3 | 0.3%)
- (21, 85, 94) (2 | 0.2%)
- (21, 84, 84) (1 | 0.1%)
Mask typical: (282, 282) (95.9%) from 1000 samples
Mask anomalies:
- (282, 252) (24 | 2.4%)
- (252, 282) (16 | 1.6%)
- (252, 252) (1 | 0.1%)
@adamjstewart should we resize to the standard size?
I would recommend a couple things:
- Change our
tests/data/copernicus/l3_biomass_s3data to have inconsistent sizes to reproduce this issue - Use
Resize((94, 94))as a dataset transform in the datamodule (not an aug)
@adamjstewart and resize the mask to 282?
Mask typical: (282, 282) (95.3%) from 1000 samples
Oh, I didn't realize the image and mask have different resolutions. That will be quite tricky. We should probably resize both to the same resolution so it becomes a simple pixel-wise regression task. @wangyi111 do you have a preference between making the image bigger or the mask smaller?
@adamjstewart for the test, I see there are some tifs, you want one of those resized to reproduce the error?
Correct. There should be a data.py used to generate those files. You can make the changes there so we can reproduce the test data.
Oh, I didn't realize the image and mask have different resolutions. That will be quite tricky. We should probably resize both to the same resolution so it becomes a simple pixel-wise regression task. @wangyi111 do you have a preference between making the image bigger or the mask smaller?
i'd like to keep:) but if it's must maybe make image bigger then
@wangyi111 I mean change it on-the-fly, not in your dataset. How do you evaluate your model if the input and output resolutions don't match? Are you training a J-net instead of a U-net?
@adamjstewart using
TARGET_SIZE = (282, 282)
...
resize = K.Resize(size=TARGET_SIZE)
extra_args = {
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
}
if mode == 'time-series':
self.aug = K.AugmentationSequential(
K.VideoSequential(resize, normalizer),
data_keys=None,
keepdim=True,
same_on_batch=True,
extra_args=extra_args,
)
else:
self.aug = K.AugmentationSequential(
resize, normalizer, data_keys=None, keepdim=True, extra_args=extra_args
)
I still get RuntimeError: Trying to resize storage that is not resizable - is this the correct resize approach?
Update: If I swap out AugmentationSequential the error is cleared:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""Copernicus-Bench Biomass-S3 datamodule."""
from typing import Any
import kornia.augmentation as K
import torch
from kornia.constants import Resample
from ...datasets import CopernicusBenchBiomassS3
from ..geo import NonGeoDataModule
# Multiplicative scale factors from
# https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S3_OLCI#bands
SCALE = {
'Oa01_radiance': 0.0139465,
'Oa02_radiance': 0.0133873,
'Oa03_radiance': 0.0121481,
'Oa04_radiance': 0.0115198,
'Oa05_radiance': 0.0100953,
'Oa06_radiance': 0.0123538,
'Oa07_radiance': 0.00879161,
'Oa08_radiance': 0.00876539,
'Oa09_radiance': 0.0095103,
'Oa10_radiance': 0.00773378,
'Oa11_radiance': 0.00675523,
'Oa12_radiance': 0.0071996,
'Oa13_radiance': 0.00749684,
'Oa14_radiance': 0.0086512,
'Oa15_radiance': 0.00526779,
'Oa16_radiance': 0.00530267,
'Oa17_radiance': 0.00493004,
'Oa18_radiance': 0.00549962,
'Oa19_radiance': 0.00502847,
'Oa20_radiance': 0.00326378,
'Oa21_radiance': 0.00324118,
}
TARGET_SIZE = (282, 282)
class CopernicusBenchBiomassS3DataModule(NonGeoDataModule):
"""LightningDataModule implementation for the Copernicus Biomass-S3 dataset.
Uses the train/val/test splits provided with the benchmark.
.. versionadded:: 0.81
"""
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new CopernicusBenchBiomassS3DataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.CopernicusBenchBiomassS3`.
"""
bands = kwargs.get('bands', SCALE.keys())
scale_factors = torch.tensor([SCALE[b] for b in bands], dtype=torch.float32)
self.mean = torch.zeros(len(bands), dtype=torch.float32)
self.std = torch.reciprocal(scale_factors)
super().__init__(CopernicusBenchBiomassS3, batch_size, num_workers, **kwargs)
self.image_resizer = K.Resize(
size=TARGET_SIZE,
resample=Resample.BILINEAR.name,
align_corners=False,
)
self.mask_resizer = K.Resize(
size=TARGET_SIZE,
resample=Resample.NEAREST.name,
align_corners=False,
)
self.normalizer = K.Normalize(mean=self.mean, std=self.std)
def _resize_and_normalize(batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
batch['image'] = self._resize_image(batch['image'])
if 'mask' in batch:
batch['mask'] = self._resize_mask(batch['mask'])
return batch
self.aug = _resize_and_normalize
def _resize_image(self, image: torch.Tensor) -> torch.Tensor:
device = image.device
self.image_resizer = self.image_resizer.to(device)
self.normalizer = self.normalizer.to(device)
if image.ndim == 4:
resized = self.image_resizer(image)
return self.normalizer(resized)
if image.ndim == 5:
batch, time, channels, height, width = image.shape
flattened = image.view(-1, channels, height, width)
resized = self.image_resizer(flattened)
normalized = self.normalizer(resized)
return normalized.view(batch, time, channels, *TARGET_SIZE)
msg = (
'Expected image tensor with shape (batch, channels, height, width) or '
'(batch, time, channels, height, width).'
)
raise ValueError(msg)
def _resize_mask(self, mask: torch.Tensor) -> torch.Tensor:
device = mask.device
original_dtype = mask.dtype
self.mask_resizer = self.mask_resizer.to(device)
if mask.ndim == 3:
resized = self.mask_resizer(mask.unsqueeze(1).float())
return resized.squeeze(1).to(original_dtype)
if mask.ndim == 4:
batch, time, height, width = mask.shape
flattened = mask.view(-1, 1, height, width).float()
resized = self.mask_resizer(flattened)
return resized.view(batch, time, *TARGET_SIZE).to(original_dtype)
msg = (
'Expected mask tensor with shape (batch, height, width) or '
'(batch, time, height, width).'
)
raise ValueError(msg)
The error might be a red herring: https://github.com/ultralytics/ultralytics/issues/5319
I would use the transforms parameter in https://torchgeo.readthedocs.io/en/stable/api/datasets.html#torchgeo.datasets.CopernicusBenchBiomassS3.init
You need to do the resize on each sample before the mini-batch is collated. Collation will fail due to size mismatch.
@wangyi111 I mean change it on-the-fly, not in your dataset. How do you evaluate your model if the input and output resolutions don't match? Are you training a J-net instead of a U-net?
i see. when training i was upsampling image to match mask
I've iterated though a few images and they just don't appear to match the mask:
datamodule = CopernicusBenchBiomassS3DataModule(
root="/data/",
batch_size=4,
num_workers=4,
)
datamodule.setup(stage="fit")
train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))
batch = datamodule.aug(batch)
assert batch['image'].shape == torch.Size([4, 21, 282, 282])
index = 3
channel = 10
sample_image = batch['image'][index, channel, :, :].numpy()
sample_mask = batch['mask'][index, :, :].numpy()
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(sample_image)
ax[0].set_title('Sample Image')
ax[1].imshow(sample_mask, cmap='gray')
ax[1].set_title('Sample Mask')
@wangyi111 can you share your code for the RGB plotting in the paper? Adding a utility for that, examples below:
Did you figure out the non-matching image/mask bug?
The base class has a plot method. Is this not working?
I didn't notice the base class plot - these (using the base class plot) look decent so my hacky plot approach in a notebook must have an error in it
To add test coverage, add a config file to tests/conf/ and 1 line to TestPixelwiseRegressionTask.test_trainer parameterization in tests/trainers/test_regression.py.
@adamjstewart codecov still short
See https://app.codecov.io/gh/torchgeo/torchgeo/pull/3086 for the lines missing coverage, or install the codecov plugin to directly view them on GitHub.
New error now I am testing time-series:
RuntimeError: stack expects each tensor to be equal size, but got [21, 8, 8] at entry 0 and [21, 8, 10] at entry 1
@adamjstewart all green
IDK inspecting another batch, these just don't look right. @wangyi111 WDYT?
I have managed to train a model, so perhaps I am just unused to viewing the S3 images. However given these are large areas of forest, perhaps that is possible even if the images are poorly registered
@adamjstewart upgraded to kornia-0.8.2 and the image/mask still look unregistered. I don't think this is a code issue, perhaps the images themselves are not well registered. Suggest a follow up investigation be performed by @wangyi111