Using continue in a for loop hangs torch op e2e test
If I put the following loop in a shape propagation function in torch-mlir/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py:
for i in range(result_rank):
if i in (dim1, dim2):
continue
result_shape[i] = self[input_dim_idx]
input_dim_idx += 1
the generated torch MLIR in torch-mlir/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp creates extremally high loop upper bound without any warning like so:
" %int9223372036854775807 = torch.constant.int 9223372036854775807\n"
...
" %23:2 = torch.prim.Loop %int9223372036854775807, %true, init(%int0, %int0) {\n"
This in turn hangs the end to end test when run via the e2e script: ./projects/pt1/tools/e2e_test.sh
To get this working I had to rewrite the code above as below:
for i in range(result_rank):
if i not in (dim1, dim2):
result_shape[i] = self[input_dim_idx]
input_dim_idx += 1
I observed the same issue when I lowered aten.diagonal, as I explained in my PR at the time. The problem was caused by the continue, but I didn‘t understand what exactly was causing it.
The "When things go wrong" section in the instructions docs say sometimes it requires writing things a little awkwardly.
So is the quality of torch-mlir generation for shape and dtype functions on as-needed basis as I'm sure many people have faced tricky situations like this? We are after all trying to compile generic python code and coverage is always challenging.
All of the shape inference infra is in support of the old TorchScript path. The new FX path uses upstream pt for all of this. As such, I'm not sure this gets much attention outside of the test suite (which still uses the old path).