Initial commit does not build
Draft PR For initial review.
Adding lowering support for RoiAlign ops with static shapes and known sampling ratios.
Current to do list:
- Fix minor build issues (replace SubOp with SubIOp/SubFOp, etc)
- Add unit test
@Muzammiluddin-Syed-ECE make sure you click the 'Resolve conversation' button under comments you believe are addressed. This makes it much easier to iterate on pull requests.
Pausing work on this issue as a ramp up task. This task has proven to be more involved than expected for something issued as part of an onboarding process.
Work done: 1. Identified assumptions required for base case of support for lowering RoiAlign:
- Only support lowering on inputs with a trivial batch dimension (=1).
- Only support lowering on RoiAlign when sampling ratio is explicitly defined.
Both these conditions were added to avoid the dependence on data in execution (for example the loop bounds and tensor indexing being dependent on the data inside the inputs)
2. Identified an algorithm to implement base case
To do: 1. Finish debugging current implementation
- Current error:
error: expected type to be 'tensor<?x?x?x?xf32>' or a rank-reduced version. (size mismatch)on finalinsert_slice
2. Add this unit test in an appropriate location:
func.func @main(%arg0: !torch.vtensor<[?,256,?,?],f32>, %arg1: !torch.vtensor<[2,4],f32>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,256,7,7],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.1"} {
%int2 = torch.constant.int 2
%int7 = torch.constant.int 7
%float1.250000e-01 = torch.constant.float 1.250000e-01
%false = torch.constant.bool false
%none = torch.constant.none
%int6 = torch.constant.int 6
%int1 = torch.constant.int 1
%0 = torch.aten.unsqueeze %arg2, %int1 : !torch.vtensor<[2],si64>, !torch.int -> !torch.vtensor<[2,1],si64>
%1 = torch.aten.to.dtype %0, %int6, %false, %false, %none : !torch.vtensor<[2,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,1],f32>
%2 = torch.prim.ListConstruct %1, %arg1 : (!torch.vtensor<[2,1],f32>, !torch.vtensor<[2,4],f32>) -> !torch.list<vtensor>
%3 = torch.aten.cat %2, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[2,5],f32>
%4 = torch.torchvision.roi_align %arg0, %3, %float1.250000e-01, %int7, %int7, %int2, %false : !torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[2,5],f32>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,256,7,7],f32>
return %4 : !torch.vtensor<[?,256,7,7],f32>
}
3. Create a test to verify numerics
Feel free to ping me for additional context and detail