geo-deep-learning icon indicating copy to clipboard operation
geo-deep-learning copied to clipboard

Transforms and data augmentation: use Kornia

Open remtav opened this issue 2 years ago • 1 comments

Kornia offers great transforms for data augmentation. GDL could replace its in-house augmentations from augmentation.py for Kornia's ones. Initially, the problem was that torchvision's transforms were based on PIL and did not accept more than 3 bands. That is not the case with Kornia as it is much more flexible and offers lots of features (GPU-support, apply transforms on entire batch, etc.) that our in-house augmentations don't offer.

remtav avatar Sep 02 '22 15:09 remtav

(draft)

CURRENT - Torchvision approach

Instantiation:

  • SegmentationDataset()
  • transforms as torchvision.Compose object (with list of custom transforms classes each including minimally an init and call methods)
    • ends with a ToTensorTarget() converting numpy array to pytorch tensor (and a bit more...)

During training:

  1. SegmentationDataset's get_item returns a "sample" as dict contianing numpy arrays (keys: 'sat_img', 'map_img')
  2. transforms are applied sequentially
    • input: sample (dict with values as np.arrays)
    • output: sample (dict with values as torch.Tensors)

PROPOSED - translate to Kornia

Requirements:

  • input as torch.Tensor, can be dict with values as torch.Tensors if working with torchgeo's version of AugmentationSequential
  • all transforms, custom or not, must be subclasses from nn.Module

Instantiation:

  • SegmentationDataset()
  • transforms as list of transforms passed to torchgeo.transforms.AugmentationSequential()
    • Radiometric transforms
    • Geometric transforms

Implementation: SegmentationDataset: _load_image() in get_item SegmentationDataset: _load_target() in get_item SegmentationDataset: subclass torchgeo's VisionDataset! augmentations.py: class Augmentation(), which accepts variable length list of transforms as nn.Module (ex.: Intensity)

Test:

  1. Test with GDL's input (dict of numpy arrays) but with Kornia's transforms inside a torchgeo.transforms.AugmentationSequential()
  2. Returns transformed sample with same keys?

remtav avatar Oct 19 '22 14:10 remtav