torch-mlir
torch-mlir copied to clipboard
Torch to TOSA conversion fails to legalize 'torch.constant.int'
I am trying to compile a portion of a PyTorch Self-Attention module down to TOSA backend and am hitting an error on legalizing the torch.contant.int
Op in TOSA conversion pass. The issue raises only when output type is set to torch_mlir.OutputType.TOSA
in Torch-MLIR compile API. The conversion to LinAlg Dialect and further down to backend works fine. However Torch to TOSA conversion is failing.
Error log
Exception:
Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.constant.int'
note: see current operation: %5 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int
Steps to reproduce
The script to reproduce the error is up in this draft PR on a local fork. The error can be reproduced using the code snippet below with the module definition and torch_mlir.compile()
API call:
class AttentionScores(torch.nn.Module):
def __init__(self, embedding_dim, num_heads):
super(AttentionScores, self).__init__()
self.query = nn.Linear(embedding_dim, num_heads)
self.key = nn.Linear(embedding_dim, num_heads)
def forward(self, inputs):
query = self.query(inputs)
key = self.key(inputs)
scores = torch.matmul(query, key.transpose(0, 1))
return scores
attention_scores = AttentionScores(embedding_dim=10, num_heads=2)
inputs = torch.rand(5, 10)
tosa_module = torch_mlir.compile(attention_scores, inputs, output_type=torch_mlir.OutputType.TOSA)
This issue is potentially relevant to what @nithinsubbiah, @rdadolf and I are seeing in #910. I was also able to reproduce the error by simplifying the module above to a single Transpose Op (i.e. torch.Tensor.transpose
in forward-propagate method).
cc @sjarus @powderluv @silvasean - wonder if you have seen this or have any insight on what might have been going wrong here.
I've encountered this already @Svoch - it also impacts MobilenetsV3 . Working on a fix internally but am getting some BERT ones out first.
I just convert ConstantIntOp to arith while converting Bert to tosa, as a intermediate result.
Below is the IR I get lowering the model above to Torch Dialect. If I've understood correctly, even though rewriters for AtenMmOp and AtenLinearOp do exist in TorchToTosa lowering pass, there is no lowering pattern for AtenTransposeIntOp
. Seems like this is the underlying problem here. Does it make sense @sjarus?
module attributes {torch.debug_module_name = "AttentionScores"} {
func.func @forward(%arg0: !torch.vtensor<[5,10],f32>) -> !torch.vtensor<[5,5],f32> {
%0 = torch.vtensor.literal(dense<[0.143874153, 0.230392322]> : tensor<2xf32>) : !torch.vtensor<[2],f32>
%1 = torch.vtensor.literal(dense<[[0.226313293, -0.279624343, -0.211782753, 2.701980e-01, 0.0410803184, 0.25729695, -0.00262214779, -0.0828355625, 0.145104617, -0.0266586915], [0.239182726, 0.00810842216, -0.00369983795, -0.132520735, -0.254919976, 0.0812480971, -0.196122929, 0.10878253, 0.158111736, 0.306294829]]> : tensor<2x10xf32>) : !torch.vtensor<[2,10],f32>
%2 = torch.vtensor.literal(dense<[0.249187589, 0.0799354762]> : tensor<2xf32>) : !torch.vtensor<[2],f32>
%3 = torch.vtensor.literal(dense<[[2.573700e-01, 0.215394989, 0.238999024, 0.163943127, 0.212515891, 0.231857046, -0.28136012, -0.118400358, -0.217035949, -0.0496219955], [-0.3136262, 0.105287395, -0.037995059, -0.129876241, -0.142800108, -0.208000481, -0.0777831152, 0.144313246, -0.086798042, 0.255681425]]> : tensor<2x10xf32>) : !torch.vtensor<[2,10],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%4 = torch.aten.linear %arg0, %3, %2 : !torch.vtensor<[5,10],f32>, !torch.vtensor<[2,10],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[5,2],f32>
%5 = torch.aten.linear %arg0, %1, %0 : !torch.vtensor<[5,10],f32>, !torch.vtensor<[2,10],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[5,2],f32>
%6 = torch.aten.transpose.int %5, %int0, %int1 : !torch.vtensor<[5,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,5],f32>
%7 = torch.aten.mm %4, %6 : !torch.vtensor<[5,2],f32>, !torch.vtensor<[2,5],f32> -> !torch.vtensor<[5,5],f32>
return %7 : !torch.vtensor<[5,5],f32>
}
}
The IR is really helpful to test against what I have and see if it legalizes right. I'll check today after morning meetings and post an update.
I am encountering the same issue with this patch https://github.com/llvm/torch-mlir/pull/862 for ResNet18 static model. @sjarus
error: failed to legalize operation 'torch.constant.int'
Just pushed https://github.com/llvm/torch-mlir/pull/1017 on this .
I encounter the same issue to lower the huggingface gpt2. https://gist.github.com/AmosLewis/9b929414d5677afda3528122f92bbc73 @sjarus error: failed to legalize operation 'torch.constant.int'
torch.constant.int is a known missing conversion.
Was this fixed?
@silvasean - Yes. with the Torch to TOSA conversion of the Transpose Op merged, this issue can be marked as fixed.
Please note that the symptom, i.e. error: failed to legalize operation 'torch.constant.int'
is common for most of the Ops without a TOSA lowering in Torch-MLIR, since they usually will leave dangling attributes (such as axis integers) after the Torch to TOSA conversion.
The torch.constant.int error just means there are aten ops that use this torch.constant.int as operand haven’t been lowered successfully by your lowering code. You need to find the op that is not lower successfully in the IR of debug info. And understand each line of your lowering code that is related to this op and come up with a new plan.
The error will come again and again on each op we try to lower until we lower it successfully. As you can see in the comment, I find this error many times when I started to work on gpt. This error will disappear when you lower your own ops to tosa(or other dialects stablehlo/linalg/tmtensor) correctly.
torch-mlir-opt -convert-torch-to-tosa /tmp/aten_as_tride.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug --mlir-print-ir-before-all
This is the command that you might need to get more debug info. you just need to replace the /tmp/aten_as_tride.mlir with op.mlir file you manually created. You can take my where.mlir file and command in comments as examples. Here is the link https://gist.github.com/AmosLewis/32847885f8b3ff27b7ef6564154fec59
For those who worked on tosa, here is the relationship of the 2 tosa-related flags for torch-mir-opt you need to understand before diving into debugging:
-pass-pipeline='torch-backend-to-tosa-backend-pipeline' == "-convert-torch-to-tosa"+ some other clear/standard conversion pass(like clear the torch.constant.int for aten ops that successfully lowered to tosa)
-pass-pipeline='torch-backend-to-tosa-backend-pipeline' will call this line 100, the whole function createTorchBackendToTosaBackendPipeline( OpPassManager &pm) https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Dialect/TorchConversion/Transforms/Passes.cpp#L100
-convert-torch-to-tosa will only call this line 102, https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Dialect/TorchConversion/Transforms/Passes.cpp#L102. which will call the convention function createConvertTorchToTosaPass() you added for your ops in TorchToTosa.cpp https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Conversion/TorchToTosa/TorchToTosa.cpp#L3944. Which will call ConvertTorchToTosa::runOnOperation() https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Conversion/TorchToTosa/TorchToTosa.cpp#L3734. In this function , this is where the MatchAndRewrite pattern lowering ops code we added usually started.
The torch.constant.int to tosa type should be clean around this line 113 if line 102 -convert-torch-to-tosa doing well. https://github.com/llvm/torch-mlir/blob/804f9f1f8f2002c3f1537e1fb0c919e30f6b4600/lib/Dialect/TorchConversion/Transforms/Passes.cpp#L113 .as the comment in line 111 has explained.
And in each matchAndRewirte pattern, each aten ops has a corresponding Adaptor op. The adaptor is the mlir inside version of the aten ops. For example, for a where.mlir file, torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1> ), the arg0 if you use atenop.getSelf().dump(), you will get torch version tensor !torch.vtensor<[1,1,5,5],i1>. But if you use adaptor.getSelf().dump(), you will get tensor<1x1x5x5xi1>.
Those useful op helper function like getSelf(), you can find them in you own building directory, build/tools/torc-mlir/include/torch-mlir/dialect/torch/IR/TorchOps.h.inc, their implementation is in build/tools/torc-mlir/include/torch-mlir/dialect/torch/IR/TorchOps.cpp.inc. This is automatic generated by tabelgen(.td file)of of mlir. The tablegen file location is at the similar dir structure of torch_mlir source code https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td. In this td file, you will find detail types of each aten ops, which will be very useful when you come up you new lowering plans
And to play with adaptor's types, which is mlir internal type, like <RankedTensorType> <TensorType> etc, you will need the function in external/llvm-project https://github.com/llvm/llvm-project/blob/798fa4b415eea55c868ae42b874083cb9886991e/mlir/include/mlir/IR/Types.h and https://github.com/llvm/llvm-project/blob/798fa4b415eea55c868ae42b874083cb9886991e/mlir/include/mlir/IR/BuiltinTypes.h
We will have to go deep and read these codes, understand their design structure, and get familiar with them. Otherwise, nothing we can successfully debug. These codes are like the raw food for a cooker. C++ and python is our cooking tools. Our work is to come up with a recipe(lowering plan) and use the cooking tools to cook(implement/debug) it with this raw food.