prioritize batching for torchvision::nms
🚀 The feature
Implement batched nms!
Motivation, pitch
My motivation to create this issue was this warning:
[W BatchedFallback.cpp:84] Warning: There is a performance drop because we have not yet implemented the batching rule for torchvision::nms. Please file us an issue on GitHub so that we can prioritize its implementation. (function warnFallback)
Alternatives
No response
Additional context
No response
Hi @RishiMalhotra920 , can you please share a minimal reproducing example of the code you tried to run? I suspect the warning you're observing comes from torch core rather than from torchvision directly.
Not that we do have a batched version of nms here: https://pytorch.org/vision/main/generated/torchvision.ops.batched_nms.html#torchvision.ops.batched_nms
Ah unfortunately, I can't find the place where i saw the warning again but i think i was doing something with
torch.vmap and torch.box_iou
Additionally, i saw the batched_nms earlier and this threw me off initially since it mentioned that it does not do NMS between objects in different categories. However, to make this work for multiple images in a batch, i would just assign each image in the batch with a different category and it would work as expected. Thanks for pointing me to this!
Once you have noted the suggestion, if you don't have any other questions, you can close this.
Thanks for the reply @RishiMalhotra920
However, to make this work for multiple images in a batch, i would just assign each image in the batch with a different category and it would work as expected
Can you share a reproducible example of this? If all examples in a batch have a different label, then batched_nms shouls basically be a no-op.
Thanks to the Team for developing PyTorch!
It would be so much useful to have namely batch-NMS, but multiclass-NMS as now. It's extremely helpfull for for those who works with Triton and Jetson and needs to torch.jit.trace or torch.jit.script their model to use them with Triton backend. Triton Python backend doesn't support GPU, so compiling Torch models is the only way to run them on GPU, and it doesn't seem to be changed because of bad connection between Python and Jetson-GPU (see here https://github.com/triton-inference-server/server/issues/4772#issuecomment-1217253091).
Please, it would be remarkable to have NMS which returns the bboxes with the batch dimension and padding with fake bboxes to support the tensor shape.
@another-sasha I'd like to give this a try. Can you give a minimal reproducible thing (even with for loops) so I understand your ask better?