xla
xla copied to clipboard
Implement `torchvision.ops.roi_align` in torchxla2
🚀 Feature
https://pytorch.org/vision/stable/generated/torchvision.ops.roi_align.html?highlight=roi_align#torchvision.ops.roi_align
Few ideas:
- Use torch decomposition in here: https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L115 ; tried this and found out jax OOMs pointing here: https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L74 so the issue seems that the advanced indexing used here creates large intermediaries. Torch side needed a "loop-less" impl to help with inductor, we could actually rewrite it using
jax.vmap
andjax.lax.fori_loop
. - Start from this jax implementation: https://github.com/google-research/scenic/blob/74225e8e71ba27a76abd62e6bc56e8a64c4cc19e/scenic/projects/baselines/centernet/modeling/roi_align.py#L103 but this one takes output_size as int instead of tuple of int (i.e. it assumes width and height is the same) so it will need some modification.