[Integrate] Fail to hoisting vectors when memref.assume_alignment is present
I hit the error, and it is not an easy fix to me. For now, filing an issue to unblock integrate. It only happens in hoisting transfers on memrefs, which is legacy pipeline.
To repro: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-optimize-vector-transfer))" repro.mlir
The input has hundreds IR, so I don't inline the program. See https://gist.github.com/hanhanW/a93a21297a915f9ff6eeb5bc0fdd8ba7
Adding a working IR that uses the memref directly, and it works: https://gist.github.com/hanhanW/4314bd77cb24468b3ebdea2cb9d274f5
The naming of variables looks confusing, because I made the change by hands. It is easier to make the change.
I got a smaller repro, and I trimed down IREE specifics.
To repro: iree-opt --iree-transform-dialect-interpreter --allow-unregistered-dialect repro.mlir
func.func @hoist_vector_transfer_pairs(
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
%r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
}
"unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
}
"unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
return
}
func.func @hoist_vector_transfer_pairs_bug(
%src_memref0: memref<?x?xf32>, %src_memref1: memref<?x?xf32>,
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
%memref0 = memref.assume_alignment %src_memref0, 64 : memref<?x?xf32>
%memref1 = memref.assume_alignment %src_memref1, 64 : memref<?x?xf32>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
scf.for %i = %lb to %ub step %step {
scf.for %j = %lb to %ub step %step {
%r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
%r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
}
"unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
}
"unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
Output:
func.func @hoist_vector_transfer_pairs(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i1) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg1[%c0, %c0], %cst : memref<?x?xf32>, vector<1xf32>
%1 = scf.for %arg7 = %arg3 to %arg4 step %arg5 iter_args(%arg8 = %0) -> (vector<1xf32>) {
%2 = vector.transfer_read %arg0[%arg7, %arg7], %cst : memref<?x?xf32>, vector<2xf32>
%3:2 = scf.for %arg9 = %arg3 to %arg4 step %arg5 iter_args(%arg10 = %arg8, %arg11 = %2) -> (vector<1xf32>, vector<2xf32>) {
%4 = "some_use"(%arg10) : (vector<1xf32>) -> vector<1xf32>
%5 = "some_use"(%arg11) : (vector<2xf32>) -> vector<2xf32>
scf.yield %4, %5 : vector<1xf32>, vector<2xf32>
}
vector.transfer_write %3#1, %arg0[%arg7, %arg7] : vector<2xf32>, memref<?x?xf32>
"unrelated_use"(%arg0) : (memref<?x?xf32>) -> ()
scf.yield %3#0 : vector<1xf32>
}
vector.transfer_write %1, %arg1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
"unrelated_use"(%arg1) : (memref<?x?xf32>) -> ()
return
}
func.func @hoist_vector_transfer_pairs_bug(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i1) {
%assume_align = memref.assume_alignment %arg0, 64 : memref<?x?xf32>
%assume_align_0 = memref.assume_alignment %arg1, 64 : memref<?x?xf32>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
scf.for %arg7 = %arg3 to %arg4 step %arg5 {
scf.for %arg8 = %arg3 to %arg4 step %arg5 {
%0 = vector.transfer_read %assume_align_0[%c0, %c0], %cst : memref<?x?xf32>, vector<1xf32>
%1 = vector.transfer_read %assume_align[%arg7, %arg7], %cst : memref<?x?xf32>, vector<2xf32>
%2 = "some_use"(%0) : (vector<1xf32>) -> vector<1xf32>
%3 = "some_use"(%1) : (vector<2xf32>) -> vector<2xf32>
vector.transfer_write %2, %assume_align_0[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
vector.transfer_write %3, %assume_align[%arg7, %arg7] : vector<2xf32>, memref<?x?xf32>
}
"unrelated_use"(%assume_align) : (memref<?x?xf32>) -> ()
}
"unrelated_use"(%assume_align_0) : (memref<?x?xf32>) -> ()
return
}
I think the issue is about how we implement MemoryEffectOpInterface for the memref.assume_alignment op. I'll take a look later.
Upstream implementation for hoisting: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
I verified that https://github.com/llvm/llvm-project/pull/144809 can fix the issue, see https://github.com/iree-org/iree/pull/21133.