[Feature Request] PadToSquare: Square Padding to Preserve Aspect Ratios When Resizing Images with Varied Shapes in torchvision.transforms.v2
🚀 The feature
A new transform class, PadToSquare, that pads non-square images to make them square by adding padding to the shorter side. Configuration is inspired by torchvision.transforms.v2.Pad. Note that positional argument size is dropped since we calculate the target size based on the non-square image we want to square pad. This feature would be beneficial in situations where square inputs are required for downstream models or processes, and it simplifies the pipeline by embedding this transformation within torchvision.transforms.v2.
Case 1 (Width > Height):
Case 2: Height > Width:
Case 3: Height == Width: Nothing changes :-)
Image Sources: VOC2012
Motivation, pitch
I’m working on a multi-label classification project that requires images to be square, but the input dataset has a variety of shapes and aspect ratios. PadSquare would streamline the preprocessing pipeline by automatically resizing these images to square while allowing flexible padding modes. This avoids distortions when resizing further and simplifies handling various image shapes. This feature request is based on the need to make square inputs straightforward and robust with consistent padding.
Alternatives
I have considered using existing padding methods within torchvision, but they require additional logic to conditionally apply padding only to the shorter side, making the code less modular, e.g. as demonstrated in this discussion. Current alternatives involve manually calculating padding and applying it to achieve square shapes. By having a dedicated PadSquare transform, it would streamline this common operation into a more reusable and convenient utility.
Additional context
The PadSquare class uses the _get_params method to calculate the necessary padding values, ensuring the padded image is centered. It also supports multiple padding modes and allows for a specified fill value when using 'constant' mode. It would enhance the versatility of torchvision.transforms.v2 by providing a reusable utility for data preprocessing. Let me know what you think of it! :-)
Initial Implementation
My initial implementation of PadSquare is inspired by the implementation of Pad.
from typing import Any, Dict, List, Literal, Union, Type
import torchvision.transforms.v2.functional as F
from torchvision.transforms import v2
from torchvision.transforms.v2._utils import (
_check_padding_mode_arg,
_get_fill,
_setup_fill_arg,
_FillType,
)
class PadSquare(v2.Transform):
"""Pad a non-square input to make it square by padding the shorter side to match the longer side.
Args:
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is "constant".
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
Example:
>>> import torch
>>> from torchvision.transforms.v2 import PadSquare
>>> rectangular_image = torch.randint(0, 255, (3, 224, 168), dtype=torch.uint8)
>>> transform = PadSquare(padding_mode='constant', fill=0)
>>> square_image = transform(rectangular_image)
>>> print(square_image.size())
torch.Size([3, 224, 224])
"""
def __init__(
self,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
):
super().__init__()
_check_padding_mode_arg(padding_mode)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
"`padding_mode` must be one of 'constant', 'edge', 'reflect' or 'symmetric'."
)
self.padding_mode = padding_mode
self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
# Get the original height and width from the inputs
orig_height, orig_width = v2.query_size(flat_inputs)
# Find the target size (maximum of height and width)
target_size = max(orig_height, orig_width)
if orig_height < target_size:
# Need to pad height
pad_height = target_size - orig_height
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = 0
pad_right = 0
else:
# Need to pad width
pad_width = target_size - orig_width
pad_left = pad_width // 2
pad_right = pad_width - pad_left
pad_top = 0
pad_bottom = 0
# The padding needs to be in the format [left, top, right, bottom]
return dict(padding=[pad_left, pad_top, pad_right, pad_bottom])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self.fill, type(inpt))
return self._call_kernel(
F.pad,
inpt,
padding=params["padding"],
padding_mode=self.padding_mode,
fill=fill
)
Thanks a lot for the super detailed feature request @geezah !
This sounds reasonable but before we move towards a PR, can you help me understand why you think padding is preferable to resizing the input image here?
Side note: using query_size(flat_inputs) as suggested in the snippet above will enforce that all images in the input are of the same original shape. I don't think we can avoid such enforcement (at least not easily), but I just wanted to point that out in case that's not desirable for your own use-case.
Thanks for the feedback! The padding approach was suggested mainly for cases where preserving aspect ratios could be beneficial, such as:
- Object detection tasks where object proportions are crucial
- OCR/text recognition where character aspect ratios matter
- Fine-grained classification where shape distortion might hide important features
You're right about the issue with the same-sized inputs. For handling variable input sizes, one could implement a custom collate_fn that performs random resizing at batch creation time instead of during the transform pipeline. This would allow for more flexibility while maintaining batch efficiency.
Thank for coming back to me @geezah . This sounds good, please feel free to submit a PR! Let's go with the simple approach of using query_size first, we can consider the collate_fn approach later if needed.
Alright 😄 Thank you for coming back to it quickly!
This sounds like a good idea—some of the VLM processors. Eg. LLaVa need a square image. It needs a square image and internal mechanism changes the image size by resizing, whereas I want to preserve the aspect ratio.