ml4floods
ml4floods copied to clipboard
[DataPrep] PyTorch torchvision transforms for datamodule
Currently, we're using the albumentations
library. I would suggest we change this to torchvision
because it's more consistent. It would require minimum API changes but we need to ensure it's consistent.
Demo w. pytorch
Note: rasterio
also uses CxHxW
so we don't need to permute the channels in this case.
Source: PyTorch Docs
Example:
from torchvision import transforms, utils
# Stacked Transforms
tranform_permute = transforms.PermuteChannels()
tranform_toTensor = transforms.ToTensor()
scale = Rescale(256)
crop = RandomCrop(128)
mega_transform = transforms.Compose([Rescale(256),
RandomCrop(224)])
pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix, transforms=mega_transform)
Demo w. albumentations
Very similar to the torchvision
except the shape is CxHxW
instead of HxWxC
. So we need to have dedicated PermuteChannels()
class to ensure we have channels that make sense.
Note: rasterio
also uses CxHxW
so we do need to permute the channels in this case.
Source: Our Notebook
Example:
# Stacked Transforms
tranform_permute = transformations.PermuteChannels()
tranform_toTensor = transformations.ToTensor()
tranform_oneHotEncoding = transformations.OneHotEncoding(num_classes=3)
mega_transform = transformations.Compose([
transform_invpermutechannels,
transform_resize,
transform_gaussnoise,
transform_motionblur,
transform_rr90,
transform_flip,
tranform_permute,
# tranform_toTensor,
# tranform_oneHotEncoding,
])
pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix, transforms=mega_transform)