torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Torch to TOSA conversion fails to legalize 'torch.constant.int'

Open Svoch opened this issue 2 years ago • 6 comments

I am trying to compile a portion of a PyTorch Self-Attention module down to TOSA backend and am hitting an error on legalizing the torch.contant.int Op in TOSA conversion pass. The issue raises only when output type is set to torch_mlir.OutputType.TOSA in Torch-MLIR compile API. The conversion to LinAlg Dialect and further down to backend works fine. However Torch to TOSA conversion is failing.

Error log

Exception: 
Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.constant.int'
note: see current operation: %5 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int

Steps to reproduce

The script to reproduce the error is up in this draft PR on a local fork. The error can be reproduced using the code snippet below with the module definition and torch_mlir.compile() API call:

class AttentionScores(torch.nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(AttentionScores, self).__init__()
        self.query = nn.Linear(embedding_dim, num_heads)
        self.key = nn.Linear(embedding_dim, num_heads)

    def forward(self, inputs):
        query = self.query(inputs)
        key = self.key(inputs)
        scores = torch.matmul(query, key.transpose(0, 1))
        return scores

attention_scores = AttentionScores(embedding_dim=10, num_heads=2)
inputs = torch.rand(5, 10)
tosa_module = torch_mlir.compile(attention_scores, inputs, output_type=torch_mlir.OutputType.TOSA)

This issue is potentially relevant to what @nithinsubbiah, @rdadolf and I are seeing in #910. I was also able to reproduce the error by simplifying the module above to a single Transpose Op (i.e. torch.Tensor.transpose in forward-propagate method).

cc @sjarus @powderluv @silvasean - wonder if you have seen this or have any insight on what might have been going wrong here.

Svoch avatar Jun 22 '22 16:06 Svoch

I've encountered this already @Svoch - it also impacts MobilenetsV3 . Working on a fix internally but am getting some BERT ones out first.

sjarus avatar Jun 22 '22 16:06 sjarus

I just convert ConstantIntOp to arith while converting Bert to tosa, as a intermediate result.

YellowHCH avatar Jun 27 '22 09:06 YellowHCH

Below is the IR I get lowering the model above to Torch Dialect. If I've understood correctly, even though rewriters for AtenMmOp and AtenLinearOp do exist in TorchToTosa lowering pass, there is no lowering pattern for AtenTransposeIntOp. Seems like this is the underlying problem here. Does it make sense @sjarus?

module attributes {torch.debug_module_name = "AttentionScores"} {
  func.func @forward(%arg0: !torch.vtensor<[5,10],f32>) -> !torch.vtensor<[5,5],f32> {
    %0 = torch.vtensor.literal(dense<[0.143874153, 0.230392322]> : tensor<2xf32>) : !torch.vtensor<[2],f32>
    %1 = torch.vtensor.literal(dense<[[0.226313293, -0.279624343, -0.211782753, 2.701980e-01, 0.0410803184, 0.25729695, -0.00262214779, -0.0828355625, 0.145104617, -0.0266586915], [0.239182726, 0.00810842216, -0.00369983795, -0.132520735, -0.254919976, 0.0812480971, -0.196122929, 0.10878253, 0.158111736, 0.306294829]]> : tensor<2x10xf32>) : !torch.vtensor<[2,10],f32>
    %2 = torch.vtensor.literal(dense<[0.249187589, 0.0799354762]> : tensor<2xf32>) : !torch.vtensor<[2],f32>
    %3 = torch.vtensor.literal(dense<[[2.573700e-01, 0.215394989, 0.238999024, 0.163943127, 0.212515891, 0.231857046, -0.28136012, -0.118400358, -0.217035949, -0.0496219955], [-0.3136262, 0.105287395, -0.037995059, -0.129876241, -0.142800108, -0.208000481, -0.0777831152, 0.144313246, -0.086798042, 0.255681425]]> : tensor<2x10xf32>) : !torch.vtensor<[2,10],f32>
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %4 = torch.aten.linear %arg0, %3, %2 : !torch.vtensor<[5,10],f32>, !torch.vtensor<[2,10],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[5,2],f32>
    %5 = torch.aten.linear %arg0, %1, %0 : !torch.vtensor<[5,10],f32>, !torch.vtensor<[2,10],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[5,2],f32>
    %6 = torch.aten.transpose.int %5, %int0, %int1 : !torch.vtensor<[5,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,5],f32>
    %7 = torch.aten.mm %4, %6 : !torch.vtensor<[5,2],f32>, !torch.vtensor<[2,5],f32> -> !torch.vtensor<[5,5],f32>
    return %7 : !torch.vtensor<[5,5],f32>
  }
}

Svoch avatar Jun 28 '22 12:06 Svoch

The IR is really helpful to test against what I have and see if it legalizes right. I'll check today after morning meetings and post an update.

sjarus avatar Jun 28 '22 15:06 sjarus

I am encountering the same issue with this patch https://github.com/llvm/torch-mlir/pull/862 for ResNet18 static model. @sjarus error: failed to legalize operation 'torch.constant.int'

Shukla-Gaurav avatar Jul 06 '22 16:07 Shukla-Gaurav

Just pushed https://github.com/llvm/torch-mlir/pull/1017 on this .

sjarus avatar Jul 06 '22 17:07 sjarus

I encounter the same issue to lower the huggingface gpt2. https://gist.github.com/AmosLewis/9b929414d5677afda3528122f92bbc73 @sjarus error: failed to legalize operation 'torch.constant.int'

AmosLewis avatar Sep 22 '22 02:09 AmosLewis

torch.constant.int is a known missing conversion.

sjarus avatar Sep 22 '22 02:09 sjarus

Was this fixed?

silvasean avatar Oct 07 '22 13:10 silvasean

@silvasean - Yes. with the Torch to TOSA conversion of the Transpose Op merged, this issue can be marked as fixed.

Please note that the symptom, i.e. error: failed to legalize operation 'torch.constant.int' is common for most of the Ops without a TOSA lowering in Torch-MLIR, since they usually will leave dangling attributes (such as axis integers) after the Torch to TOSA conversion.

Svoch avatar Oct 07 '22 14:10 Svoch

The torch.constant.int error just means there are aten ops that use this torch.constant.int as operand haven’t been lowered successfully by your lowering code. You need to find the op that is not lower successfully in the IR of debug info. And understand each line of your lowering code that is related to this op and come up with a new plan.

The error will come again and again on each op we try to lower until we lower it successfully. As you can see in the comment, I find this error many times when I started to work on gpt. This error will disappear when you lower your own ops to tosa(or other dialects stablehlo/linalg/tmtensor) correctly.

torch-mlir-opt -convert-torch-to-tosa /tmp/aten_as_tride.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug  --mlir-print-ir-before-all

This is the command that you might need to get more debug info. you just need to replace the /tmp/aten_as_tride.mlir with op.mlir file you manually created. You can take my where.mlir file and command in comments as examples. Here is the link https://gist.github.com/AmosLewis/32847885f8b3ff27b7ef6564154fec59

For those who worked on tosa, here is the relationship of the 2 tosa-related flags for torch-mir-opt you need to understand before diving into debugging:

-pass-pipeline='torch-backend-to-tosa-backend-pipeline'  == "-convert-torch-to-tosa"+ some other clear/standard conversion pass(like clear the torch.constant.int for aten ops that successfully lowered to tosa)

-pass-pipeline='torch-backend-to-tosa-backend-pipeline'  will call this line 100, the whole function createTorchBackendToTosaBackendPipeline( OpPassManager &pm) https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Dialect/TorchConversion/Transforms/Passes.cpp#L100

-convert-torch-to-tosa will only call this line 102, https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Dialect/TorchConversion/Transforms/Passes.cpp#L102. which will call the convention function createConvertTorchToTosaPass() you added for your ops in TorchToTosa.cpp https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Conversion/TorchToTosa/TorchToTosa.cpp#L3944. Which will call ConvertTorchToTosa::runOnOperation() https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Conversion/TorchToTosa/TorchToTosa.cpp#L3734. In this function , this is where the MatchAndRewrite pattern lowering ops code we added usually started.

The torch.constant.int to tosa type should be clean around this line 113 if line 102 -convert-torch-to-tosa doing well. https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Dialect/TorchConversion/Transforms/Passes.cpp#L113 .as the comment in line 111 has explained.

And in each matchAndRewirte pattern, each aten ops has a corresponding Adaptor op. The adaptor is the mlir inside version of the aten ops. For example, for a where.mlir file, torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1> ), the arg0 if you use atenop.getSelf().dump(), you will get torch version tensor !torch.vtensor<[1,1,5,5],i1>. But if you use adaptor.getSelf().dump(), you will get tensor<1x1x5x5xi1>.

