SanitizeBoundingBoxes does not support Semantic Segmentation mask
🐛 Describe the bug
I want to use one transform compose to apply data augmentation for my images with both bboxes, bboxes targets and a semantic segmentation mask. But:
torchvision.transforms.v2.SanitizeBoundingBoxes fails when used inside a v2.Compose that receives both bounding boxes and a semantic segmentation mask as inputs.
According to the docs, this transform should only remove invalid bounding boxes and their corresponding labels or instance masks. However, in practice, it also tries to apply the per-box boolean validity mask to the semantic mask, which causes a shape mismatch and a crash.
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
# Synthetic data
image = tv_tensors.Image(torch.randint(0, 255, (3, 1080, 1920), dtype=torch.uint8))
boxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 0, 10, 10], [50, 50, 70, 70], [100, 100, 101, 101]]),
format="XYXY",
canvas_size=(1080, 1920),
)
labels = torch.tensor([1, 2, 3])
# Semantic (2D) segmentation mask, not per-instance
semantic_mask = tv_tensors.Mask(torch.zeros((1080, 1920), dtype=torch.uint8))
# Custom label getter
def labels_getter(inputs):
img, boxes, labels, mask = inputs
return labels
transform = v2.Compose([
v2.SanitizeBoundingBoxes(labels_getter=labels_getter),
])
# This crashes
out = transform(image, boxes, labels, semantic_mask)
.local/lib/python3.10/site-packages/torchvision/tv_tensors/_tv_tensor.py", line 77, in __torch_function__
output = func(*args, **kwargs or dict())
IndexError: The shape of the mask [3] at index 0 does not match the shape of the indexed tensor [1080, 1920] at index 0
SanitizeBoundingBoxes.transform() currently does:
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
output = inpt[params["valid"]]
This treats any tv_tensors.Mask as per-box indexable, even semantic masks shaped [H, W]. When params["valid"] has shape [num_boxes], indexing a [H, W] tensor with it causes the IndexError.
For now I place the sanitizer after the compost only on the boxes, but it could be interesting to fix it for homogeneity. Thanks a lot!
Versions
PyTorch version: 2.2.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0 Clang version: Could not collect CMake version: version 4.1.0 Libc version: glibc-2.35
Hey @leenheart, thanks for opening this issue. I understand your concerns. However, looking at the documentation for SanitizeBoundingBoxes, it clearly states that the transform will apply to tv_tensors.Mask. So I am not sure how to "fix it for homogeneity"?
I know it's not ideal, but would it be possible to unblock you by calling the transform separately for the mask with something along those lines?
out = transform(image, boxes, labels)
semantic_mask_out = transform(semantic_mask)
Hi, thanks for your response! I agree that the function operates on tv_tensors.Mask. However, based on my understanding of the documentation for SanitizeBoundingBoxes, the sanitization should only apply to the targets or masks provided by the label_getter call. The documentation explicitly states: "labels_getter – Indicates how to identify the labels in the input (or anything else that needs to be sanitized along with the bounding boxes)." From this, I infer that if the mask is not included in the label_getter (as in the small example I mentioned), it should not be sanitized.
By "homogeneity," I mean that either:
- All tv_tensors.Mask objects should be treated as representing boxes and thus sanitized by the function. In this case, the documentation should clarify this behavior more explicitly.
- Only masks provided via label_getter should be sanitized. In this scenario, the mask I pass as a separate argument should remain untouched by the transform.
I hope the second interpretation is correct, as my use case requires applying rotation transforms to both the boxes and the semantic mask simultaneously. If they are not transformed together, they will no longer align with the input image. For example, consider a Compose transform like this:
def create_transform(self, with_data_augmentation=False):
transforms_list = []
transforms_list.append(transforms.ToImage())
transforms_list.append(transforms.ToDtype(torch.float32, scale=True))
if with_data_augmentation:
if self.cfg_transform.iou_crop:
transforms_list.append(transforms.RandomIoUCrop())
if self.cfg_transform.horizontal_flip:
transforms_list.append(transforms.RandomHorizontalFlip())
if self.cfg_transform.rotation:
transforms_list.append(transforms.RandomRotation(self.cfg_transform.rotation_degree))
transforms_list.append(transforms.ClampBoundingBoxes())
transforms_list.append(transforms.SanitizeBoundingBoxes())
return transforms.Compose(transforms_list)
When applied to image, boxes, bboxes_targets, and semantic_segmentation_2d like this:
image, boxes, bboxes_targets, semantic_segmentation_2d = self.transforms(image, boxes, boxes_targets, semantic_mask)
the process fails and crashes at the SanitizeBoundingBoxes step.
The solution, as I understand it, is to only sanitize the masks provided by the label_getter function. While this might affect existing PyTorch datasets, it would make the behavior much clearer and more intuitive.
If I haven’t been clear enough, feel free to ask for further clarification. Don’t worry, I’m not blocked, as I’ve isolated SanitizeBoundingBoxes from the rest of the Compose transform for now. I just wanted to share my findings after spending several hours debugging the issue.
@AntoineSimoulin and I discussed this a bit offline, my current thoughts are that:
- we shouldn't expect users to pass masks to
labels_getterfor them to be sanitized, that would be a BC breaking change: currently, masks are sanitized regardless of whether they're inlabels_getter. - maybe the sanitization (subsampling) of masks should be a "best-effort" thing:
- if there is a 1:1 mapping between the boxes and masks, then we sanitize the masks
- If there is NOT a 1:1 mapping between boxes and masks, we just pass-through the masks. Right now we error in this case, which is, IIUC, @leenheart 's problem. Would that work for you @leenheart ?
That would be perfect and completly understandable :)
Thanks a lot !