iree icon indicating copy to clipboard operation
iree copied to clipboard

[Integrate] Upstream narrow type emulation is breaking iree test

Open raikonenfnu opened this issue 8 months ago • 16 comments

What happened?

https://github.com/llvm/llvm-project/pull/133231 seems to be breaking our subbyte emulation tests.

A simple repro is https://gist.github.com/raikonenfnu/1ea07b7e231d8997bfa1c29502df637d

original test https://github.com/iree-org/iree/blob/1a8d229431e62b50eb297e75ca4bf1dba3b67f65/tests/e2e/linalg/fp_to_subbyte.mlir

Based on the above test, currently IREE is generating code that has different trailing dim and "virtually"/fake non constant offset into the store vector.store %0[%2] where %2 is actually a constant, but since it is inside a branch region, %c0 becomes %2 which is a region argument and is now a non-constant.

The condition above leads our program into this code path where it cannot determine !foldedNumFrontPadElems and fails this lowering. as seen in: https://github.com/llvm/llvm-project/blob/2de936b6eb38e7a37224a97c2a22aa79b9dfb9dc/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp#L619-L629

IIUC, that PR is somewhat correct, as long as the trailing dim does not match, we may need partial stores.

I don't think I have enough context on the best way to solve this. We can:

  1. Fold away that region since the condition to the regions are constants or somehow constant fold S.T we have constant offset in vector.store
  2. Add some code in vector dialect's EmulateNarrowType vector::StoreOp conversion to handle non constant cases through series of bitcasting and memref.generic_atomic_rmws. (this seems much harder and require some more thinking)

Steps to reproduce your issue

  1. wget https://gist.githubusercontent.com/raikonenfnu/1ea07b7e231d8997bfa1c29502df637d/raw/6e3311f6794dd5afeaeac0f82422cbe498d67272/emulate_failure.mlir

  2. iree-opt --iree-codegen-emulate-narrow-type emulate_failure.mlir -o out.mlir

emulate.mlir:18:3: error: failed to legalize operation 'vector.store' that was explicitly marked illegal

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

raikonenfnu avatar Apr 26 '25 03:04 raikonenfnu

So, I think what should be done there is that the slow path should be guarded by a runtime condition as in https://github.com/llvm/llvm-project/pull/135014

But, in addition, the narrow type emulation should be moved before scf-to-cf - heck, before OptimizeIntArithmetic - so that we can take advantage of divisibility analysis

krzysz00 avatar Apr 26 '25 05:04 krzysz00

So, I think what should be done there is that the slow path should be guarded by a runtime condition as in llvm/llvm-project#135014

But, in addition, the narrow type emulation should be moved before scf-to-cf - heck, before OptimizeIntArithmetic - so that we can take advantage of divisibility analysis

That is the way to go. Currently the transformation is also unable to handle some other cases, all needs runtime check.

lialan avatar Apr 28 '25 14:04 lialan

To clarify, which "other cases"?

krzysz00 avatar Apr 28 '25 23:04 krzysz00

To clarify, which "other cases"?

All the cases with a non-constant storing index.

lialan avatar Apr 29 '25 15:04 lialan

@raikonenfnu Was the emulation in the case of non-constant/dynamic indexing already working before? I think it was not working as intended even before this problematic patch.

lialan avatar Apr 29 '25 15:04 lialan

I do think the patch was fixing a real correctness issue, but didn't do it in a performant way

krzysz00 avatar Apr 29 '25 16:04 krzysz00

Sorry that my patch is causing this issue - I'd like to see if I can help resolve it.

To start, I’m trying to better understand the underlying problem. When I lower the reproducer (with my patch reverted), I get the following IR:

  func.func @_f32_to_i4_1d_dispatch_0_elementwise_8_f32xi4() {
    %cst = arith.constant dense<4> : vector<2xi32>
    %cst_0 = arith.constant dense<15> : vector<2xi32>
    %c4 = arith.constant 4 : index
    %c8 = arith.constant 8 : index
    %c0 = arith.constant 0 : index
    %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<8xf32>
    memref.assume_alignment %0, 64 : memref<8xf32>
    %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(Indirect) : memref<4xi8>
    memref.assume_alignment %1, 64 : memref<4xi8>
    cf.br ^bb1(%c0 : index)
  ^bb1(%2: index):  // 2 preds: ^bb0, ^bb2
    %3 = arith.cmpi slt, %2, %c8 : index
    cf.cond_br %3, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    %4 = vector.load %0[%2] : memref<8xf32>, vector<4xf32>
    %5 = arith.fptoui %4 : vector<4xf32> to vector<4xi32>
    %6 = affine.apply #map()[%2]
    %7 = vector.shuffle %5, %5 [0, 2] : vector<4xi32>, vector<4xi32>
    %8 = arith.andi %7, %cst_0 : vector<2xi32>
    %9 = vector.shuffle %5, %5 [1, 3] : vector<4xi32>, vector<4xi32>
    %10 = arith.andi %9, %cst_0 : vector<2xi32>
    %11 = arith.shli %10, %cst : vector<2xi32>
    %12 = arith.ori %8, %11 : vector<2xi32>
    %13 = arith.trunci %12 : vector<2xi32> to vector<2xi8>
    vector.store %13, %1[%6] : memref<4xi8>, vector<2xi8>
    %14 = arith.addi %2, %c4 : index
    cf.br ^bb1(%14 : index)
  ^bb3:  // pred: ^bb1
    return
  }

