RegionProposalNetwork can't be AOTInductor compiled with dynamic batch size
🐛 Describe the bug
this is a cross post of https://github.com/pytorch/pytorch/issues/121036
Just raising it here to notify the maintainers that I'm going to take a crack at fixing the RegionProposalNetwork and potentially other modules to be either traceable, AOTInductor compileable, or both. Are there any current efforts in this direction I should be aware of?
For AOTInductor I think this will at least involve changing the AnchorGenerator, which has a method that mutates an anchor attribute to instead return anchor values.
To support tracing, my plan is to address each TracerWarning (see below). First I'll be looking to remove the iteration over tensors in ImageList that prevent the model from generalziing after tracing.
[/opt/workspace/./satlas-src/satlas/model/model.py:438](http://127.0.0.1:8888/satlas-src/satlas/model/model.py#line=437): TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
image_sizes = [(image.shape[1], image.shape[2]) for image in images]
[/opt/conda/lib/python3.10/site-packages/torchvision/ops/boxes.py:166](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torchvision/ops/boxes.py#line=165): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
[/opt/conda/lib/python3.10/site-packages/torchvision/ops/boxes.py:168](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torchvision/ops/boxes.py#line=167): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
[/opt/conda/lib/python3.10/site-packages/torch/__init__.py:1560](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torch/__init__.py#line=1559): TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert condition, message
[/opt/workspace/./satlas-src/satlas/model/model.py:537](http://127.0.0.1:8888/satlas-src/satlas/model/model.py#line=536): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
losses = {'base': torch.tensor(0, device=device, dtype=torch.float32)}
[/opt/workspace/./satlas-src/satlas/model/model.py:850](http://127.0.0.1:8888/satlas-src/satlas/model/model.py#line=849): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
losses = torch.tensor(0, device=batch_tensor.device, dtype=torch.float32)
Versions
I'm using the nightlies, see https://github.com/pytorch/pytorch/issues/121036
Thanks for the report @rbavery
Are there any current efforts in this direction I should be aware of?
No, we haven't been looking at the RPN's support for torch.compile yet.
I think this will at least involve changing the AnchorGenerator
Just note that a lot of this code is public and changing the behaviour e.g. the expected intput/output would technically be breaking backward compatibility. So adding support for AOT while still preserving BC may not be a trivial task.
Beyond the RPN, what model specifically are you interested in tracing?
Got it, I initially went with supporting TorchScript scripting since it seemed easier and would only require adding type annotations. I've made edits to this model which uses a SWIN Transformer backbone, an FPN, and a Faster RCNN head:
https://github.com/allenai/satlas/blob/main/configs/satlas_explorer_marine_infrastructure.txt https://github.com/allenai/satlas/blob/main/satlas/model/model.py
So far I addressed torchscript scripting issues with type annotations in the Satlas model source.
the first issue I hit was with torchvision is here:
RuntimeError:
Module 'GeneralizedRCNNTransform' has no attribute 'image_mean' (This attribute exists on the Python module, but we failed to convert Python type: 'list' to a TorchScript type. List trace inputs must have elements. Its type was inferred; try adding a type annotation for the attribute.):
File "[/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py", line 167](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py#line=166)
)
dtype, device = image.dtype, image.device
mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
~~~~~~~~~~~~~~~ <--- HERE
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
return (image - mean[:, None, None]) [/](http://127.0.0.1:8888/) std[:, None, None]
'GeneralizedRCNNTransform.normalize' is being compiled since it was called from 'GeneralizedRCNNTransform.forward'
File "[/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py", line 141](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py#line=140)
if image.dim() != 3:
raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
image = self.normalize(image)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
image, target_index = self.resize(image, target_index)
images[i] = image
and I've had some trouble addressing this with typing, since the class attribute is already typed, I'm not sure how to enable Torschript scripting to understand this attribute can be either a List[float] or None. I might need to make code modifications. I'll try to do so in a way that preserves backwards compat and leaves passing test and PR if it is helpful.
I was able to get torch scripting to work by refactoring the Satlas source code, mostly by adding typing, removing control flow in some spots, and replacing the use of complex python data structures containing tensors with plain tensors. inference on dynamic batches appears to work without error. No changes to torchvision were needed.
but not AOTInductor unfortunately. I made some progress forking torchvision and trying to remove the use of ImageList and other python data structures, remove control flow (often making very hard assumptions about the input data), replace python indexing with torch.narrow, etc. But I still ran into unbacked symint issues when the NMS step is applied in the RPN, which I wasn't sure how to get around the fact that NMS is data-dependent and can't be made un-data dependent. If it' shelpful, I tried to document how I made progress here https://github.com/pytorch/pytorch/issues/121036
Both methods for exporting were relatively painful. I'm hoping that AOTInductor comes up with a solution for handling data-dependent shapes, or making the process to write code that handles data-dependent shapes easier. I realize it is early days for AOTInductor still, but documentation would go a long way. I'd be happy to contribute but still feel fairly new to the process of handing data dependent shapes.