Add batch visualization function to `torchvision.utils`
🚀 The feature
Currently vision models commonly return dictionary
model = fasterrcnn_resnet50_fpn_v2(
weights=FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1,
)
dog_int = io.read_image("dog.jpeg")
batch = f.convert_image_dtype(dog_int)
response = model(batch.unsqueeze(0))
# [{'boxes': tensor([[ 61.9160, 49.2223, 185.7204, 184.4109],
# [143.5477, 143.5992, 175.1468, 184.4712],
# [136.6003, 217.3529, 161.4755, 223.8254]]),
# 'labels': tensor([18, 20, 57]),
# 'scores': tensor([0.9989, 0.1069, 0.0699])}]
This is visualized by torchvision by
dog_image = utils.draw_bounding_boxes(
dog_int, response[0]["boxes"][response[0]["scores"] > 0.3]
)
f.to_pil_image(dog_image)
Motivation, pitch
However, there is no currenly a way to visualize whole batch with a function in utilities, eg. continuing from previous
with torch.no_grad():
imgs = [img.detach().clone() for _ in range(16)]
model.eval()
response = model(imgs)
# The input is float tensor batch
# output is equal length list of output dictionaries
There is no currenly available method to visualize this batch for user.
I suggest an utils function
def visualize_batch(image_batch: Tensor | list[Tensor], batch_response: list[dict[str, Tensor]], **visualization_arguments):
"""Function visualizes image batch in a suitable grid and returns result as a tensor
Arguments:
image_batch (Tensor): Float tensor batch, internally transform to uint8 for visualization utils
batch_response: List of response dictionaries, works for both mask and rcnn models as well as training batches by inferring correct behavior from dictionary keys.
"""
# ... implementation
where keyword arguments relate to arguments in current utils functions.
Alternatives
User could follow utilities example at example which implements matplotlib function show. This is non-ideal, as users seems to have need for similar functionality without related boilderplate code.
Additional context
I am willing to contribute given green light.
Thanks for the feature request @vahvero
I think what you're trying to do should be reasonably achievable by:
- manually looping over all images in the batch and calling
draw_bounding_boxes()indiviudally, with the appropriate parameters (note that if there was avisualize_batch()functionality, it would need to take in all of these parameters as well as the score threshold, making its API much more complex) - a final call to
make_grid()to put all of those images+bboxes in a single tensor img
LMK if this isn't what you're looking for
@NicolasHug I personally feel that the parameter count would not be an issue here. Torchvision models already take a lot of default keyword parameters, so the function call signature would not be out of place for users. By using reasonable defaults, most of the functionality could be abstracted in a manner not different for model __init__ methods such as FasterRCNN.
I think that this type of function would have tangible benefit for most users. Instead of the mentioned self implemented functionality, which I suspect every every single user has either copied from the example or abstracted themselves to achieve similar behavior to this feature request, the library would offer a standardized function for it.
For me, it seems a very commonly utilized behavior which has not for some reason included into the library despite including some drawing utilities.
Having though this later, I think function
def visualize_batch(
image_batch: Tensor | list[Tensor],
batch_response: list[dict[str, Tensor]],
**visualization_arguments, # these are naturally expanded
) -> list[Tensor]:
would allow user to pass the tensors to whichever visualization library they are utilizing.