vision icon indicating copy to clipboard operation
vision copied to clipboard

ToRGB() Transform

Open kevinMEH opened this issue 2 years ago • 3 comments

🚀 The feature

Some datasets (like ImageNet) contain both RGB images and grayscale images; to make the images batchable, they all need to have the same channel dimension. The following transform solves this problem by converting all grayscale images to RGB without branching.

class ThreeChannel(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, input: Tensor):
        return input.expand(3, -1, -1)

Unfortunately, it is only compatible with 3D Tensors in [C, X, X] format. However, since transforms are applied per image in most torchvision datasets, this shouldn't be a problem.

Motivation, pitch

Some datasets (like ImageNet) contain both RGB images and grayscale images; to make the images batchable, they all need to have the same channel dimension.

Alternatives

No response

Additional context

No response

kevinMEH avatar Dec 22 '23 22:12 kevinMEH

Thanks for the feature request @kevinMEH . That sounds reasonable but before we commit to adding a ToRGB() transform, have you considered decoding the images directly into RGB format? E.g. either using img.convert("RGB") if img is a PIL image, or if you're decoding tensors directly you could use read_image(mode=ImageReadMode.RGB).

LMK if this doesn't address your use-case

NicolasHug avatar Jan 01 '24 09:01 NicolasHug

Thanks for the suggestion! I did not even realize that image.convert("RGB") was an option (I am not familiar with PIL images). I am now using the following transform before ToImage() and it seems to work. Perhaps it can be added to the default transforms library?

class ToRGB(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, image):
        return image.convert("RGB")

kevinMEH avatar Jan 01 '24 18:01 kevinMEH

Glad it works. Calling .convert("RGB") is actually what happens by default in the datasets, but it's true that it's a bit hidden.

We could add a new ToRGB() transform to extract that logic out. One minor caveat here is that converting to RGB while decoding (e.g. through read_image(mode=ImageReadMode.RGB)) is probably more efficient than converting after (extra copy, etc.). But this can probably be simply mentioned in the docs.

NicolasHug avatar Jan 02 '24 10:01 NicolasHug