lightly
lightly copied to clipboard
Supporting Tensor as input to ImageCollateFunction
I'm using ImageCollateFunction to augment images in SimClr and BYOL methods. I've have implemented a torch Dataset which is then passed to LightlyDataset, like LightlyDataset.from_torch_dataset(torch_dataset). My torch dataset's __getitem__ method returns image and label pair.
The issue is that ImageCollateFunction supports only PIL image as input, while my dataset is naturally in Tensor format.
I've added conversion to PIL image on my side, but I think that Tensor input should be supported, as it's a natural format for both Pytorch and many datasets. Also, almost all of the transformation functions support both PIL and Tensor format. It would be faster to avoid torch -> PIL -> torch transformations, and instead just keep data in torch format.
My suggestion is to do the following:
- GaussianBlur function currently relies on PIL's gaussian blur, which supports only PIL format. This class should be changed to use torchvision's GaussianBlur which supports both PIL and Tensor as input (and output)
- Instead of
T.ToTensor()in the list of transforms, a custom `ToTensorCustom()' should be called, which supports both Tensor and PIL image as input, and returns Tensor. The logic would be trivial:
def ToTensorCustom(img):
if isinstance(img, Tensor):
return img
return T.ToTensor(img)
What do you think about this, is there an easier or clearer way to do it? I'd be happy to implement this myself and create a PR.
Kudos to @guarin who helped me with this.
Yes, I think it's now safe to switch to gaussian blur by torchvision. We initially kept our version to make sure we keep backward compatibility. But torchvision already introduced the blur in version 0.8. We could also support both, tensors and PIL images as input. We just need to make sure the ordering of the channels is correct or does not cause any issues. And we would need to check all the other transforms we apply and whether they are already supported or not. I'd try to make sure we are at least version 0.8 backward compatible if possible.
Thanks @lukasugar for making the issue and investigations!
I would suggest the following changes:
- Change gaussian blur to torchvision
- Add deprecation warning when using our gaussian blur transform
- Add a
to_tensorargument to the collate functions. IfTruethen add aToTensortransform to the applied transforms, otherwise don't add it. This is consistent with how we handle image normalization using thenormalizeargument. - The
ToTensorCustomtransform should not be necessary when adding ato_tensorargument.
I agree, there's no need for ToTensorCustom then. I'll create a PR with these changes and link the issue.
That sounds great! Thanks a lot, looking forward to your pr :)
Sadly, I forgot to make this change long time ago... I see now that GaussianBlur has been changed to use radius, which is incompatible with torchvision version of GaussianBlur, so will mark this issue as "won't fix"
Sadly, I forgot to make this change long time ago... I see now that GaussianBlur has been changed to use radius, which is incompatible with torchvision version of GaussianBlur, so will mark this issue as "won't fix"
Yes, we didn't use GaussianBlur from torchvision because it was added in torchvision 0.8 and we didn't want to increase the minimal torchvision version. We'll eventually do this in the future.
In the meantime, you could use an extra function to modify the transforms along the lines of:
from torchvision.transforms import ToTensor, Compose, GaussianBlur
def get_transform(t):
if isinstance(t, ToTensor):
# return identity transform
return lambda x: x
elif isinstance(t, lightly.transforms.gaussian_blur.GaussianBlur):
# put kernel size you need
return GaussianBlur(5)
transform = SimCLRTransform()
for view_transform in transform.transforms:
view_transform.transform = Compose([
get_transform(t) for t in view_transform.transform.transforms
])