Failure to lower mhlo.scatter to LinalgExt
What happened?
may be related to #9456 but is still failing on latest release.
Assuming steps are redundant with reproducers.
Steps to reproduce your issue
- Go to '...'
- Click on '....'
- Scroll down to '....'
- See error
What component(s) does this issue relate to?
MLIR
Version information
iree-compiler 20220629.188 iree-runtime 20220629.188 iree-tools-tf 20220629.188 iree-tools-tflite 20220629.188 iree-tools-xla 20220629.188
Additional context
Attempted to run the iree-sample for huggingface but with an additional train method. It works without that.
@NatashaKnk could you look at this too?
Could you perhaps paste in the failing MHLO scatter here? (makes it easier to see before I get triage)
The error I'm getting is "error: failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal", and the output scatter op is:
%13798 = "mhlo.scatter"(%49, %261, %13797) ({
^bb0(%arg4: tensor<f32> loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))]), %arg5: tensor<f32> loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))])):
%14284 = "mhlo.add"(%arg4, %arg5) : (tensor<f32>, tensor<f32>) -> tensor<f32> loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))])
"mhlo.return"(%14284) : (tensor<f32>) -> () loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))])
}) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<30522x384xf32>, tensor<512xi32>, tensor<512x384xf32>) -> tensor<30522x384xf32> loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))])
@jpienaar bumping to P1 from today's Nod meeting.
Local reproducer
func.func @main(%arg0: tensor<30522x384xf32>, %arg1: tensor<512xi32>, %arg2: tensor<512x384xf32>) -> tensor<30522x384xf32> {
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<30522x384xf32>, tensor<512xi32>, tensor<512x384xf32>) -> tensor<30522x384xf32>
return %0 : tensor<30522x384xf32>
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "func.func(iree-mhlo-to-linalg-on-tensors)",
disable_threading: true,
verify_each: true
}
}
#-}
$tools/iree-opt /tmp/repro.mlir
results in
failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal
unfortunately notifyMatchFailure isn't used so debug output is just
//===-------------------------------------------===// Legalizing operation : 'mhlo.scatter'(0x8ea4b50) {
-
Fold { } -> FAILURE : unable to fold
-
Pattern : 'mhlo.scatter -> ()' { Trying to match "mlir::(anonymous namespace)::ScatterUpdateConversion" "mlir::(anonymous namespace)::ScatterUpdateConversion" result 0
What seems to happen is that the folder fails on casting the scatter_dimension attribute (i.e. %arg1: tensor<512xi32>) to a DenseIntElementsAttr. I'm not exactly sure why that would be the case (is there something obvious I'm missing?), looking further into it.
Could you add link to where in the code it fails?
Yeah, sorry! As far as I understand it returns false here.
So scatter_indices are not constant input, do we get to https://github.com/iree-org/iree/blob/8d0975e3ea50fd65f8d11e793743f3c08e98978c/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp#L291 post?
So I think this is not a legal mhlo.scatter. I cannot generate this operation via jax and the nearest replication broadcasts to add a length-1 dimension.
Potentially missing verification? Could you flag MHLO folks to verify?
@rsuderman Hey Rob, do you have any update here?
Hey @rsuderman @NatashaKnk Any update on this? This is a P1 item for the Nod folks at the moment. I think we were pinging MHLO folks?
I'm on it, fix should be in today/tomorrow
Hey @NatashaKnk any update?
Sorry for the delay! Should be fixed when 10095 is submitted.
Validated in iree-samples that the case is now fixed