torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Converting PrimLoopOp to SCF does not properly convert tensor arguments

Open renxida opened this issue 1 year ago • 7 comments

This causes errors like

./scratch/scfloop.mlir:16:12: error: 'torch_c.to_builtin_tensor' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got 'tensor<2x3xf32>'

by outputting MLIR like

this ```mlir %15 = "scf.for"(%12, %14, %13, %9) ({ ^bb0(%arg0: index loc("./scratch/scfloop.mlir":14:12), %arg1: tensor loc("./scratch/scfloop.mlir":14:12)): ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent<:scf::executeregionop mlir::scf::forop mlir::scf::ifop mlir::scf::indexswitchop mlir::scf::whileop>::Impl) %17 = "arith.index_cast"(%arg0) : (index) -> i64 loc("./scratch/scfloop.mlir":14:12) %18 = "torch_c.from_i64"(%17) : (i64) -> !torch.int loc("./scratch/scfloop.mlir":14:12) %19 = "torch_c.to_builtin_tensor"(%arg1) : (tensor) -> tensor loc("./scratch/scfloop.mlir":16:12) %20 = "tensor.empty"() : () -> tensor loc("./scratch/scfloop.mlir":16:12) %21 = "linalg.generic"(%19, %11, %20) (d0, d1)>, affine_map (d0, d1)>, affine_map (d0, d1)>], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array}> ({ ^bb0(%arg2: f32 loc("./scratch/scfloop.mlir":16:12), %arg3: f32 loc("./scratch/scfloop.mlir":13:10), %arg4: f32 loc("./scratch/scfloop.mlir":16:12)): %23 = "arith.addf"(%arg2, %arg3) }> : (f32, f32) -> f32 loc("./scratch/scfloop.mlir":16:12) "linalg.yield"(%23) : (f32) -> () loc("./scratch/scfloop.mlir":16:12) }) : (tensor, tensor, tensor) -> tensor loc("./scratch/scfloop.mlir":16:12) %22 = "torch_c.from_builtin_tensor"(%21) : (tensor) -> !torch.vtensor loc("./scratch/scfloop.mlir":16:12) "scf.yield"(%22) : (!torch.vtensor) -> () loc("./scratch/scfloop.mlir":14:12) }) : (index, index, index, tensor) -> tensor loc("./scratch/scfloop.mlir":14:12) ```

renxida avatar Mar 15 '24 13:03 renxida

Minimal replicating example:

module {
  func.func @minimal_example() -> (!torch.vtensor<[2,3],f32>) {
    %true = torch.constant.bool true
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %int5 = torch.constant.int 5
    %int6 = torch.constant.int 6
    %none = torch.constant.none
    %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %2 = torch.aten.ones %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %3:1 = torch.prim.Loop %int5, %true, init(%1) {
    ^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[2,3],f32>):
      %4 = torch.aten.add.Tensor %arg2, %2, %int1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
      torch.prim.Loop.condition %true, iter(%4 : !torch.vtensor<[2,3],f32>)
    } : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>)
    return %3#0 : !torch.vtensor<[2,3],f32>
  }
}

renxida avatar Mar 15 '24 13:03 renxida

Confirmed that this is a ConvertTorchPrimLoopForLikeOp issue, because:

  1. the error happened right after ConvertTorchPrimLoopForLikeOp ran
  2. All the way up to the ConvertTorchPrimLoopForLikeOp run, the loop had %arg1: !torch.vtensor<[2,3],f32>

It looks like before the conversion, the loop body was first converting the arg using torch_c.to_builtin_tensor and converting it back for the next iteration. ConvertTorchPrimLoopForLikeOp caused the double-conversion by converting the loop block argument types.

But I have trouble locating exactly where it happens. It seems to be properly skipping tensor arguments. Weird.

renxida avatar Mar 15 '24 13:03 renxida

Going through things that uses the TypeConverter instance.

Things that could be it:

  1. this shouldn't be it because the vtensor should match to neither mlir::FloatType nor mlir::IntegerType.

      // If the target type is non-torch type, then use TypeConverter to convert
      // the type of the source.
      if (targetType.isa<mlir::FloatType>()) {
        targetType = Torch::FloatType::get(op->getContext());
        torchArg = typeConverter->materializeSourceConversion(
            rewriter, scfForOp.getLoc(), targetType, {to});
      } else if (targetType.isa<mlir::IntegerType>()) {
        unsigned bitWidth = targetType.getIntOrFloatBitWidth();
        if (bitWidth == 1)
          targetType = Torch::BoolType::get(op->getContext());
        else
          targetType = Torch::IntType::get(op->getContext());
        torchArg = typeConverter->materializeSourceConversion(
            rewriter, scfForOp.getLoc(), targetType, {to});
      }
  1. this shouldn't be it because the isa<Torch::BaseTensorType> would skip this one
          if (torchType.isa<Torch::BaseTensorType>()) {
            loopConditionIterArgs.push_back(torchArg);
            continue;
          }
          Value arg = typeConverter->materializeTargetConversion(
              rewriter, scfForOp.getLoc(),
              typeConverter->convertType(torchArg.getType()), {torchArg});

renxida avatar Mar 15 '24 13:03 renxida

It's unlikely but maybe it's this one right at the beginning?

    if (failed(
            typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
      return rewriter.notifyMatchFailure(
          op, "could not convert PrimLoopOp outputs");

But it's only supposed to be converting loop result types. Let's change it anyways and see what happens.

renxida avatar Mar 15 '24 13:03 renxida

Nope. Skipping basetensortypes doesn't work.

renxida avatar Mar 15 '24 15:03 renxida

Currently stuck. Can't find any other uses of the typeconverter. Current assumptions that got me stuck:

  1. this is a ConvertTorchPriLoopForLikeOp issue
  2. This issue is caused by a basetensor being converted when it's not supposed to
  3. the faulty conversion happens through TypeConverter
  4. The vtensor should not match isamlir::FloatType, isamlir::IntegerType
  5. the vtensor should match isaTorch::BaseTensorType

I wonder what's wrong.

renxida avatar Mar 15 '24 15:03 renxida

Fix is uploaded to: https://github.com/rsuderman/torch-mlir/tree/torch_scf

We will need to check if while or if have similar issues.

The tests only roundtripped tensors, they never included any computation in the body. As a result they never materialized the compatbility casts between the torch and mlir tensor types. This meant as long as it was just round-tripping the tensor things worked, however once you have any computation it starts returning incorrect types on the region boundaries.

rsuderman avatar Mar 15 '24 16:03 rsuderman