xla icon indicating copy to clipboard operation
xla copied to clipboard

Implement `torchvision.ops.roi_align` in torchxla2

Open qihqi opened this issue 4 months ago • 0 comments

🚀 Feature

https://pytorch.org/vision/stable/generated/torchvision.ops.roi_align.html?highlight=roi_align#torchvision.ops.roi_align

Few ideas:

  1. Use torch decomposition in here: https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L115 ; tried this and found out jax OOMs pointing here: https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L74 so the issue seems that the advanced indexing used here creates large intermediaries. Torch side needed a "loop-less" impl to help with inductor, we could actually rewrite it using jax.vmap and jax.lax.fori_loop.
  2. Start from this jax implementation: https://github.com/google-research/scenic/blob/74225e8e71ba27a76abd62e6bc56e8a64c4cc19e/scenic/projects/baselines/centernet/modeling/roi_align.py#L103 but this one takes output_size as int instead of tuple of int (i.e. it assumes width and height is the same) so it will need some modification.

Motivation

Pitch

Alternatives

Additional context

qihqi avatar Oct 18 '24 21:10 qihqi