Converting PrimLoopOp to SCF does not properly convert tensor arguments
This causes errors like
./scratch/scfloop.mlir:16:12: error: 'torch_c.to_builtin_tensor' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got 'tensor<2x3xf32>'
by outputting MLIR like
this
```mlir %15 = "scf.for"(%12, %14, %13, %9) ({ ^bb0(%arg0: index loc("./scratch/scfloop.mlir":14:12), %arg1: tensor loc("./scratch/scfloop.mlir":14:12)): ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent<:scf::executeregionop mlir::scf::forop mlir::scf::ifop mlir::scf::indexswitchop mlir::scf::whileop>::Impl
Minimal replicating example:
module {
func.func @minimal_example() -> (!torch.vtensor<[2,3],f32>) {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int5 = torch.constant.int 5
%int6 = torch.constant.int 6
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
%2 = torch.aten.ones %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
%3:1 = torch.prim.Loop %int5, %true, init(%1) {
^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[2,3],f32>):
%4 = torch.aten.add.Tensor %arg2, %2, %int1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
torch.prim.Loop.condition %true, iter(%4 : !torch.vtensor<[2,3],f32>)
} : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>)
return %3#0 : !torch.vtensor<[2,3],f32>
}
}
Confirmed that this is a ConvertTorchPrimLoopForLikeOp issue, because:
- the error happened right after
ConvertTorchPrimLoopForLikeOpran - All the way up to the
ConvertTorchPrimLoopForLikeOprun, the loop had%arg1: !torch.vtensor<[2,3],f32>
It looks like before the conversion, the loop body was first converting the arg using torch_c.to_builtin_tensor and converting it back for the next iteration. ConvertTorchPrimLoopForLikeOp caused the double-conversion by converting the loop block argument types.
But I have trouble locating exactly where it happens. It seems to be properly skipping tensor arguments. Weird.
Going through things that uses the TypeConverter instance.
Things that could be it:
- this shouldn't be it because the vtensor should match to neither mlir::FloatType nor mlir::IntegerType.
// If the target type is non-torch type, then use TypeConverter to convert
// the type of the source.
if (targetType.isa<mlir::FloatType>()) {
targetType = Torch::FloatType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
} else if (targetType.isa<mlir::IntegerType>()) {
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
if (bitWidth == 1)
targetType = Torch::BoolType::get(op->getContext());
else
targetType = Torch::IntType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
}
- this shouldn't be it because the
isa<Torch::BaseTensorType>would skip this one
if (torchType.isa<Torch::BaseTensorType>()) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
Value arg = typeConverter->materializeTargetConversion(
rewriter, scfForOp.getLoc(),
typeConverter->convertType(torchArg.getType()), {torchArg});
It's unlikely but maybe it's this one right at the beginning?
if (failed(
typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(
op, "could not convert PrimLoopOp outputs");
But it's only supposed to be converting loop result types. Let's change it anyways and see what happens.
Nope. Skipping basetensortypes doesn't work.
Currently stuck. Can't find any other uses of the typeconverter. Current assumptions that got me stuck:
- this is a ConvertTorchPriLoopForLikeOp issue
- This issue is caused by a basetensor being converted when it's not supposed to
- the faulty conversion happens through TypeConverter
- The vtensor should not match isamlir::FloatType, isamlir::IntegerType
- the vtensor should match isaTorch::BaseTensorType
I wonder what's wrong.
Fix is uploaded to: https://github.com/rsuderman/torch-mlir/tree/torch_scf
We will need to check if while or if have similar issues.
The tests only roundtripped tensors, they never included any computation in the body. As a result they never materialized the compatbility casts between the torch and mlir tensor types. This meant as long as it was just round-tripping the tensor things worked, however once you have any computation it starts returning incorrect types on the region boundaries.