ToRGB() Transform
🚀 The feature
Some datasets (like ImageNet) contain both RGB images and grayscale images; to make the images batchable, they all need to have the same channel dimension. The following transform solves this problem by converting all grayscale images to RGB without branching.
class ThreeChannel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input: Tensor):
return input.expand(3, -1, -1)
Unfortunately, it is only compatible with 3D Tensors in [C, X, X] format. However, since transforms are applied per image in most torchvision datasets, this shouldn't be a problem.
Motivation, pitch
Some datasets (like ImageNet) contain both RGB images and grayscale images; to make the images batchable, they all need to have the same channel dimension.
Alternatives
No response
Additional context
No response
Thanks for the feature request @kevinMEH . That sounds reasonable but before we commit to adding a ToRGB() transform, have you considered decoding the images directly into RGB format? E.g. either using img.convert("RGB") if img is a PIL image, or
if you're decoding tensors directly you could use read_image(mode=ImageReadMode.RGB).
LMK if this doesn't address your use-case
Thanks for the suggestion! I did not even realize that image.convert("RGB") was an option (I am not familiar with PIL images). I am now using the following transform before ToImage() and it seems to work. Perhaps it can be added to the default transforms library?
class ToRGB(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, image):
return image.convert("RGB")
Glad it works. Calling .convert("RGB") is actually what happens by default in the datasets, but it's true that it's a bit hidden.
We could add a new ToRGB() transform to extract that logic out. One minor caveat here is that converting to RGB while decoding (e.g. through read_image(mode=ImageReadMode.RGB)) is probably more efficient than converting after (extra copy, etc.). But this can probably be simply mentioned in the docs.