Those useful op helper function like getSelf(), you can find them in you own building directory, build/tools/torc-mlir/include/torch-mlir/dialect/torch/IR/TorchOps.h.inc, their implementation is in build/tools/torc-mlir/include/torch-mlir/dialect/torch/IR/TorchOps.cpp.inc. This is automatic generated by tabelgen(.td file)of of mlir. The tablegen file location is at the similar dir structure of torch_mlir source code https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td. In this td file, you will find detail types of each aten ops, which will be very useful when you come up you new lowering plans

And to play with adaptor's types, which is mlir internal type, like <RankedTensorType> <TensorType> etc, you will need the function in external/llvm-project https://github.com/llvm/llvm-project/blob/798fa4b415eea55c868ae42b874083cb9886991e/mlir/include/mlir/IR/Types.h and https://github.com/llvm/llvm-project/blob/798fa4b415eea55c868ae42b874083cb9886991e/mlir/include/mlir/IR/BuiltinTypes.h

We will have to go deep and read these codes, understand their design structure, and get familiar with them. Otherwise, nothing we can successfully debug. These codes are like the raw food for a cooker. C++ and python is our cooking tools. Our work is to come up with a recipe(lowering plan) and use the cooking tools to cook(implement/debug) it with this raw food.

AmosLewis avatar Dec 12 '22 03:12 AmosLewis