What stands out to me is that %2 (the branch+load/store index) doesn’t appear to be used to mask the input of the vector.store. I’m not yet sure what this constant (that's used for masking) represents ...

    %cst_0 = arith.constant dense<15> : vector<2xi32>

... but it doesn't look correct. So, it feels like the failing test is exercising an incorrect lowering path in upstream MLIR? And my patch effectively disables that.

Does this agree with your understanding?

banach-space avatar Apr 30 '25 15:04 banach-space

I think I know what's happening. Below is the IR that I move the emulation before scf->cf lowering, which looks easier. It also trims down IREE specific ops.

When we do tiling in IREE, the tile sizes are driven by native vector size on CPU. For i4 type emulation, they are always aligned store (i.e., no partial store) because typically the tile sizes are multiple of 2. I.e., it is always aligned in the for-loop.

func.func @main(%arg0: memref<8xf32>, %arg1: memref<8xi4>) {
  %c4 = arith.constant 4 : index
  %c8 = arith.constant 8 : index
  %c0 = arith.constant 0 : index
  scf.for %arg2 = %c0 to %c8 step %c4 {
    %0 = vector.load %arg0[%arg2] : memref<8xf32>, vector<4xf32>
    %1 = arith.fptoui %0 : vector<4xf32> to vector<4xi32>
    %2 = arith.trunci %1 : vector<4xi32> to vector<4xi4>
    vector.store %2, %arg1[%arg2] : memref<8xi4>, vector<4xi4>
  }
  return
}

The upstream fix is reasonable, and it enables the support for unaligned cases. Here, I think we miss a hint for the pattern. This can either be:

  1. Add a mode to upstream patterns that assumes that stores are always aligned. Then it is user's responsibility to tile it correct.
  2. We need an integer range analysis that identify if the flattened index is an aligned case or not. If so, we can always convert it to vector.bitcast and do the store. Otherwise, we'll need to support it separately.

I don't know how to achieve (2) at the moment, maybe ValueRange analysis can help. I only use it for querying upper_bound. Let me take a look if it can check if a value is a multiple of something or not.

In IREE, we have util.int.assume op that tells you if it can be divisible by udiv or not.

https://github.com/iree-org/iree/blob/5140464cee7635eda92fdbbff7df580e37a44652/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir#L14-L17

hanhanW avatar May 30 '25 04:05 hanhanW

Okay, I think upstream may not support this. To fix it properly in IREE, I think we should move some patterns back to IREE and use TensorDynamicDimAnalysis to get the information. We can prioritize the IREE patterns in this case.

hanhanW avatar May 30 '25 04:05 hanhanW

yeah, especially for stuff like this I just assume (hah) upstream doesn't support what we need in conjunction with our better type support and assume ops.

really good detective work on this - very tricky interactions!

benvanik avatar May 30 '25 04:05 benvanik

@lialan may be busy on other stuff. Assigning it to me for now, and I can find someone or me to work on it. (cc @MaheshRavishankar )

hanhanW avatar May 30 '25 04:05 hanhanW

We could also get around to adding align = N to vector.load and vector.store and such, which is @efric (and then an absent align = or an align = 0 would mean we need the unaligned case)

krzysz00 avatar May 30 '25 06:05 krzysz00

Another thing that can be done upstream is to put the slow and the fast cases in two arms of an if statement and let integer range analysis clear off the slow path.

This is what got done for the buffer OOB stuff

krzysz00 avatar May 30 '25 06:05 krzysz00

We could also get around to adding align = N to vector.load and vector.store and such, which is @efric (and then an absent align = or an align = 0 would mean we need the unaligned case)

I'd prefer not using this approach atm, because we can get all the needed information from IR. I can be convinced if there are other use cases.

Another thing that can be done upstream is to put the slow and the fast cases in two arms of an if statement and let integer range analysis clear off the slow path.

I'm not sure how integrate range analysis works with upstream patterns. It seems like the analysis should only be run once because of efficiency. Then we collect the ops and apply the transform. Also, the analysis is only available in IREE, so I'm not sure how they can be connected.

Anyway, I think we have a solid plan, and we can fix it if someone pick this up.

hanhanW avatar May 30 '25 06:05 hanhanW

Agreed RE upstream analysis interaction - I think we could provide interface implementations if upstream uses them that distills the info (I don't think we do). The particular benefit of the util.assume.int op is the correlation between independent SSA values ("if A is udiv=4 then B is udiv=8" or "if A is umin=8 then B is umin=16"). We can do a lot more with that (@qedawkins specialization is an example - specializing just for the sets of assumptions vs the full combinatorial explosion of them). We should probably plan on keeping anything important relying on analysis on the IREE side but exposing what we can (the union/intersection of all our assume pairs exposed as a single value range).

benvanik avatar May 30 '25 06:05 benvanik

I'm not sure how integrate range analysis works with upstream patterns. It seems like the analysis should only be run once because of efficiency. Then we collect the ops and apply the transform. Also, the analysis is only available in IREE, so I'm not sure how they can be connected.

Integer range analysis lives upstream - the only thing that doesn't is divisibility analysis, and that should be fairly easily upstreamable

And even if integer range analysis doesn't work, having the pattern gener9ate

if (linearrI4Index % 2 == 0) { fastPath } else { slowPath }

instead of just generating slowPath in cases where the divisibility isn't statically obvious will allow LLVM to take care of it

krzysz00 avatar May 30 '25 06:05 krzysz00