Don't assume tv_tensors.Image for pure tensors
🚀 The feature
Make transforms.v2 ignore raw tensors / let them pass through instead of assuming that they are images.
- This could be activated manually by a flag in order to not disrupt the existing API.
- Or if one element of the sample is already an
Image(detected upon calling tree_flatten), don't assume the pure tensors are images too. - Alternatively, implement a new
TVTensortype which is registered for all transformation kernels with an identity / passthrough op. It would need to be thoroughly implemented so that the TVTensor type does not get unwrapped by some transformations.
I understand that the current behaviour reduces the migration work for v1 code, but it also makes it a bit odd to work with for fresh code IMHO.
Motivation, pitch
I'm trying to use transforms.v2 in combination with tv_tensors. My samples contain a combination of an image, boxes, keypoints, labels, and some extra attributes stored as tensors (they are multi-dimensional, not scalar). Unfortunately, transforms.v2 modules get confused by the tensors.
Alternatives
- Implement a
PassthroughTVTensor myself. This is verbose and tedious since I don't know the complete list of transformation kernels. - Don't use tv_tensors and implement all transformations as nn.Module. I pass only a tuple of the sample fields that need transformations and recompose the sample back together.
class RandomPhotometricDistort(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.t = T.RandomPhotometricDistort(*args, **kwargs)
def forward(self, sample):
return sample | {"image": self.t(sample["image"])}
Additional context
https://github.com/pytorch/vision/blob/ccb801b88af136454798b945175c4c87e636ac33/torchvision/transforms/v2/_type_conversion.py#L42
https://github.com/pytorch/vision/blob/ccb801b88af136454798b945175c4c87e636ac33/torchvision/transforms/v2/_type_conversion.py#L71
Hi @nlgranger
Thanks for the feature request. Can you help me understand exactly what you mean by "transforms.v2 modules get confused by the tensors" in:
trying to use transforms.v2 in combination with tv_tensors. My samples contain a combination of an image, boxes, keypoints, labels, and some extra attributes stored as tensors (they are multi-dimensional, not scalar). Unfortunately, transforms.v2 modules get confused by the tensors.
You suggested:
Or if one element of the sample is already an Image (detected upon calling tree_flatten), don't assume the pure tensors are images too.
This should already be the default behavior (see the note in this tutorial):
If there is an Image, Video, or PIL.Image.Image instance in the input, all other pure tensors are passed-through.
If that's not the case then this is possibly a bug.