ml4floods icon indicating copy to clipboard operation
ml4floods copied to clipboard

[DataPrep] PyTorch torchvision transforms for datamodule

Open jejjohnson opened this issue 4 years ago • 0 comments

Currently, we're using the albumentations library. I would suggest we change this to torchvision because it's more consistent. It would require minimum API changes but we need to ensure it's consistent.


Demo w. pytorch

Note: rasterio also uses CxHxW so we don't need to permute the channels in this case.

Source: PyTorch Docs

Example:

from torchvision import transforms, utils

# Stacked Transforms
tranform_permute = transforms.PermuteChannels()
tranform_toTensor = transforms.ToTensor()

scale = Rescale(256)
crop = RandomCrop(128)
mega_transform = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix, transforms=mega_transform)

Demo w. albumentations

Very similar to the torchvision except the shape is CxHxW instead of HxWxC. So we need to have dedicated PermuteChannels() class to ensure we have channels that make sense.

Note: rasterio also uses CxHxW so we do need to permute the channels in this case.

Source: Our Notebook

Example:

# Stacked Transforms
tranform_permute = transformations.PermuteChannels()
tranform_toTensor = transformations.ToTensor()
tranform_oneHotEncoding = transformations.OneHotEncoding(num_classes=3)

mega_transform = transformations.Compose([
    transform_invpermutechannels, 
    transform_resize,
    transform_gaussnoise,
    transform_motionblur,
    transform_rr90,
    transform_flip,
    tranform_permute, 
#     tranform_toTensor, 
#     tranform_oneHotEncoding,
    ])

pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix, transforms=mega_transform)

jejjohnson avatar Feb 22 '21 15:02 jejjohnson