Fuse mmt4d ukernel with consumer and get perfect codegen thanks to store-to-load forwarding
mmt4d->consumer fusion is a major codegen optimization opportunity. The typical consumers include the unpack following the mmt4d, and element-wise ops (generics with only parallel iterators) that are typically found as matmul consumers, for instance the bias-additions and activation functions of NN workloads.
ukernels preclude codegen fusion at MLIR level, but they do not preclude late LLVM IR optimizations such as store-to-load forwarding. However, for there to be a store-to-load forwarding opportunity between the final stores in the mmt4d ukernel and the initial loads in the consumer without relying on LLVM loop fusion/interleaving, there must be no nontrivial loop enclosing the stores in the mmt4d ukernel. This is only the case if the outer M/N-dimension shape of the mmt4d ukernel are both constant 1, making the M/N outer loops trivial.
So there are at least two things that need to happen here:
- Enable mmt4d->consumer and/or ukernel->consumer fusion (depending on whether that is decided before or after LowerToUKernels).
- Set outer tile sizes to 1.
Here is a typical test case (from https://gist.github.com/bjacob/295ae908d4061178a5d1c691f12fd174):
#map_2d_identity = affine_map<(m,n) -> (m, n)>
#map_2d_first_coordinate = affine_map<(m,n)->(m)>
func.func @matmul_bias_relu_dynamic(
%input_activations : tensor<?x?xf32>,
%weights : tensor<?x?xf32>,
%bias : tensor<?xf32>
) -> tensor<?x?xf32> {
%c0f32 = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// Query dimensions of dynamic-size input tensors.
%m = tensor.dim %input_activations, %c0 : tensor<?x?xf32>
%n = tensor.dim %weights, %c1 : tensor<?x?xf32>
// Perform the matrix multiplication.
%matmul_empty = tensor.empty(%m, %n) : tensor<?x?xf32>
%matmul_accumulator = linalg.fill ins(%c0f32 : f32) outs(%matmul_empty : tensor<?x?xf32>) -> tensor<?x?xf32>
%matmul_result = linalg.matmul
ins(%input_activations, %weights : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%matmul_accumulator : tensor<?x?xf32>) -> tensor<?x?xf32>
// Perform the bias-addition and ReLU
%result_empty = tensor.empty(%m, %n) : tensor<?x?xf32>
%result = linalg.generic {
indexing_maps=[
#map_2d_identity,
#map_2d_first_coordinate,
#map_2d_identity
],
iterator_types=["parallel", "parallel"]
} ins(%matmul_result, %bias : tensor<?x?xf32>, tensor<?xf32>)
outs(%result_empty : tensor<?x?xf32>) {
^bb0(%matmul_entry : f32, %bias_entry : f32, %unused_result_entry : f32):
%add = arith.addf %matmul_entry, %bias_entry : f32
%relu = arith.maximumf %add, %c0f32 : f32
linalg.yield %relu : f32
} -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
EDIT:
Tile size 1 is not going to be optimal for other reasons (it will result in too many, too-small dispatach function calls). Perhaps once we've secured the ability to get perfect codegen with tile size 1, the next question is how do we reconcile that with suitably sized dispatch function calls.
Oh I think I know... we need to replace "mmt4d ukernel with outer sizes M, N" by "scf.for loops with trip counts M,N around mmt4d ukernel with outer sizes 1,1". Then this has nothing to do anymore with distribution tile sizes.
But this idea relies on the ability to perform a fusion of scf.for loops ?
I generated an example using cpu=znver4. In your example, I think we will get something like for fusion (when we enable unpack propagation):
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
module {
func.func @mmt4d_bias_relu_fusion(%arg0: tensor<?x?x16x1xf32>, %arg1: tensor<?x?x16x1xf32>, %arg2: tensor<?x16xf32>) -> tensor<?x?x16x16xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x16x1xf32>
%dim_0 = tensor.dim %arg1, %c0 : tensor<?x?x16x1xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?x16x16xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
%2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<?x?x16x1xf32>, tensor<?x?x16x1xf32>) outs(%1 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
%3 = tensor.empty(%dim, %dim_0) : tensor<?x?x16x16xf32>
%4 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2, %arg2 : tensor<?x?x16x16xf32>, tensor<?x16xf32>) outs(%3 : tensor<?x?x16x16xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%5 = arith.addf %in, %in_1 : f32
%6 = arith.maximumf %5, %cst : f32
linalg.yield %6 : f32
} -> tensor<?x?x16x16xf32>
return %4 : tensor<?x?x16x16xf32>
}
}
Input IR (starting from executable-source:
hal.executable public @mmt4d_bias_relu_fusion_dispatch_0 {
hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+mwaitx,+pku,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sha,+shstk,+vaes,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,+evex512", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "all"}>) {
hal.executable.export public @mmt4d_bias_relu_fusion_dispatch_0_generic_DxDx16x16_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 10, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<Mmt4dTilingExpert>} {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3, %arg4, %arg5
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @mmt4d_bias_relu_fusion_dispatch_0_generic_DxDx16x16_f32() {
%c0 = arith.constant 0 : index
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = hal.interface.constant.load[5] : i32
%6 = hal.interface.constant.load[6] : i32
%7 = hal.interface.constant.load[7] : i32
%8 = hal.interface.constant.load[8] : i32
%9 = hal.interface.constant.load[9] : i32
%10 = arith.extui %0 : i32 to i64
%11 = arith.extui %1 : i32 to i64
%12 = arith.shli %11, %c32_i64 : i64
%13 = arith.ori %10, %12 : i64
%14 = arith.index_castui %13 : i64 to index
%15 = arith.extui %2 : i32 to i64
%16 = arith.extui %3 : i32 to i64
%17 = arith.shli %16, %c32_i64 : i64
%18 = arith.ori %15, %17 : i64
%19 = arith.index_castui %18 : i64 to index
%20 = arith.extui %4 : i32 to i64
%21 = arith.extui %5 : i32 to i64
%22 = arith.shli %21, %c32_i64 : i64
%23 = arith.ori %20, %22 : i64
%24 = arith.index_castui %23 : i64 to index
%25 = arith.extui %6 : i32 to i64
%26 = arith.extui %7 : i32 to i64
%27 = arith.shli %26, %c32_i64 : i64
%28 = arith.ori %25, %27 : i64
%29 = arith.index_castui %28 : i64 to index
%30 = arith.extui %8 : i32 to i64
%31 = arith.extui %9 : i32 to i64
%32 = arith.shli %31, %c32_i64 : i64
%33 = arith.ori %30, %32 : i64
%34 = arith.index_castui %33 : i64 to index
%35 = flow.dispatch.workload.ordinal %14, 0 : index
%36 = flow.dispatch.workload.ordinal %19, 1 : index
%37 = flow.dispatch.workload.ordinal %24, 2 : index
%38 = flow.dispatch.workload.ordinal %29, 3 : index
%39 = flow.dispatch.workload.ordinal %34, 4 : index
%40 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%38, %35}
%41 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%39, %36}
%42 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x16xf32>>{%37}
%43 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<?x?x16x16xf32>>{%38, %39}
%44 = flow.dispatch.tensor.load %40, offsets = [0, 0, 0, 0], sizes = [%38, %35, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%38, %35} -> tensor<?x?x16x1xf32>
%45 = flow.dispatch.tensor.load %41, offsets = [0, 0, 0, 0], sizes = [%39, %36, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%39, %36} -> tensor<?x?x16x1xf32>
%46 = flow.dispatch.tensor.load %42, offsets = [0, 0], sizes = [%37, 16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x16xf32>>{%37} -> tensor<?x16xf32>
%47 = tensor.empty(%38, %39) : tensor<?x?x16x16xf32>
%48 = linalg.fill ins(%cst : f32) outs(%47 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
%49 = linalg.mmt4d {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]>} ins(%44, %45 : tensor<?x?x16x1xf32>, tensor<?x?x16x1xf32>) outs(%48 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
%50 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%49, %46 : tensor<?x?x16x16xf32>, tensor<?x16xf32>)
outs(%47 : tensor<?x?x16x16xf32>)
attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 2, 0, 0], [1, 1, 16, 16], [0, 0, 0, 0], [0, 0, 0, 0]]>}
{
^bb0(%in: f32, %in_0: f32, %out: f32):
%51 = arith.addf %in, %in_0 : f32
%52 = arith.maximumf %51, %cst : f32
linalg.yield %52 : f32
} -> tensor<?x?x16x16xf32>
flow.dispatch.tensor.store %50, %43, offsets = [0, 0, 0, 0], sizes = [%38, %39, 16, 16], strides = [1, 1, 1, 1] : tensor<?x?x16x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?x16x16xf32>>{%38, %39}
return
}
}
}
}
The mmt4d op will be replaced with ukernel, so only the first list is important. The tile sizes used in TileAndFuse is the second list in generic op.
To compile the dispatch, run iree-compile --output-format=vm-bytecode --iree-llvmcpu-enable-ukernels=all ~/module_mmt4d_fusion_dispatch_0.mlir -o /tmp/a.vmfb --compile-from=executable-sources --iree-llvmcpu-keep-linker-artifacts --iree-llvmcpu-link-embedded=false
Then we can get the asm:
/tmp/mmt4d_bias_relu_fusion_dispatch_0-e687d2.so: file format elf64-x86-64
Disassembly of section .text:
0000000000001440 <iree_hal_executable_library_query-0x6a0>:
1440: 55 push rbp
1441: 48 89 e5 mov rbp,rsp
1444: 41 57 push r15
1446: 41 56 push r14
1448: 41 55 push r13
144a: 41 54 push r12
144c: 53 push rbx
144d: 48 83 e4 c0 and rsp,0xffffffffffffffc0
1451: 48 81 ec 00 09 00 00 sub rsp,0x900
1458: 48 8b 4e 18 mov rcx,QWORD PTR [rsi+0x18]
145c: 8b 7a 04 mov edi,DWORD PTR [rdx+0x4]
145f: 48 8b 41 18 mov rax,QWORD PTR [rcx+0x18]
1463: 48 89 7c 24 40 mov QWORD PTR [rsp+0x40],rdi
1468: 48 89 84 24 88 00 00 mov QWORD PTR [rsp+0x88],rax
146f: 00
1470: 48 39 f8 cmp rax,rdi
1473: 0f 8e 49 06 00 00 jle 1ac2 <iree_hal_executable_library_query-0x1e>
1479: 48 8b 46 20 mov rax,QWORD PTR [rsi+0x20]
147d: 4c 8b 49 20 mov r9,QWORD PTR [rcx+0x20]
1481: 44 8b 69 08 mov r13d,DWORD PTR [rcx+0x8]
1485: 44 8b 46 10 mov r8d,DWORD PTR [rsi+0x10]
1489: 8b 7e 0c mov edi,DWORD PTR [rsi+0xc]
148c: 8b 71 0c mov esi,DWORD PTR [rcx+0xc]
148f: 4c 8b 11 mov r10,QWORD PTR [rcx]
1492: c5 f8 57 c0 vxorps xmm0,xmm0,xmm0
1496: 4c 8b 58 08 mov r11,QWORD PTR [rax+0x8]
149a: 48 8b 08 mov rcx,QWORD PTR [rax]
149d: 4c 8b 70 10 mov r14,QWORD PTR [rax+0x10]
14a1: 48 c1 e6 29 shl rsi,0x29
14a5: 49 c1 e5 09 shl r13,0x9
14a9: 4c 89 84 24 80 00 00 mov QWORD PTR [rsp+0x80],r8
14b0: 00
14b1: 4c 89 8c 24 98 00 00 mov QWORD PTR [rsp+0x98],r9
14b8: 00
14b9: 49 09 f5 or r13,rsi
14bc: 4c 89 ac 24 a0 00 00 mov QWORD PTR [rsp+0xa0],r13
14c3: 00
14c4: 49 c1 fd 03 sar r13,0x3
14c8: 4c 89 5c 24 58 mov QWORD PTR [rsp+0x58],r11
14cd: 4d 89 cb mov r11,r9
14d0: 49 c1 e3 08 shl r11,0x8
14d4: 48 89 4c 24 70 mov QWORD PTR [rsp+0x70],rcx
14d9: 8b 0a mov ecx,DWORD PTR [rdx]
14db: 4c 89 da mov rdx,r11
14de: 48 0f af 54 24 40 imul rdx,QWORD PTR [rsp+0x40]
14e4: 4d 0f af d8 imul r11,r8
14e8: 49 c1 e3 02 shl r11,0x2
14ec: 48 8d 1c 09 lea rbx,[rcx+rcx*1]
14f0: 48 c1 e1 0b shl rcx,0xb
14f4: 4c 89 5c 24 78 mov QWORD PTR [rsp+0x78],r11
14f9: 48 89 5c 24 50 mov QWORD PTR [rsp+0x50],rbx
14fe: 48 8d 0c 91 lea rcx,[rcx+rdx*4]
1502: 48 03 48 18 add rcx,QWORD PTR [rax+0x18]
1506: 4c 89 c8 mov rax,r9
1509: 48 29 d8 sub rax,rbx
150c: 48 89 44 24 68 mov QWORD PTR [rsp+0x68],rax
1511: 4c 89 d0 mov rax,r10
1514: 48 c1 e0 09 shl rax,0x9
1518: 48 89 44 24 60 mov QWORD PTR [rsp+0x60],rax
151d: 48 8d 04 3f lea rax,[rdi+rdi*1]
1521: 48 c1 e7 0b shl rdi,0xb
1525: 48 89 84 24 90 00 00 mov QWORD PTR [rsp+0x90],rax
152c: 00
152d: 48 89 bc 24 a8 00 00 mov QWORD PTR [rsp+0xa8],rdi
1534: 00
1535: 48 89 4c 24 48 mov QWORD PTR [rsp+0x48],rcx
153a: eb 33 jmp 156f <iree_hal_executable_library_query-0x571>
153c: 0f 1f 40 00 nop DWORD PTR [rax+0x0]
1540: 48 8b 44 24 40 mov rax,QWORD PTR [rsp+0x40]
1545: 48 8b 4c 24 48 mov rcx,QWORD PTR [rsp+0x48]
154a: 48 03 84 24 80 00 00 add rax,QWORD PTR [rsp+0x80]
1551: 00
1552: 48 03 4c 24 78 add rcx,QWORD PTR [rsp+0x78]
1557: 48 89 4c 24 48 mov QWORD PTR [rsp+0x48],rcx
155c: 48 89 44 24 40 mov QWORD PTR [rsp+0x40],rax
1561: 48 3b 84 24 88 00 00 cmp rax,QWORD PTR [rsp+0x88]
1568: 00
1569: 0f 8d 53 05 00 00 jge 1ac2 <iree_hal_executable_library_query-0x1e>
156f: 4c 3b 4c 24 50 cmp r9,QWORD PTR [rsp+0x50]
1574: 7e ca jle 1540 <iree_hal_executable_library_query-0x5a0>
1576: 48 8b 44 24 60 mov rax,QWORD PTR [rsp+0x60]
157b: 48 8b 4c 24 40 mov rcx,QWORD PTR [rsp+0x40]
1580: 4c 8b 64 24 48 mov r12,QWORD PTR [rsp+0x48]
1585: 48 8b 54 24 68 mov rdx,QWORD PTR [rsp+0x68]
158a: 48 8b 7c 24 50 mov rdi,QWORD PTR [rsp+0x50]
158f: 48 0f af c1 imul rax,rcx
1593: 48 c1 e1 04 shl rcx,0x4
1597: 48 c1 f8 03 sar rax,0x3
159b: 48 03 44 24 70 add rax,QWORD PTR [rsp+0x70]
15a0: eb 48 jmp 15ea <iree_hal_executable_library_query-0x4f6>
15a2: 66 66 66 66 66 2e 0f data16 data16 data16 data16 cs nop WORD PTR [rax+rax*1+0x0]
15a9: 1f 84 00 00 00 00 00
15b0: 48 8b 94 24 90 00 00 mov rdx,QWORD PTR [rsp+0x90]
15b7: 00
15b8: 48 8b bc 24 b0 00 00 mov rdi,QWORD PTR [rsp+0xb0]
15bf: 00
15c0: 48 8b b4 24 b8 00 00 mov rsi,QWORD PTR [rsp+0xb8]
15c7: 00
15c8: 4c 03 a4 24 a8 00 00 add r12,QWORD PTR [rsp+0xa8]
15cf: 00
15d0: 4c 8b 8c 24 98 00 00 mov r9,QWORD PTR [rsp+0x98]
15d7: 00
15d8: 48 01 d7 add rdi,rdx
15db: 48 29 d6 sub rsi,rdx
15de: 48 89 f2 mov rdx,rsi
15e1: 4c 39 cf cmp rdi,r9
15e4: 0f 8d 56 ff ff ff jge 1540 <iree_hal_executable_library_query-0x5a0>
15ea: 48 8b b4 24 a0 00 00 mov rsi,QWORD PTR [rsp+0xa0]
15f1: 00
15f2: 45 31 ff xor r15d,r15d
15f5: 48 83 fa 02 cmp rdx,0x2
15f9: 48 89 94 24 b8 00 00 mov QWORD PTR [rsp+0xb8],rdx
1600: 00
1601: 4c 89 ca mov rdx,r9
1604: bb 02 00 00 00 mov ebx,0x2
1609: 4c 8b 44 24 58 mov r8,QWORD PTR [rsp+0x58]
160e: 0f 0d 8c 24 c0 00 00 prefetchw BYTE PTR [rsp+0xc0]
1615: 00
1616: 0f 18 08 prefetcht0 BYTE PTR [rax]
1619: 48 89 bc 24 b0 00 00 mov QWORD PTR [rsp+0xb0],rdi
1620: 00
1621: 0f 9d 44 24 3f setge BYTE PTR [rsp+0x3f]
1626: 48 29 fa sub rdx,rdi
1629: 48 83 fa 02 cmp rdx,0x2
162d: 48 0f 4c da cmovl rbx,rdx
1631: 48 0f af f7 imul rsi,rdi
1635: 48 c1 fe 03 sar rsi,0x3
1639: 41 0f 18 0c 30 prefetcht0 BYTE PTR [r8+rsi*1]
163e: 85 db test ebx,ebx
1640: 0f 8e fa 01 00 00 jle 1840 <iree_hal_executable_library_query-0x2a0>
1646: 48 03 74 24 58 add rsi,QWORD PTR [rsp+0x58]
164b: 48 8d bc 24 c0 00 00 lea rdi,[rsp+0xc0]
1652: 00
1653: 45 31 db xor r11d,r11d
1656: e9 d0 00 00 00 jmp 172b <iree_hal_executable_library_query-0x3b5>
165b: 0f 1f 44 00 00 nop DWORD PTR [rax+rax*1+0x0]
1660: c5 e8 57 d2 vxorps xmm2,xmm2,xmm2
1664: c5 e0 57 db vxorps xmm3,xmm3,xmm3
1668: c5 d8 57 e4 vxorps xmm4,xmm4,xmm4
166c: c5 d0 57 ed vxorps xmm5,xmm5,xmm5
1670: c5 c8 57 f6 vxorps xmm6,xmm6,xmm6
1674: c5 c0 57 ff vxorps xmm7,xmm7,xmm7
1678: c4 41 38 57 c0 vxorps xmm8,xmm8,xmm8
167d: c4 41 30 57 c9 vxorps xmm9,xmm9,xmm9
1682: c4 41 28 57 d2 vxorps xmm10,xmm10,xmm10
1687: c4 41 20 57 db vxorps xmm11,xmm11,xmm11
168c: c4 41 18 57 e4 vxorps xmm12,xmm12,xmm12
1691: c4 41 10 57 ed vxorps xmm13,xmm13,xmm13
1696: c4 41 08 57 f6 vxorps xmm14,xmm14,xmm14
169b: c4 41 00 57 ff vxorps xmm15,xmm15,xmm15
16a0: 62 a1 7c 00 57 c0 vxorps xmm16,xmm16,xmm16
16a6: 62 e1 7c 48 11 07 vmovups ZMMWORD PTR [rdi],zmm16
16ac: 62 71 7c 48 11 7f 01 vmovups ZMMWORD PTR [rdi+0x40],zmm15
16b3: 62 71 7c 48 11 77 02 vmovups ZMMWORD PTR [rdi+0x80],zmm14
16ba: 62 71 7c 48 11 6f 03 vmovups ZMMWORD PTR [rdi+0xc0],zmm13
16c1: 62 71 7c 48 11 67 04 vmovups ZMMWORD PTR [rdi+0x100],zmm12
16c8: 62 71 7c 48 11 5f 05 vmovups ZMMWORD PTR [rdi+0x140],zmm11
16cf: 62 71 7c 48 11 57 06 vmovups ZMMWORD PTR [rdi+0x180],zmm10
16d6: 62 71 7c 48 11 4f 07 vmovups ZMMWORD PTR [rdi+0x1c0],zmm9
16dd: 62 71 7c 48 11 47 08 vmovups ZMMWORD PTR [rdi+0x200],zmm8
16e4: 62 f1 7c 48 11 7f 09 vmovups ZMMWORD PTR [rdi+0x240],zmm7
16eb: 62 f1 7c 48 11 77 0a vmovups ZMMWORD PTR [rdi+0x280],zmm6
16f2: 62 f1 7c 48 11 6f 0b vmovups ZMMWORD PTR [rdi+0x2c0],zmm5
16f9: 62 f1 7c 48 11 67 0c vmovups ZMMWORD PTR [rdi+0x300],zmm4
1700: 62 f1 7c 48 11 5f 0d vmovups ZMMWORD PTR [rdi+0x340],zmm3
1707: 62 f1 7c 48 11 57 0e vmovups ZMMWORD PTR [rdi+0x380],zmm2
170e: 62 f1 7c 48 11 4f 0f vmovups ZMMWORD PTR [rdi+0x3c0],zmm1
1715: 48 81 c7 00 04 00 00 add rdi,0x400
171c: 4c 01 ee add rsi,r13
171f: 41 ff c3 inc r11d
1722: 41 39 db cmp r11d,ebx
1725: 0f 84 15 01 00 00 je 1840 <iree_hal_executable_library_query-0x2a0>
172b: 0f 18 08 prefetcht0 BYTE PTR [rax]
172e: 0f 18 0e prefetcht0 BYTE PTR [rsi]
1731: c5 f0 57 c9 vxorps xmm1,xmm1,xmm1
1735: 4d 85 d2 test r10,r10
1738: 0f 8e 22 ff ff ff jle 1660 <iree_hal_executable_library_query-0x480>
173e: c5 e8 57 d2 vxorps xmm2,xmm2,xmm2
1742: c5 e0 57 db vxorps xmm3,xmm3,xmm3
1746: c5 d8 57 e4 vxorps xmm4,xmm4,xmm4
174a: c5 d0 57 ed vxorps xmm5,xmm5,xmm5
174e: c5 c8 57 f6 vxorps xmm6,xmm6,xmm6
1752: c5 c0 57 ff vxorps xmm7,xmm7,xmm7
1756: c4 41 38 57 c0 vxorps xmm8,xmm8,xmm8
175b: c4 41 30 57 c9 vxorps xmm9,xmm9,xmm9
1760: c4 41 28 57 d2 vxorps xmm10,xmm10,xmm10
1765: c4 41 20 57 db vxorps xmm11,xmm11,xmm11
176a: c4 41 18 57 e4 vxorps xmm12,xmm12,xmm12
176f: c4 41 10 57 ed vxorps xmm13,xmm13,xmm13
1774: c4 41 08 57 f6 vxorps xmm14,xmm14,xmm14
1779: c4 41 00 57 ff vxorps xmm15,xmm15,xmm15
177e: 62 a1 7c 00 57 c0 vxorps xmm16,xmm16,xmm16
1784: 45 31 c0 xor r8d,r8d
1787: 4d 89 d1 mov r9,r10
178a: 66 0f 1f 44 00 00 nop WORD PTR [rax+rax*1+0x0]
1790: 62 a1 7c 48 10 0c 06 vmovups zmm17,ZMMWORD PTR [rsi+r8*1]
1797: 42 0f 18 8c 06 00 02 prefetcht0 BYTE PTR [rsi+r8*1+0x200]
179e: 00 00
17a0: 62 a2 75 50 b8 04 00 vfmadd231ps zmm16,zmm17,DWORD PTR [rax+r8*1]{1to16}
17a7: 62 32 75 50 b8 7c 00 vfmadd231ps zmm15,zmm17,DWORD PTR [rax+r8*1+0x4]{1to16}
17ae: 01
17af: 62 32 75 50 b8 74 00 vfmadd231ps zmm14,zmm17,DWORD PTR [rax+r8*1+0x8]{1to16}
17b6: 02
17b7: 62 32 75 50 b8 6c 00 vfmadd231ps zmm13,zmm17,DWORD PTR [rax+r8*1+0xc]{1to16}
17be: 03
17bf: 62 32 75 50 b8 64 00 vfmadd231ps zmm12,zmm17,DWORD PTR [rax+r8*1+0x10]{1to16}
17c6: 04
17c7: 62 32 75 50 b8 5c 00 vfmadd231ps zmm11,zmm17,DWORD PTR [rax+r8*1+0x14]{1to16}
17ce: 05
17cf: 62 32 75 50 b8 54 00 vfmadd231ps zmm10,zmm17,DWORD PTR [rax+r8*1+0x18]{1to16}
17d6: 06
17d7: 62 32 75 50 b8 4c 00 vfmadd231ps zmm9,zmm17,DWORD PTR [rax+r8*1+0x1c]{1to16}
17de: 07
17df: 62 32 75 50 b8 44 00 vfmadd231ps zmm8,zmm17,DWORD PTR [rax+r8*1+0x20]{1to16}
17e6: 08
17e7: 62 b2 75 50 b8 7c 00 vfmadd231ps zmm7,zmm17,DWORD PTR [rax+r8*1+0x24]{1to16}
17ee: 09
17ef: 62 b2 75 50 b8 74 00 vfmadd231ps zmm6,zmm17,DWORD PTR [rax+r8*1+0x28]{1to16}
17f6: 0a
17f7: 62 b2 75 50 b8 6c 00 vfmadd231ps zmm5,zmm17,DWORD PTR [rax+r8*1+0x2c]{1to16}
17fe: 0b
17ff: 62 b2 75 50 b8 64 00 vfmadd231ps zmm4,zmm17,DWORD PTR [rax+r8*1+0x30]{1to16}
1806: 0c
1807: 62 b2 75 50 b8 5c 00 vfmadd231ps zmm3,zmm17,DWORD PTR [rax+r8*1+0x34]{1to16}
180e: 0d
180f: 62 b2 75 50 b8 54 00 vfmadd231ps zmm2,zmm17,DWORD PTR [rax+r8*1+0x38]{1to16}
1816: 0e
1817: 62 b2 75 50 b8 4c 00 vfmadd231ps zmm1,zmm17,DWORD PTR [rax+r8*1+0x3c]{1to16}
181e: 0f
181f: 42 0f 18 8c 00 00 02 prefetcht0 BYTE PTR [rax+r8*1+0x200]
1826: 00 00
1828: 49 83 c0 40 add r8,0x40
182c: 49 ff c9 dec r9
182f: 0f 85 5b ff ff ff jne 1790 <iree_hal_executable_library_query-0x350>
1835: e9 6c fe ff ff jmp 16a6 <iree_hal_executable_library_query-0x43a>
183a: 66 0f 1f 44 00 00 nop WORD PTR [rax+rax*1+0x0]
1840: 48 85 d2 test rdx,rdx
1843: 0f 8e 67 fd ff ff jle 15b0 <iree_hal_executable_library_query-0x530>
1849: 0f b6 54 24 3f movzx edx,BYTE PTR [rsp+0x3f]
184e: 41 88 d7 mov r15b,dl
1851: ba c0 03 00 00 mov edx,0x3c0
1856: 49 ff c7 inc r15
1859: 0f 1f 80 00 00 00 00 nop DWORD PTR [rax+0x0]
1860: 62 f1 7c 48 28 4c 14 vmovaps zmm1,ZMMWORD PTR [rsp+rdx*1-0x300]
1867: f4
1868: 62 f1 7c 48 28 54 14 vmovaps zmm2,ZMMWORD PTR [rsp+rdx*1-0x2c0]
186f: f5
1870: 62 f1 7c 48 28 5c 14 vmovaps zmm3,ZMMWORD PTR [rsp+rdx*1-0x280]
1877: f6
1878: 62 f1 7c 48 28 64 14 vmovaps zmm4,ZMMWORD PTR [rsp+rdx*1-0x240]
187f: f7
1880: 62 f1 7c 48 28 6c 14 vmovaps zmm5,ZMMWORD PTR [rsp+rdx*1-0x200]
1887: f8
1888: 62 f1 7c 48 28 74 14 vmovaps zmm6,ZMMWORD PTR [rsp+rdx*1-0x1c0]
188f: f9
1890: 62 f1 7c 48 28 7c 14 vmovaps zmm7,ZMMWORD PTR [rsp+rdx*1-0x180]
1897: fa
1898: 62 71 7c 48 28 44 14 vmovaps zmm8,ZMMWORD PTR [rsp+rdx*1-0x140]
189f: fb
18a0: 62 71 7c 48 28 4c 14 vmovaps zmm9,ZMMWORD PTR [rsp+rdx*1-0x100]
18a7: fc
18a8: 62 71 7c 48 28 54 14 vmovaps zmm10,ZMMWORD PTR [rsp+rdx*1-0xc0]
18af: fd
18b0: 62 71 7c 48 28 5c 14 vmovaps zmm11,ZMMWORD PTR [rsp+rdx*1-0x80]
18b7: fe
18b8: 62 71 7c 48 28 64 14 vmovaps zmm12,ZMMWORD PTR [rsp+rdx*1-0x40]
18bf: ff
18c0: 62 71 7c 48 28 2c 14 vmovaps zmm13,ZMMWORD PTR [rsp+rdx*1]
18c7: 62 71 7c 48 28 74 14 vmovaps zmm14,ZMMWORD PTR [rsp+rdx*1+0x40]
18ce: 01
18cf: 62 71 7c 48 28 7c 14 vmovaps zmm15,ZMMWORD PTR [rsp+rdx*1+0x80]
18d6: 02
18d7: 62 e1 7c 48 28 44 14 vmovaps zmm16,ZMMWORD PTR [rsp+rdx*1+0xc0]
18de: 03
18df: 62 d1 74 58 58 0c 8e vaddps zmm1,zmm1,DWORD PTR [r14+rcx*4]{1to16}
18e6: 62 d1 6c 58 58 54 8e vaddps zmm2,zmm2,DWORD PTR [r14+rcx*4+0x4]{1to16}
18ed: 01
18ee: 62 d1 64 58 58 5c 8e vaddps zmm3,zmm3,DWORD PTR [r14+rcx*4+0x8]{1to16}
18f5: 02
18f6: 62 d1 5c 58 58 64 8e vaddps zmm4,zmm4,DWORD PTR [r14+rcx*4+0xc]{1to16}
18fd: 03
18fe: 62 d1 54 58 58 6c 8e vaddps zmm5,zmm5,DWORD PTR [r14+rcx*4+0x10]{1to16}
1905: 04
1906: 62 d1 4c 58 58 74 8e vaddps zmm6,zmm6,DWORD PTR [r14+rcx*4+0x14]{1to16}
190d: 05
190e: 62 d1 44 58 58 7c 8e vaddps zmm7,zmm7,DWORD PTR [r14+rcx*4+0x18]{1to16}
1915: 06
1916: 62 51 3c 58 58 44 8e vaddps zmm8,zmm8,DWORD PTR [r14+rcx*4+0x1c]{1to16}
191d: 07
191e: 62 51 34 58 58 4c 8e vaddps zmm9,zmm9,DWORD PTR [r14+rcx*4+0x20]{1to16}
1925: 08
1926: 62 51 2c 58 58 54 8e vaddps zmm10,zmm10,DWORD PTR [r14+rcx*4+0x24]{1to16}
192d: 09
192e: 62 51 24 58 58 5c 8e vaddps zmm11,zmm11,DWORD PTR [r14+rcx*4+0x28]{1to16}
1935: 0a
1936: 62 51 1c 58 58 64 8e vaddps zmm12,zmm12,DWORD PTR [r14+rcx*4+0x2c]{1to16}
193d: 0b
193e: 62 51 14 58 58 6c 8e vaddps zmm13,zmm13,DWORD PTR [r14+rcx*4+0x30]{1to16}
1945: 0c
1946: 62 51 0c 58 58 74 8e vaddps zmm14,zmm14,DWORD PTR [r14+rcx*4+0x34]{1to16}
194d: 0d
194e: 62 51 04 58 58 7c 8e vaddps zmm15,zmm15,DWORD PTR [r14+rcx*4+0x38]{1to16}
1955: 0e
1956: 62 c1 7c 50 58 44 8e vaddps zmm16,zmm16,DWORD PTR [r14+rcx*4+0x3c]{1to16}
195d: 0f
195e: 62 f1 74 48 c2 c8 06 vcmpnleps k1,zmm1,zmm0
1965: 62 f1 7c c9 28 c9 vmovaps zmm1{k1}{z},zmm1
196b: 62 f1 6c 48 c2 c8 06 vcmpnleps k1,zmm2,zmm0
1972: 62 d1 7c 48 29 4c 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x3c0],zmm1
1979: f1
197a: 62 f1 7c c9 28 d2 vmovaps zmm2{k1}{z},zmm2
1980: 62 f1 64 48 c2 c8 06 vcmpnleps k1,zmm3,zmm0
1987: 62 d1 7c 48 29 54 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x380],zmm2
198e: f2
198f: 62 f1 7c c9 28 db vmovaps zmm3{k1}{z},zmm3
1995: 62 f1 5c 48 c2 c8 06 vcmpnleps k1,zmm4,zmm0
199c: 62 d1 7c 48 29 5c 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x340],zmm3
19a3: f3
19a4: 62 f1 7c c9 28 e4 vmovaps zmm4{k1}{z},zmm4
19aa: 62 f1 54 48 c2 c8 06 vcmpnleps k1,zmm5,zmm0
19b1: 62 d1 7c 48 29 64 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x300],zmm4
19b8: f4
19b9: 62 f1 7c c9 28 ed vmovaps zmm5{k1}{z},zmm5
19bf: 62 f1 4c 48 c2 c8 06 vcmpnleps k1,zmm6,zmm0
19c6: 62 d1 7c 48 29 6c 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x2c0],zmm5
19cd: f5
19ce: 62 f1 7c c9 28 f6 vmovaps zmm6{k1}{z},zmm6
19d4: 62 f1 44 48 c2 c8 06 vcmpnleps k1,zmm7,zmm0
19db: 62 d1 7c 48 29 74 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x280],zmm6
19e2: f6
19e3: 62 f1 7c c9 28 ff vmovaps zmm7{k1}{z},zmm7
19e9: 62 f1 3c 48 c2 c8 06 vcmpnleps k1,zmm8,zmm0
19f0: 62 d1 7c 48 29 7c 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x240],zmm7
19f7: f7
19f8: 62 51 7c c9 28 c0 vmovaps zmm8{k1}{z},zmm8
19fe: 62 f1 34 48 c2 c8 06 vcmpnleps k1,zmm9,zmm0
1a05: 62 51 7c 48 29 44 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x200],zmm8
1a0c: f8
1a0d: 62 51 7c c9 28 c9 vmovaps zmm9{k1}{z},zmm9
1a13: 62 f1 2c 48 c2 c8 06 vcmpnleps k1,zmm10,zmm0
1a1a: 62 51 7c 48 29 4c 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x1c0],zmm9
1a21: f9
1a22: 62 51 7c c9 28 d2 vmovaps zmm10{k1}{z},zmm10
1a28: 62 f1 24 48 c2 c8 06 vcmpnleps k1,zmm11,zmm0
1a2f: 62 51 7c 48 29 54 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x180],zmm10
1a36: fa
1a37: 62 51 7c c9 28 db vmovaps zmm11{k1}{z},zmm11
1a3d: 62 f1 1c 48 c2 c8 06 vcmpnleps k1,zmm12,zmm0
1a44: 62 51 7c 48 29 5c 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x140],zmm11
1a4b: fb
1a4c: 62 51 7c c9 28 e4 vmovaps zmm12{k1}{z},zmm12
1a52: 62 f1 14 48 c2 c8 06 vcmpnleps k1,zmm13,zmm0
1a59: 62 51 7c 48 29 64 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x100],zmm12
1a60: fc
1a61: 62 51 7c c9 28 ed vmovaps zmm13{k1}{z},zmm13
1a67: 62 f1 0c 48 c2 c8 06 vcmpnleps k1,zmm14,zmm0
1a6e: 62 51 7c 48 29 6c 14 vmovaps ZMMWORD PTR [r12+rdx*1-0xc0],zmm13
1a75: fd
1a76: 62 51 7c c9 28 f6 vmovaps zmm14{k1}{z},zmm14
1a7c: 62 f1 04 48 c2 c8 06 vcmpnleps k1,zmm15,zmm0
1a83: 62 51 7c 48 29 74 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x80],zmm14
1a8a: fe
1a8b: 62 51 7c c9 28 ff vmovaps zmm15{k1}{z},zmm15
1a91: 62 f1 7c 40 c2 c8 06 vcmpnleps k1,zmm16,zmm0
1a98: 62 51 7c 48 29 7c 14 vmovaps ZMMWORD PTR [r12+rdx*1-0x40],zmm15
1a9f: ff
1aa0: 62 a1 7c c9 28 c0 vmovaps zmm16{k1}{z},zmm16
1aa6: 62 c1 7c 48 29 04 14 vmovaps ZMMWORD PTR [r12+rdx*1],zmm16
1aad: 48 81 c2 00 04 00 00 add rdx,0x400
1ab4: 49 ff cf dec r15
1ab7: 0f 85 a3 fd ff ff jne 1860 <iree_hal_executable_library_query-0x280>
1abd: e9 ee fa ff ff jmp 15b0 <iree_hal_executable_library_query-0x530>
1ac2: 31 c0 xor eax,eax
1ac4: 48 8d 65 d8 lea rsp,[rbp-0x28]
1ac8: 5b pop rbx
1ac9: 41 5c pop r12
1acb: 41 5d pop r13
1acd: 41 5e pop r14
1acf: 41 5f pop r15
1ad1: 5d pop rbp
1ad2: c5 f8 77 vzeroupper
1ad5: c3 ret
1ad6: cc int3
1ad7: cc int3
1ad8: cc int3
1ad9: cc int3
1ada: cc int3
1adb: cc int3
1adc: cc int3
1add: cc int3
1ade: cc int3
1adf: cc int3
0000000000001ae0 <iree_hal_executable_library_query>:
1ae0: 31 c0 xor eax,eax
1ae2: 83 ff 03 cmp edi,0x3
1ae5: 48 8d 0d 54 10 00 00 lea rcx,[rip+0x1054] # 2b40 <iree_hal_executable_library_query+0x1060>
1aec: 48 0f 44 c1 cmove rax,rcx
1af0: c3 ret
is this a duplicate of #16025? (want to make sure they're the same effort - probably worth picking one issue to use?)
The https://github.com/openxla/iree/issues/16025 is more for functionality -- i.e., making sure things are all controlled by ourselves. I think Benoit has more ideas about how it should work in terms of optimization, and we're tracking the optimization on this issue.
#16025 is for fusing in terms of dispatch formation (right?). The present issue is for obtaining optimal object code. So #16025 is a prerequisite.
Yes, that's what I think in my mind. I don't want to enable the fusion while we are doing bad control on tile size selection, or say that's unexpected. And this issue is more for tracking optimization when we enable the fusion.
This makes sense to me. I understand that you plan to limit this fusion to cases where the pack op has been propagating above the producer op, right? Any plans for the consumer counterpart?