torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Decomposing `aten.as_strided` causes assertion failure in MLIR

Open momchil-velikov opened this issue 3 months ago • 2 comments

Using a recent IREE (fd8715fd14eeb4b929c4b2d052377e60709b5c82) with its corresponding torch-mlir (7000187be292710f3a5044f46e577f22e6cfef57) and llvm-project (a376df0140e67c86a1a48d4ab18ca8a3984b1b0c)

The following MLIR testcase is extracted from a source generated by IREE Turbine from a HuggingFace GPT2 model:

  func.func @f(%in : !torch.vtensor<[1,10,2304],f32>) -> !torch.vtensor<[1,10,768],f32> {
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int10 = torch.constant.int 10
    %int768 = torch.constant.int 768
    %int2304 = torch.constant.int 2304
    %int23040 = torch.constant.int 23040

    %lst0 = torch.prim.ListConstruct %int1, %int10, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %lst1 = torch.prim.ListConstruct %int23040, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

    %out = torch.aten.as_strided %in, %lst0, %lst1, %int0
      : !torch.vtensor<[1,10,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,10,768],f32>

    return %out : !torch.vtensor<[1,10,768],f32>
  }

Compilation with iree-opt --torch-decompose-complex-ops --torch-scalarize-shapes --convert-torch-to-tmtensor --convert-torch-to-tensor --convert-torch-to-linalg repro-1.mlir result in an assertion failure:

iree-opt: /work/iree/main/third_party/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::Value, From = mlir::OpFoldResult]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0.      Program arguments: ./tools/iree-opt --torch-decompose-complex-ops --torch-scalarize-shapes --convert-torch-to-tmtensor --convert-torch-to-tensor --convert-torch-to-linalg repro-1.mlir
 #0 0x0000fb2df12f9e64 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0x6929e64)
 #1 0x0000fb2df12f7884 llvm::sys::RunSignalHandlers() (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0x6927884)
 #2 0x0000fb2df12faa40 SignalHandler(int, siginfo_t*, void*) Signals.cpp:0:0
 #3 0x0000fb2dfa579968 (linux-vdso.so.1+0x968)
 #4 0x0000fb2dea697608 __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
 #5 0x0000fb2dea64cb3c raise ./signal/../sysdeps/posix/raise.c:27:6
 #6 0x0000fb2dea637e00 abort ./stdlib/abort.c:81:7
 #7 0x0000fb2dea645cc0 __assert_fail_base ./assert/assert.c:93:7
 #8 0x0000fb2dea645d30 __assert_perror_fail ./assert/assert-perr.c:31:1
 #9 0x0000fb2df70043bc mlir::matchConstantIndex() (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc6343bc)
#10 0x0000fb2df6e92980 mlir::tensor::ExpandShapeOp::inferOutputShape(mlir::OpBuilder&, mlir::Location, mlir::RankedTensorType, llvm::ArrayRef<llvm::SmallVector<long, 2u>>, llvm::ArrayRef<mlir::OpFoldResult>) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc4c2980)
#11 0x0000fb2df6e92e10 mlir::tensor::ExpandShapeOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Type, mlir::Value, llvm::ArrayRef<llvm::SmallVector<long, 2u>>) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc4c2e10)
#12 0x0000fb2df24a6138 mlir::tensor::ExpandShapeOp mlir::OpBuilder::create<mlir::tensor::ExpandShapeOp, mlir::Type&, mlir::Value, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&>(mlir::Location, mlir::Type&, mlir::Value&&, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&) DataMovement.cpp:0:0
#13 0x0000fb2df24982a4 (anonymous namespace)::ConvertAtenUnflattenIntOp::matchAndRewrite(mlir::torch::Torch::AtenUnflattenIntOp, mlir::torch::Torch::AtenUnflattenIntOpAdaptor, mlir::ConversionPatternRewriter&) const DataMovement.cpp:0:0
...

In the program above, the torch.aten.as_strided is lowered by DecomposeAtenAsStridedOp pattern (torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp) into a sequence containing a few torch.aten.view operations like

  %8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>

which in turn are transformed by ScalarizeShapes pass into problematic torch.aten.unflatten.int ops, like below:

func.func @f() -> !torch.vtensor<[1,?,1],si64> {
  %none = torch.constant.none
  %int-1 = torch.constant.int -1
  %int0 = torch.constant.int 0
  %int1 = torch.constant.int 1
  %int10 = torch.constant.int 10

  %steps = torch.aten.arange.start_step %int0, %int10, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],si64>
  %sizes = torch.prim.ListConstruct %int1, %int-1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

  %ret = torch.aten.unflatten.int %steps, %int0, %sizes : !torch.vtensor<[10],si64>, !torch.int, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>
  return %ret : !torch.vtensor<[1,?,1],si64>
}

Compiling this via iree-opt --convert-torch-to-linalg repro-2.mlir again yields the above assertion failure/stack trace.

Looking at some of the generated ops like

%8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>

there isn't really a need to generate a tensor type with a dynamic shape. While not incorrect, per se, it's a loss of information. Fixing the decomposition to emit instead

%8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,10,1],si64>

helps avoid this issue.

momchil-velikov avatar Sep 17 '25 12:09 momchil-velikov

This behavior started appearing after https://github.com/llvm/torch-mlir/pull/4269

momchil-velikov avatar Sep 17 '25 13:09 momchil-velikov

@vivekkhandelwal1 @benvanik likely a TorchToLinalg problem but for an IREE user.

sjarus avatar Sep 19 '25 14:09 sjarus