iree icon indicating copy to clipboard operation
iree copied to clipboard

[Integrate] Fail to hoisting vectors when memref.assume_alignment is present

Open hanhanW opened this issue 7 months ago • 4 comments

I hit the error, and it is not an easy fix to me. For now, filing an issue to unblock integrate. It only happens in hoisting transfers on memrefs, which is legacy pipeline.

To repro: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-optimize-vector-transfer))" repro.mlir

The input has hundreds IR, so I don't inline the program. See https://gist.github.com/hanhanW/a93a21297a915f9ff6eeb5bc0fdd8ba7

hanhanW avatar May 28 '25 08:05 hanhanW

Adding a working IR that uses the memref directly, and it works: https://gist.github.com/hanhanW/4314bd77cb24468b3ebdea2cb9d274f5

The naming of variables looks confusing, because I made the change by hands. It is easier to make the change.

hanhanW avatar Jun 02 '25 11:06 hanhanW

I got a smaller repro, and I trimed down IREE specifics.

To repro: iree-opt --iree-transform-dialect-interpreter --allow-unregistered-dialect repro.mlir

func.func @hoist_vector_transfer_pairs(
    %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
    %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.0 : f32
  scf.for %i = %lb to %ub step %step {
    scf.for %j = %lb to %ub step %step {
      %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
      %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
      vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
      vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
    }
    "unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
  }
  "unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
  return
}

func.func @hoist_vector_transfer_pairs_bug(
    %src_memref0: memref<?x?xf32>, %src_memref1: memref<?x?xf32>,
    %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
  %memref0 = memref.assume_alignment %src_memref0, 64 : memref<?x?xf32>
  %memref1 = memref.assume_alignment %src_memref1, 64 : memref<?x?xf32>
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.0 : f32
  scf.for %i = %lb to %ub step %step {
    scf.for %j = %lb to %ub step %step {
      %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
      %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
      vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
      vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
    }
    "unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
  }
  "unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
  return
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["func.func"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    transform.structured.hoist_redundant_vector_transfers %0
      : (!transform.any_op) -> !transform.any_op
    transform.yield
  }
}

Output:

func.func @hoist_vector_transfer_pairs(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i1) {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %0 = vector.transfer_read %arg1[%c0, %c0], %cst : memref<?x?xf32>, vector<1xf32>
  %1 = scf.for %arg7 = %arg3 to %arg4 step %arg5 iter_args(%arg8 = %0) -> (vector<1xf32>) {
    %2 = vector.transfer_read %arg0[%arg7, %arg7], %cst : memref<?x?xf32>, vector<2xf32>
    %3:2 = scf.for %arg9 = %arg3 to %arg4 step %arg5 iter_args(%arg10 = %arg8, %arg11 = %2) -> (vector<1xf32>, vector<2xf32>) {
      %4 = "some_use"(%arg10) : (vector<1xf32>) -> vector<1xf32>
      %5 = "some_use"(%arg11) : (vector<2xf32>) -> vector<2xf32>
      scf.yield %4, %5 : vector<1xf32>, vector<2xf32>
    }
    vector.transfer_write %3#1, %arg0[%arg7, %arg7] : vector<2xf32>, memref<?x?xf32>
    "unrelated_use"(%arg0) : (memref<?x?xf32>) -> ()
    scf.yield %3#0 : vector<1xf32>
  }
  vector.transfer_write %1, %arg1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
  "unrelated_use"(%arg1) : (memref<?x?xf32>) -> ()
  return
}

func.func @hoist_vector_transfer_pairs_bug(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i1) {
  %assume_align = memref.assume_alignment %arg0, 64 : memref<?x?xf32>
  %assume_align_0 = memref.assume_alignment %arg1, 64 : memref<?x?xf32>
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  scf.for %arg7 = %arg3 to %arg4 step %arg5 {
    scf.for %arg8 = %arg3 to %arg4 step %arg5 {
      %0 = vector.transfer_read %assume_align_0[%c0, %c0], %cst : memref<?x?xf32>, vector<1xf32>
      %1 = vector.transfer_read %assume_align[%arg7, %arg7], %cst : memref<?x?xf32>, vector<2xf32>
      %2 = "some_use"(%0) : (vector<1xf32>) -> vector<1xf32>
      %3 = "some_use"(%1) : (vector<2xf32>) -> vector<2xf32>
      vector.transfer_write %2, %assume_align_0[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
      vector.transfer_write %3, %assume_align[%arg7, %arg7] : vector<2xf32>, memref<?x?xf32>
    }
    "unrelated_use"(%assume_align) : (memref<?x?xf32>) -> ()
  }
  "unrelated_use"(%assume_align_0) : (memref<?x?xf32>) -> ()
  return
}

hanhanW avatar Jun 02 '25 16:06 hanhanW

I think the issue is about how we implement MemoryEffectOpInterface for the memref.assume_alignment op. I'll take a look later.

Upstream implementation for hoisting: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

hanhanW avatar Jun 02 '25 16:06 hanhanW

I verified that https://github.com/llvm/llvm-project/pull/144809 can fix the issue, see https://github.com/iree-org/iree/pull/21133.

hanhanW avatar Jun 18 '25 22:06 hanhanW