iree icon indicating copy to clipboard operation
iree copied to clipboard

Failure to lower mhlo.scatter to LinalgExt

Open dan-garvey opened this issue 3 years ago • 15 comments

What happened?

may be related to #9456 but is still failing on latest release.

iree reproducer zip

Assuming steps are redundant with reproducers.

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

What component(s) does this issue relate to?

MLIR

Version information

iree-compiler 20220629.188 iree-runtime 20220629.188 iree-tools-tf 20220629.188 iree-tools-tflite 20220629.188 iree-tools-xla 20220629.188

Additional context

Attempted to run the iree-sample for huggingface but with an additional train method. It works without that.

dan-garvey avatar Jun 29 '22 19:06 dan-garvey

@NatashaKnk could you look at this too?

jpienaar avatar Jun 29 '22 22:06 jpienaar

Could you perhaps paste in the failing MHLO scatter here? (makes it easier to see before I get triage)

jpienaar avatar Jun 30 '22 00:06 jpienaar

The error I'm getting is "error: failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal", and the output scatter op is:

%13798 = "mhlo.scatter"(%49, %261, %13797) ({
^bb0(%arg4: tensor<f32> loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))]), %arg5: tensor<f32> loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))])):
  %14284 = "mhlo.add"(%arg4, %arg5) : (tensor<f32>, tensor<f32>) -> tensor<f32> loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))])
  "mhlo.return"(%14284) : (tensor<f32>) -> () loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))])
}) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<30522x384xf32>, tensor<512xi32>, tensor<512x384xf32>) -> tensor<30522x384xf32> loc(fused["UnsortedSegmentSum:", callsite("AddN_89/inputs_1@__inference_train_15002"("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/imperative_grad.py":67:0) at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/backprop.py":1100:0 at callsite("/home/dan/SHARK/tank/tf/huggingface_MiniLM_train.py":54:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1116:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":3251:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py":677:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py":1141:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2627:0 at callsite("/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2711:0 at "/home/dan/SHARK/shark.venv/lib/python3.10/site-packages/tensorflow/python/eager/function.py":2533:0)))))))))])

NatashaKnk avatar Jun 30 '22 11:06 NatashaKnk

@jpienaar bumping to P1 from today's Nod meeting.

allieculp avatar Jun 30 '22 17:06 allieculp

Local reproducer

func.func @main(%arg0: tensor<30522x384xf32>, %arg1: tensor<512xi32>, %arg2: tensor<512x384xf32>) -> tensor<30522x384xf32> {
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %1 = mhlo.add %arg3, %arg4 : tensor<f32>
    "mhlo.return"(%1) : (tensor<f32>) -> ()
  }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<30522x384xf32>, tensor<512xi32>, tensor<512x384xf32>) -> tensor<30522x384xf32>
  return %0 : tensor<30522x384xf32>
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "func.func(iree-mhlo-to-linalg-on-tensors)",
      disable_threading: true,
      verify_each: true
    }
  }
#-}

$tools/iree-opt /tmp/repro.mlir

results in

failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal

unfortunately notifyMatchFailure isn't used so debug output is just

//===-------------------------------------------===// Legalizing operation : 'mhlo.scatter'(0x8ea4b50) {

  • Fold { } -> FAILURE : unable to fold

  • Pattern : 'mhlo.scatter -> ()' { Trying to match "mlir::(anonymous namespace)::ScatterUpdateConversion" "mlir::(anonymous namespace)::ScatterUpdateConversion" result 0

jpienaar avatar Jul 14 '22 04:07 jpienaar

What seems to happen is that the folder fails on casting the scatter_dimension attribute (i.e. %arg1: tensor<512xi32>) to a DenseIntElementsAttr. I'm not exactly sure why that would be the case (is there something obvious I'm missing?), looking further into it.

NatashaKnk avatar Jul 18 '22 12:07 NatashaKnk

Could you add link to where in the code it fails?

jpienaar avatar Jul 18 '22 13:07 jpienaar

Yeah, sorry! As far as I understand it returns false here.

NatashaKnk avatar Jul 18 '22 13:07 NatashaKnk

So scatter_indices are not constant input, do we get to https://github.com/iree-org/iree/blob/8d0975e3ea50fd65f8d11e793743f3c08e98978c/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp#L291 post?

jpienaar avatar Jul 18 '22 14:07 jpienaar

So I think this is not a legal mhlo.scatter. I cannot generate this operation via jax and the nearest replication broadcasts to add a length-1 dimension.

rsuderman avatar Jul 20 '22 22:07 rsuderman

Potentially missing verification? Could you flag MHLO folks to verify?

jpienaar avatar Jul 20 '22 22:07 jpienaar

@rsuderman Hey Rob, do you have any update here?

allieculp avatar Jul 28 '22 17:07 allieculp

Hey @rsuderman @NatashaKnk Any update on this? This is a P1 item for the Nod folks at the moment. I think we were pinging MHLO folks?

allieculp avatar Aug 04 '22 17:08 allieculp

I'm on it, fix should be in today/tomorrow

NatashaKnk avatar Aug 04 '22 17:08 NatashaKnk

Hey @NatashaKnk any update?

allieculp avatar Aug 11 '22 17:08 allieculp

Sorry for the delay! Should be fixed when 10095 is submitted.

NatashaKnk avatar Aug 15 '22 16:08 NatashaKnk

Validated in iree-samples that the case is now fixed

rsuderman avatar Aug 16 '22 23:08 rsuderman