iree icon indicating copy to clipboard operation
iree copied to clipboard

Fuse mmt4d ukernel with consumer and get perfect codegen thanks to store-to-load forwarding

Open bjacob opened this issue 1 year ago • 6 comments

mmt4d->consumer fusion is a major codegen optimization opportunity. The typical consumers include the unpack following the mmt4d, and element-wise ops (generics with only parallel iterators) that are typically found as matmul consumers, for instance the bias-additions and activation functions of NN workloads.

ukernels preclude codegen fusion at MLIR level, but they do not preclude late LLVM IR optimizations such as store-to-load forwarding. However, for there to be a store-to-load forwarding opportunity between the final stores in the mmt4d ukernel and the initial loads in the consumer without relying on LLVM loop fusion/interleaving, there must be no nontrivial loop enclosing the stores in the mmt4d ukernel. This is only the case if the outer M/N-dimension shape of the mmt4d ukernel are both constant 1, making the M/N outer loops trivial.

So there are at least two things that need to happen here:

  1. Enable mmt4d->consumer and/or ukernel->consumer fusion (depending on whether that is decided before or after LowerToUKernels).
  2. Set outer tile sizes to 1.

Here is a typical test case (from https://gist.github.com/bjacob/295ae908d4061178a5d1c691f12fd174):

#map_2d_identity = affine_map<(m,n) -> (m, n)>
#map_2d_first_coordinate = affine_map<(m,n)->(m)>
func.func @matmul_bias_relu_dynamic(
      %input_activations : tensor<?x?xf32>,
      %weights : tensor<?x?xf32>,
      %bias : tensor<?xf32>
    ) -> tensor<?x?xf32> {
  %c0f32 = arith.constant 0.0 : f32
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  // Query dimensions of dynamic-size input tensors.
  %m = tensor.dim %input_activations, %c0 : tensor<?x?xf32>
  %n = tensor.dim %weights, %c1 : tensor<?x?xf32>
  // Perform the matrix multiplication.
  %matmul_empty = tensor.empty(%m, %n) : tensor<?x?xf32>
  %matmul_accumulator = linalg.fill ins(%c0f32 : f32) outs(%matmul_empty : tensor<?x?xf32>) -> tensor<?x?xf32>
  %matmul_result = linalg.matmul
      ins(%input_activations, %weights : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%matmul_accumulator : tensor<?x?xf32>) -> tensor<?x?xf32>
  // Perform the bias-addition and ReLU
  %result_empty =  tensor.empty(%m, %n) : tensor<?x?xf32>
  %result = linalg.generic {
    indexing_maps=[
      #map_2d_identity,
      #map_2d_first_coordinate,
      #map_2d_identity
    ],
    iterator_types=["parallel", "parallel"]
  } ins(%matmul_result, %bias : tensor<?x?xf32>, tensor<?xf32>)
    outs(%result_empty : tensor<?x?xf32>) {
    ^bb0(%matmul_entry : f32, %bias_entry : f32, %unused_result_entry : f32):
      %add = arith.addf %matmul_entry, %bias_entry : f32
      %relu = arith.maximumf %add, %c0f32 : f32
      linalg.yield %relu : f32
  } -> tensor<?x?xf32>
  return %result : tensor<?x?xf32>
}

EDIT:

Tile size 1 is not going to be optimal for other reasons (it will result in too many, too-small dispatach function calls). Perhaps once we've secured the ability to get perfect codegen with tile size 1, the next question is how do we reconcile that with suitably sized dispatch function calls.

Oh I think I know... we need to replace "mmt4d ukernel with outer sizes M, N" by "scf.for loops with trip counts M,N around mmt4d ukernel with outer sizes 1,1". Then this has nothing to do anymore with distribution tile sizes.

But this idea relies on the ability to perform a fusion of scf.for loops ?

bjacob avatar Jan 24 '24 16:01 bjacob

I generated an example using cpu=znver4. In your example, I think we will get something like for fusion (when we enable unpack propagation):

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
module {
  func.func @mmt4d_bias_relu_fusion(%arg0: tensor<?x?x16x1xf32>, %arg1: tensor<?x?x16x1xf32>, %arg2: tensor<?x16xf32>) -> tensor<?x?x16x16xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?x16x1xf32>
    %dim_0 = tensor.dim %arg1, %c0 : tensor<?x?x16x1xf32>
    %0 = tensor.empty(%dim, %dim_0) : tensor<?x?x16x16xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
    %2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<?x?x16x1xf32>, tensor<?x?x16x1xf32>) outs(%1 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
    %3 = tensor.empty(%dim, %dim_0) : tensor<?x?x16x16xf32>
    %4 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2, %arg2 : tensor<?x?x16x16xf32>, tensor<?x16xf32>) outs(%3 : tensor<?x?x16x16xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %5 = arith.addf %in, %in_1 : f32
      %6 = arith.maximumf %5, %cst : f32
      linalg.yield %6 : f32
    } -> tensor<?x?x16x16xf32>
    return %4 : tensor<?x?x16x16xf32>
  }
}

Input IR (starting from executable-source:

hal.executable public @mmt4d_bias_relu_fusion_dispatch_0 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+mwaitx,+pku,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sha,+shstk,+vaes,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,+evex512", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "all"}>) {
    hal.executable.export public @mmt4d_bias_relu_fusion_dispatch_0_generic_DxDx16x16_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 10, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<Mmt4dTilingExpert>} {
    ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3, %arg4, %arg5
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @mmt4d_bias_relu_fusion_dispatch_0_generic_DxDx16x16_f32() {
        %c0 = arith.constant 0 : index
        %c32_i64 = arith.constant 32 : i64
        %cst = arith.constant 0.000000e+00 : f32
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = hal.interface.constant.load[2] : i32
        %3 = hal.interface.constant.load[3] : i32
        %4 = hal.interface.constant.load[4] : i32
        %5 = hal.interface.constant.load[5] : i32
        %6 = hal.interface.constant.load[6] : i32
        %7 = hal.interface.constant.load[7] : i32
        %8 = hal.interface.constant.load[8] : i32
        %9 = hal.interface.constant.load[9] : i32
        %10 = arith.extui %0 : i32 to i64
        %11 = arith.extui %1 : i32 to i64
        %12 = arith.shli %11, %c32_i64 : i64
        %13 = arith.ori %10, %12 : i64
        %14 = arith.index_castui %13 : i64 to index
        %15 = arith.extui %2 : i32 to i64
        %16 = arith.extui %3 : i32 to i64
        %17 = arith.shli %16, %c32_i64 : i64
        %18 = arith.ori %15, %17 : i64
        %19 = arith.index_castui %18 : i64 to index
        %20 = arith.extui %4 : i32 to i64
        %21 = arith.extui %5 : i32 to i64
        %22 = arith.shli %21, %c32_i64 : i64
        %23 = arith.ori %20, %22 : i64
        %24 = arith.index_castui %23 : i64 to index
        %25 = arith.extui %6 : i32 to i64
        %26 = arith.extui %7 : i32 to i64
        %27 = arith.shli %26, %c32_i64 : i64
        %28 = arith.ori %25, %27 : i64
        %29 = arith.index_castui %28 : i64 to index
        %30 = arith.extui %8 : i32 to i64
        %31 = arith.extui %9 : i32 to i64
        %32 = arith.shli %31, %c32_i64 : i64
        %33 = arith.ori %30, %32 : i64
        %34 = arith.index_castui %33 : i64 to index
        %35 = flow.dispatch.workload.ordinal %14, 0 : index
        %36 = flow.dispatch.workload.ordinal %19, 1 : index
        %37 = flow.dispatch.workload.ordinal %24, 2 : index
        %38 = flow.dispatch.workload.ordinal %29, 3 : index
        %39 = flow.dispatch.workload.ordinal %34, 4 : index
        %40 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%38, %35}
        %41 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%39, %36}
        %42 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x16xf32>>{%37}
        %43 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<?x?x16x16xf32>>{%38, %39}
        %44 = flow.dispatch.tensor.load %40, offsets = [0, 0, 0, 0], sizes = [%38, %35, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%38, %35} -> tensor<?x?x16x1xf32>
        %45 = flow.dispatch.tensor.load %41, offsets = [0, 0, 0, 0], sizes = [%39, %36, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%39, %36} -> tensor<?x?x16x1xf32>
        %46 = flow.dispatch.tensor.load %42, offsets = [0, 0], sizes = [%37, 16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x16xf32>>{%37} -> tensor<?x16xf32>
        %47 = tensor.empty(%38, %39) : tensor<?x?x16x16xf32>
        %48 = linalg.fill ins(%cst : f32) outs(%47 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
        %49 = linalg.mmt4d {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]>} ins(%44, %45 : tensor<?x?x16x1xf32>, tensor<?x?x16x1xf32>) outs(%48 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
        %50 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
                                               affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
                                               affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
                              iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
          ins(%49, %46 : tensor<?x?x16x16xf32>, tensor<?x16xf32>)
          outs(%47 : tensor<?x?x16x16xf32>)
          attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 2, 0, 0], [1, 1, 16, 16], [0, 0, 0, 0], [0, 0, 0, 0]]>}
          {
        ^bb0(%in: f32, %in_0: f32, %out: f32):
          %51 = arith.addf %in, %in_0 : f32
          %52 = arith.maximumf %51, %cst : f32
          linalg.yield %52 : f32
        } -> tensor<?x?x16x16xf32>
        flow.dispatch.tensor.store %50, %43, offsets = [0, 0, 0, 0], sizes = [%38, %39, 16, 16], strides = [1, 1, 1, 1] : tensor<?x?x16x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?x16x16xf32>>{%38, %39}
        return
      }
    }
  }
}

The mmt4d op will be replaced with ukernel, so only the first list is important. The tile sizes used in TileAndFuse is the second list in generic op.

To compile the dispatch, run iree-compile --output-format=vm-bytecode --iree-llvmcpu-enable-ukernels=all ~/module_mmt4d_fusion_dispatch_0.mlir -o /tmp/a.vmfb --compile-from=executable-sources --iree-llvmcpu-keep-linker-artifacts --iree-llvmcpu-link-embedded=false

Then we can get the asm:


/tmp/mmt4d_bias_relu_fusion_dispatch_0-e687d2.so:     file format elf64-x86-64


Disassembly of section .text:

0000000000001440 <iree_hal_executable_library_query-0x6a0>:
    1440:	55                   	push   rbp
    1441:	48 89 e5             	mov    rbp,rsp
    1444:	41 57                	push   r15
    1446:	41 56                	push   r14
    1448:	41 55                	push   r13
    144a:	41 54                	push   r12
    144c:	53                   	push   rbx
    144d:	48 83 e4 c0          	and    rsp,0xffffffffffffffc0
    1451:	48 81 ec 00 09 00 00 	sub    rsp,0x900
    1458:	48 8b 4e 18          	mov    rcx,QWORD PTR [rsi+0x18]
    145c:	8b 7a 04             	mov    edi,DWORD PTR [rdx+0x4]
    145f:	48 8b 41 18          	mov    rax,QWORD PTR [rcx+0x18]
    1463:	48 89 7c 24 40       	mov    QWORD PTR [rsp+0x40],rdi
    1468:	48 89 84 24 88 00 00 	mov    QWORD PTR [rsp+0x88],rax
    146f:	00 
    1470:	48 39 f8             	cmp    rax,rdi
    1473:	0f 8e 49 06 00 00    	jle    1ac2 <iree_hal_executable_library_query-0x1e>
    1479:	48 8b 46 20          	mov    rax,QWORD PTR [rsi+0x20]
    147d:	4c 8b 49 20          	mov    r9,QWORD PTR [rcx+0x20]
    1481:	44 8b 69 08          	mov    r13d,DWORD PTR [rcx+0x8]
    1485:	44 8b 46 10          	mov    r8d,DWORD PTR [rsi+0x10]
    1489:	8b 7e 0c             	mov    edi,DWORD PTR [rsi+0xc]
    148c:	8b 71 0c             	mov    esi,DWORD PTR [rcx+0xc]
    148f:	4c 8b 11             	mov    r10,QWORD PTR [rcx]
    1492:	c5 f8 57 c0          	vxorps xmm0,xmm0,xmm0
    1496:	4c 8b 58 08          	mov    r11,QWORD PTR [rax+0x8]
    149a:	48 8b 08             	mov    rcx,QWORD PTR [rax]
    149d:	4c 8b 70 10          	mov    r14,QWORD PTR [rax+0x10]
    14a1:	48 c1 e6 29          	shl    rsi,0x29
    14a5:	49 c1 e5 09          	shl    r13,0x9
    14a9:	4c 89 84 24 80 00 00 	mov    QWORD PTR [rsp+0x80],r8
    14b0:	00 
    14b1:	4c 89 8c 24 98 00 00 	mov    QWORD PTR [rsp+0x98],r9
    14b8:	00 
    14b9:	49 09 f5             	or     r13,rsi
    14bc:	4c 89 ac 24 a0 00 00 	mov    QWORD PTR [rsp+0xa0],r13
    14c3:	00 
    14c4:	49 c1 fd 03          	sar    r13,0x3
    14c8:	4c 89 5c 24 58       	mov    QWORD PTR [rsp+0x58],r11
    14cd:	4d 89 cb             	mov    r11,r9
    14d0:	49 c1 e3 08          	shl    r11,0x8
    14d4:	48 89 4c 24 70       	mov    QWORD PTR [rsp+0x70],rcx
    14d9:	8b 0a                	mov    ecx,DWORD PTR [rdx]
    14db:	4c 89 da             	mov    rdx,r11
    14de:	48 0f af 54 24 40    	imul   rdx,QWORD PTR [rsp+0x40]
    14e4:	4d 0f af d8          	imul   r11,r8
    14e8:	49 c1 e3 02          	shl    r11,0x2
    14ec:	48 8d 1c 09          	lea    rbx,[rcx+rcx*1]
    14f0:	48 c1 e1 0b          	shl    rcx,0xb
    14f4:	4c 89 5c 24 78       	mov    QWORD PTR [rsp+0x78],r11
    14f9:	48 89 5c 24 50       	mov    QWORD PTR [rsp+0x50],rbx
    14fe:	48 8d 0c 91          	lea    rcx,[rcx+rdx*4]
    1502:	48 03 48 18          	add    rcx,QWORD PTR [rax+0x18]
    1506:	4c 89 c8             	mov    rax,r9
    1509:	48 29 d8             	sub    rax,rbx
    150c:	48 89 44 24 68       	mov    QWORD PTR [rsp+0x68],rax
    1511:	4c 89 d0             	mov    rax,r10
    1514:	48 c1 e0 09          	shl    rax,0x9
    1518:	48 89 44 24 60       	mov    QWORD PTR [rsp+0x60],rax
    151d:	48 8d 04 3f          	lea    rax,[rdi+rdi*1]
    1521:	48 c1 e7 0b          	shl    rdi,0xb
    1525:	48 89 84 24 90 00 00 	mov    QWORD PTR [rsp+0x90],rax
    152c:	00 
    152d:	48 89 bc 24 a8 00 00 	mov    QWORD PTR [rsp+0xa8],rdi
    1534:	00 
    1535:	48 89 4c 24 48       	mov    QWORD PTR [rsp+0x48],rcx
    153a:	eb 33                	jmp    156f <iree_hal_executable_library_query-0x571>
    153c:	0f 1f 40 00          	nop    DWORD PTR [rax+0x0]
    1540:	48 8b 44 24 40       	mov    rax,QWORD PTR [rsp+0x40]
    1545:	48 8b 4c 24 48       	mov    rcx,QWORD PTR [rsp+0x48]
    154a:	48 03 84 24 80 00 00 	add    rax,QWORD PTR [rsp+0x80]
    1551:	00 
    1552:	48 03 4c 24 78       	add    rcx,QWORD PTR [rsp+0x78]
    1557:	48 89 4c 24 48       	mov    QWORD PTR [rsp+0x48],rcx
    155c:	48 89 44 24 40       	mov    QWORD PTR [rsp+0x40],rax
    1561:	48 3b 84 24 88 00 00 	cmp    rax,QWORD PTR [rsp+0x88]
    1568:	00 
    1569:	0f 8d 53 05 00 00    	jge    1ac2 <iree_hal_executable_library_query-0x1e>
    156f:	4c 3b 4c 24 50       	cmp    r9,QWORD PTR [rsp+0x50]
    1574:	7e ca                	jle    1540 <iree_hal_executable_library_query-0x5a0>
    1576:	48 8b 44 24 60       	mov    rax,QWORD PTR [rsp+0x60]
    157b:	48 8b 4c 24 40       	mov    rcx,QWORD PTR [rsp+0x40]
    1580:	4c 8b 64 24 48       	mov    r12,QWORD PTR [rsp+0x48]
    1585:	48 8b 54 24 68       	mov    rdx,QWORD PTR [rsp+0x68]
    158a:	48 8b 7c 24 50       	mov    rdi,QWORD PTR [rsp+0x50]
    158f:	48 0f af c1          	imul   rax,rcx
    1593:	48 c1 e1 04          	shl    rcx,0x4
    1597:	48 c1 f8 03          	sar    rax,0x3
    159b:	48 03 44 24 70       	add    rax,QWORD PTR [rsp+0x70]
    15a0:	eb 48                	jmp    15ea <iree_hal_executable_library_query-0x4f6>
    15a2:	66 66 66 66 66 2e 0f 	data16 data16 data16 data16 cs nop WORD PTR [rax+rax*1+0x0]
    15a9:	1f 84 00 00 00 00 00 
    15b0:	48 8b 94 24 90 00 00 	mov    rdx,QWORD PTR [rsp+0x90]
    15b7:	00 
    15b8:	48 8b bc 24 b0 00 00 	mov    rdi,QWORD PTR [rsp+0xb0]
    15bf:	00 
    15c0:	48 8b b4 24 b8 00 00 	mov    rsi,QWORD PTR [rsp+0xb8]
    15c7:	00 
    15c8:	4c 03 a4 24 a8 00 00 	add    r12,QWORD PTR [rsp+0xa8]
    15cf:	00 
    15d0:	4c 8b 8c 24 98 00 00 	mov    r9,QWORD PTR [rsp+0x98]
    15d7:	00 
    15d8:	48 01 d7             	add    rdi,rdx
    15db:	48 29 d6             	sub    rsi,rdx
    15de:	48 89 f2             	mov    rdx,rsi
    15e1:	4c 39 cf             	cmp    rdi,r9
    15e4:	0f 8d 56 ff ff ff    	jge    1540 <iree_hal_executable_library_query-0x5a0>
    15ea:	48 8b b4 24 a0 00 00 	mov    rsi,QWORD PTR [rsp+0xa0]
    15f1:	00 
    15f2:	45 31 ff             	xor    r15d,r15d
    15f5:	48 83 fa 02          	cmp    rdx,0x2
    15f9:	48 89 94 24 b8 00 00 	mov    QWORD PTR [rsp+0xb8],rdx
    1600:	00 
    1601:	4c 89 ca             	mov    rdx,r9
    1604:	bb 02 00 00 00       	mov    ebx,0x2
    1609:	4c 8b 44 24 58       	mov    r8,QWORD PTR [rsp+0x58]
    160e:	0f 0d 8c 24 c0 00 00 	prefetchw BYTE PTR [rsp+0xc0]
    1615:	00 
    1616:	0f 18 08             	prefetcht0 BYTE PTR [rax]
    1619:	48 89 bc 24 b0 00 00 	mov    QWORD PTR [rsp+0xb0],rdi
    1620:	00 
    1621:	0f 9d 44 24 3f       	setge  BYTE PTR [rsp+0x3f]
    1626:	48 29 fa             	sub    rdx,rdi
    1629:	48 83 fa 02          	cmp    rdx,0x2
    162d:	48 0f 4c da          	cmovl  rbx,rdx
    1631:	48 0f af f7          	imul   rsi,rdi
    1635:	48 c1 fe 03          	sar    rsi,0x3
    1639:	41 0f 18 0c 30       	prefetcht0 BYTE PTR [r8+rsi*1]
    163e:	85 db                	test   ebx,ebx
    1640:	0f 8e fa 01 00 00    	jle    1840 <iree_hal_executable_library_query-0x2a0>
    1646:	48 03 74 24 58       	add    rsi,QWORD PTR [rsp+0x58]
    164b:	48 8d bc 24 c0 00 00 	lea    rdi,[rsp+0xc0]
    1652:	00 
    1653:	45 31 db             	xor    r11d,r11d
    1656:	e9 d0 00 00 00       	jmp    172b <iree_hal_executable_library_query-0x3b5>
    165b:	0f 1f 44 00 00       	nop    DWORD PTR [rax+rax*1+0x0]
    1660:	c5 e8 57 d2          	vxorps xmm2,xmm2,xmm2
    1664:	c5 e0 57 db          	vxorps xmm3,xmm3,xmm3
    1668:	c5 d8 57 e4          	vxorps xmm4,xmm4,xmm4
    166c:	c5 d0 57 ed          	vxorps xmm5,xmm5,xmm5
    1670:	c5 c8 57 f6          	vxorps xmm6,xmm6,xmm6
    1674:	c5 c0 57 ff          	vxorps xmm7,xmm7,xmm7
    1678:	c4 41 38 57 c0       	vxorps xmm8,xmm8,xmm8
    167d:	c4 41 30 57 c9       	vxorps xmm9,xmm9,xmm9
    1682:	c4 41 28 57 d2       	vxorps xmm10,xmm10,xmm10
    1687:	c4 41 20 57 db       	vxorps xmm11,xmm11,xmm11
    168c:	c4 41 18 57 e4       	vxorps xmm12,xmm12,xmm12
    1691:	c4 41 10 57 ed       	vxorps xmm13,xmm13,xmm13
    1696:	c4 41 08 57 f6       	vxorps xmm14,xmm14,xmm14
    169b:	c4 41 00 57 ff       	vxorps xmm15,xmm15,xmm15
    16a0:	62 a1 7c 00 57 c0    	vxorps xmm16,xmm16,xmm16
    16a6:	62 e1 7c 48 11 07    	vmovups ZMMWORD PTR [rdi],zmm16
    16ac:	62 71 7c 48 11 7f 01 	vmovups ZMMWORD PTR [rdi+0x40],zmm15
    16b3:	62 71 7c 48 11 77 02 	vmovups ZMMWORD PTR [rdi+0x80],zmm14
    16ba:	62 71 7c 48 11 6f 03 	vmovups ZMMWORD PTR [rdi+0xc0],zmm13
    16c1:	62 71 7c 48 11 67 04 	vmovups ZMMWORD PTR [rdi+0x100],zmm12
    16c8:	62 71 7c 48 11 5f 05 	vmovups ZMMWORD PTR [rdi+0x140],zmm11
    16cf:	62 71 7c 48 11 57 06 	vmovups ZMMWORD PTR [rdi+0x180],zmm10
    16d6:	62 71 7c 48 11 4f 07 	vmovups ZMMWORD PTR [rdi+0x1c0],zmm9
    16dd:	62 71 7c 48 11 47 08 	vmovups ZMMWORD PTR [rdi+0x200],zmm8
    16e4:	62 f1 7c 48 11 7f 09 	vmovups ZMMWORD PTR [rdi+0x240],zmm7
    16eb:	62 f1 7c 48 11 77 0a 	vmovups ZMMWORD PTR [rdi+0x280],zmm6
    16f2:	62 f1 7c 48 11 6f 0b 	vmovups ZMMWORD PTR [rdi+0x2c0],zmm5
    16f9:	62 f1 7c 48 11 67 0c 	vmovups ZMMWORD PTR [rdi+0x300],zmm4
    1700:	62 f1 7c 48 11 5f 0d 	vmovups ZMMWORD PTR [rdi+0x340],zmm3
    1707:	62 f1 7c 48 11 57 0e 	vmovups ZMMWORD PTR [rdi+0x380],zmm2
    170e:	62 f1 7c 48 11 4f 0f 	vmovups ZMMWORD PTR [rdi+0x3c0],zmm1
    1715:	48 81 c7 00 04 00 00 	add    rdi,0x400
    171c:	4c 01 ee             	add    rsi,r13
    171f:	41 ff c3             	inc    r11d
    1722:	41 39 db             	cmp    r11d,ebx
    1725:	0f 84 15 01 00 00    	je     1840 <iree_hal_executable_library_query-0x2a0>
    172b:	0f 18 08             	prefetcht0 BYTE PTR [rax]
    172e:	0f 18 0e             	prefetcht0 BYTE PTR [rsi]
    1731:	c5 f0 57 c9          	vxorps xmm1,xmm1,xmm1
    1735:	4d 85 d2             	test   r10,r10
    1738:	0f 8e 22 ff ff ff    	jle    1660 <iree_hal_executable_library_query-0x480>
    173e:	c5 e8 57 d2          	vxorps xmm2,xmm2,xmm2
    1742:	c5 e0 57 db          	vxorps xmm3,xmm3,xmm3
    1746:	c5 d8 57 e4          	vxorps xmm4,xmm4,xmm4
    174a:	c5 d0 57 ed          	vxorps xmm5,xmm5,xmm5
    174e:	c5 c8 57 f6          	vxorps xmm6,xmm6,xmm6
    1752:	c5 c0 57 ff          	vxorps xmm7,xmm7,xmm7
    1756:	c4 41 38 57 c0       	vxorps xmm8,xmm8,xmm8
    175b:	c4 41 30 57 c9       	vxorps xmm9,xmm9,xmm9
    1760:	c4 41 28 57 d2       	vxorps xmm10,xmm10,xmm10
    1765:	c4 41 20 57 db       	vxorps xmm11,xmm11,xmm11
    176a:	c4 41 18 57 e4       	vxorps xmm12,xmm12,xmm12
    176f:	c4 41 10 57 ed       	vxorps xmm13,xmm13,xmm13
    1774:	c4 41 08 57 f6       	vxorps xmm14,xmm14,xmm14
    1779:	c4 41 00 57 ff       	vxorps xmm15,xmm15,xmm15
    177e:	62 a1 7c 00 57 c0    	vxorps xmm16,xmm16,xmm16
    1784:	45 31 c0             	xor    r8d,r8d
    1787:	4d 89 d1             	mov    r9,r10
    178a:	66 0f 1f 44 00 00    	nop    WORD PTR [rax+rax*1+0x0]
    1790:	62 a1 7c 48 10 0c 06 	vmovups zmm17,ZMMWORD PTR [rsi+r8*1]
    1797:	42 0f 18 8c 06 00 02 	prefetcht0 BYTE PTR [rsi+r8*1+0x200]
    179e:	00 00 
    17a0:	62 a2 75 50 b8 04 00 	vfmadd231ps zmm16,zmm17,DWORD PTR [rax+r8*1]{1to16}
    17a7:	62 32 75 50 b8 7c 00 	vfmadd231ps zmm15,zmm17,DWORD PTR [rax+r8*1+0x4]{1to16}
    17ae:	01 
    17af:	62 32 75 50 b8 74 00 	vfmadd231ps zmm14,zmm17,DWORD PTR [rax+r8*1+0x8]{1to16}
    17b6:	02 
    17b7:	62 32 75 50 b8 6c 00 	vfmadd231ps zmm13,zmm17,DWORD PTR [rax+r8*1+0xc]{1to16}
    17be:	03 
    17bf:	62 32 75 50 b8 64 00 	vfmadd231ps zmm12,zmm17,DWORD PTR [rax+r8*1+0x10]{1to16}
    17c6:	04 
    17c7:	62 32 75 50 b8 5c 00 	vfmadd231ps zmm11,zmm17,DWORD PTR [rax+r8*1+0x14]{1to16}
    17ce:	05 
    17cf:	62 32 75 50 b8 54 00 	vfmadd231ps zmm10,zmm17,DWORD PTR [rax+r8*1+0x18]{1to16}
    17d6:	06 
    17d7:	62 32 75 50 b8 4c 00 	vfmadd231ps zmm9,zmm17,DWORD PTR [rax+r8*1+0x1c]{1to16}
    17de:	07 
    17df:	62 32 75 50 b8 44 00 	vfmadd231ps zmm8,zmm17,DWORD PTR [rax+r8*1+0x20]{1to16}
    17e6:	08 
    17e7:	62 b2 75 50 b8 7c 00 	vfmadd231ps zmm7,zmm17,DWORD PTR [rax+r8*1+0x24]{1to16}
    17ee:	09 
    17ef:	62 b2 75 50 b8 74 00 	vfmadd231ps zmm6,zmm17,DWORD PTR [rax+r8*1+0x28]{1to16}
    17f6:	0a 
    17f7:	62 b2 75 50 b8 6c 00 	vfmadd231ps zmm5,zmm17,DWORD PTR [rax+r8*1+0x2c]{1to16}
    17fe:	0b 
    17ff:	62 b2 75 50 b8 64 00 	vfmadd231ps zmm4,zmm17,DWORD PTR [rax+r8*1+0x30]{1to16}
    1806:	0c 
    1807:	62 b2 75 50 b8 5c 00 	vfmadd231ps zmm3,zmm17,DWORD PTR [rax+r8*1+0x34]{1to16}
    180e:	0d 
    180f:	62 b2 75 50 b8 54 00 	vfmadd231ps zmm2,zmm17,DWORD PTR [rax+r8*1+0x38]{1to16}
    1816:	0e 
    1817:	62 b2 75 50 b8 4c 00 	vfmadd231ps zmm1,zmm17,DWORD PTR [rax+r8*1+0x3c]{1to16}
    181e:	0f 
    181f:	42 0f 18 8c 00 00 02 	prefetcht0 BYTE PTR [rax+r8*1+0x200]
    1826:	00 00 
    1828:	49 83 c0 40          	add    r8,0x40
    182c:	49 ff c9             	dec    r9
    182f:	0f 85 5b ff ff ff    	jne    1790 <iree_hal_executable_library_query-0x350>
    1835:	e9 6c fe ff ff       	jmp    16a6 <iree_hal_executable_library_query-0x43a>
    183a:	66 0f 1f 44 00 00    	nop    WORD PTR [rax+rax*1+0x0]
    1840:	48 85 d2             	test   rdx,rdx
    1843:	0f 8e 67 fd ff ff    	jle    15b0 <iree_hal_executable_library_query-0x530>
    1849:	0f b6 54 24 3f       	movzx  edx,BYTE PTR [rsp+0x3f]
    184e:	41 88 d7             	mov    r15b,dl
    1851:	ba c0 03 00 00       	mov    edx,0x3c0
    1856:	49 ff c7             	inc    r15
    1859:	0f 1f 80 00 00 00 00 	nop    DWORD PTR [rax+0x0]
    1860:	62 f1 7c 48 28 4c 14 	vmovaps zmm1,ZMMWORD PTR [rsp+rdx*1-0x300]
    1867:	f4 
    1868:	62 f1 7c 48 28 54 14 	vmovaps zmm2,ZMMWORD PTR [rsp+rdx*1-0x2c0]
    186f:	f5 
    1870:	62 f1 7c 48 28 5c 14 	vmovaps zmm3,ZMMWORD PTR [rsp+rdx*1-0x280]
    1877:	f6 
    1878:	62 f1 7c 48 28 64 14 	vmovaps zmm4,ZMMWORD PTR [rsp+rdx*1-0x240]
    187f:	f7 
    1880:	62 f1 7c 48 28 6c 14 	vmovaps zmm5,ZMMWORD PTR [rsp+rdx*1-0x200]
    1887:	f8 
    1888:	62 f1 7c 48 28 74 14 	vmovaps zmm6,ZMMWORD PTR [rsp+rdx*1-0x1c0]
    188f:	f9 
    1890:	62 f1 7c 48 28 7c 14 	vmovaps zmm7,ZMMWORD PTR [rsp+rdx*1-0x180]
    1897:	fa 
    1898:	62 71 7c 48 28 44 14 	vmovaps zmm8,ZMMWORD PTR [rsp+rdx*1-0x140]
    189f:	fb 
    18a0:	62 71 7c 48 28 4c 14 	vmovaps zmm9,ZMMWORD PTR [rsp+rdx*1-0x100]
    18a7:	fc 
    18a8:	62 71 7c 48 28 54 14 	vmovaps zmm10,ZMMWORD PTR [rsp+rdx*1-0xc0]
    18af:	fd 
    18b0:	62 71 7c 48 28 5c 14 	vmovaps zmm11,ZMMWORD PTR [rsp+rdx*1-0x80]
    18b7:	fe 
    18b8:	62 71 7c 48 28 64 14 	vmovaps zmm12,ZMMWORD PTR [rsp+rdx*1-0x40]
    18bf:	ff 
    18c0:	62 71 7c 48 28 2c 14 	vmovaps zmm13,ZMMWORD PTR [rsp+rdx*1]
    18c7:	62 71 7c 48 28 74 14 	vmovaps zmm14,ZMMWORD PTR [rsp+rdx*1+0x40]
    18ce:	01 
    18cf:	62 71 7c 48 28 7c 14 	vmovaps zmm15,ZMMWORD PTR [rsp+rdx*1+0x80]
    18d6:	02 
    18d7:	62 e1 7c 48 28 44 14 	vmovaps zmm16,ZMMWORD PTR [rsp+rdx*1+0xc0]
    18de:	03 
    18df:	62 d1 74 58 58 0c 8e 	vaddps zmm1,zmm1,DWORD PTR [r14+rcx*4]{1to16}
    18e6:	62 d1 6c 58 58 54 8e 	vaddps zmm2,zmm2,DWORD PTR [r14+rcx*4+0x4]{1to16}
    18ed:	01 
    18ee:	62 d1 64 58 58 5c 8e 	vaddps zmm3,zmm3,DWORD PTR [r14+rcx*4+0x8]{1to16}
    18f5:	02 
    18f6:	62 d1 5c 58 58 64 8e 	vaddps zmm4,zmm4,DWORD PTR [r14+rcx*4+0xc]{1to16}
    18fd:	03 
    18fe:	62 d1 54 58 58 6c 8e 	vaddps zmm5,zmm5,DWORD PTR [r14+rcx*4+0x10]{1to16}
    1905:	04 
    1906:	62 d1 4c 58 58 74 8e 	vaddps zmm6,zmm6,DWORD PTR [r14+rcx*4+0x14]{1to16}
    190d:	05 
    190e:	62 d1 44 58 58 7c 8e 	vaddps zmm7,zmm7,DWORD PTR [r14+rcx*4+0x18]{1to16}
    1915:	06 
    1916:	62 51 3c 58 58 44 8e 	vaddps zmm8,zmm8,DWORD PTR [r14+rcx*4+0x1c]{1to16}
    191d:	07 
    191e:	62 51 34 58 58 4c 8e 	vaddps zmm9,zmm9,DWORD PTR [r14+rcx*4+0x20]{1to16}
    1925:	08 
    1926:	62 51 2c 58 58 54 8e 	vaddps zmm10,zmm10,DWORD PTR [r14+rcx*4+0x24]{1to16}
    192d:	09 
    192e:	62 51 24 58 58 5c 8e 	vaddps zmm11,zmm11,DWORD PTR [r14+rcx*4+0x28]{1to16}
    1935:	0a 
    1936:	62 51 1c 58 58 64 8e 	vaddps zmm12,zmm12,DWORD PTR [r14+rcx*4+0x2c]{1to16}
    193d:	0b 
    193e:	62 51 14 58 58 6c 8e 	vaddps zmm13,zmm13,DWORD PTR [r14+rcx*4+0x30]{1to16}
    1945:	0c 
    1946:	62 51 0c 58 58 74 8e 	vaddps zmm14,zmm14,DWORD PTR [r14+rcx*4+0x34]{1to16}
    194d:	0d 
    194e:	62 51 04 58 58 7c 8e 	vaddps zmm15,zmm15,DWORD PTR [r14+rcx*4+0x38]{1to16}
    1955:	0e 
    1956:	62 c1 7c 50 58 44 8e 	vaddps zmm16,zmm16,DWORD PTR [r14+rcx*4+0x3c]{1to16}
    195d:	0f 
    195e:	62 f1 74 48 c2 c8 06 	vcmpnleps k1,zmm1,zmm0
    1965:	62 f1 7c c9 28 c9    	vmovaps zmm1{k1}{z},zmm1
    196b:	62 f1 6c 48 c2 c8 06 	vcmpnleps k1,zmm2,zmm0
    1972:	62 d1 7c 48 29 4c 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x3c0],zmm1
    1979:	f1 
    197a:	62 f1 7c c9 28 d2    	vmovaps zmm2{k1}{z},zmm2
    1980:	62 f1 64 48 c2 c8 06 	vcmpnleps k1,zmm3,zmm0
    1987:	62 d1 7c 48 29 54 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x380],zmm2
    198e:	f2 
    198f:	62 f1 7c c9 28 db    	vmovaps zmm3{k1}{z},zmm3
    1995:	62 f1 5c 48 c2 c8 06 	vcmpnleps k1,zmm4,zmm0
    199c:	62 d1 7c 48 29 5c 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x340],zmm3
    19a3:	f3 
    19a4:	62 f1 7c c9 28 e4    	vmovaps zmm4{k1}{z},zmm4
    19aa:	62 f1 54 48 c2 c8 06 	vcmpnleps k1,zmm5,zmm0
    19b1:	62 d1 7c 48 29 64 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x300],zmm4
    19b8:	f4 
    19b9:	62 f1 7c c9 28 ed    	vmovaps zmm5{k1}{z},zmm5
    19bf:	62 f1 4c 48 c2 c8 06 	vcmpnleps k1,zmm6,zmm0
    19c6:	62 d1 7c 48 29 6c 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x2c0],zmm5
    19cd:	f5 
    19ce:	62 f1 7c c9 28 f6    	vmovaps zmm6{k1}{z},zmm6
    19d4:	62 f1 44 48 c2 c8 06 	vcmpnleps k1,zmm7,zmm0
    19db:	62 d1 7c 48 29 74 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x280],zmm6
    19e2:	f6 
    19e3:	62 f1 7c c9 28 ff    	vmovaps zmm7{k1}{z},zmm7
    19e9:	62 f1 3c 48 c2 c8 06 	vcmpnleps k1,zmm8,zmm0
    19f0:	62 d1 7c 48 29 7c 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x240],zmm7
    19f7:	f7 
    19f8:	62 51 7c c9 28 c0    	vmovaps zmm8{k1}{z},zmm8
    19fe:	62 f1 34 48 c2 c8 06 	vcmpnleps k1,zmm9,zmm0
    1a05:	62 51 7c 48 29 44 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x200],zmm8
    1a0c:	f8 
    1a0d:	62 51 7c c9 28 c9    	vmovaps zmm9{k1}{z},zmm9
    1a13:	62 f1 2c 48 c2 c8 06 	vcmpnleps k1,zmm10,zmm0
    1a1a:	62 51 7c 48 29 4c 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x1c0],zmm9
    1a21:	f9 
    1a22:	62 51 7c c9 28 d2    	vmovaps zmm10{k1}{z},zmm10
    1a28:	62 f1 24 48 c2 c8 06 	vcmpnleps k1,zmm11,zmm0
    1a2f:	62 51 7c 48 29 54 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x180],zmm10
    1a36:	fa 
    1a37:	62 51 7c c9 28 db    	vmovaps zmm11{k1}{z},zmm11
    1a3d:	62 f1 1c 48 c2 c8 06 	vcmpnleps k1,zmm12,zmm0
    1a44:	62 51 7c 48 29 5c 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x140],zmm11
    1a4b:	fb 
    1a4c:	62 51 7c c9 28 e4    	vmovaps zmm12{k1}{z},zmm12
    1a52:	62 f1 14 48 c2 c8 06 	vcmpnleps k1,zmm13,zmm0
    1a59:	62 51 7c 48 29 64 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x100],zmm12
    1a60:	fc 
    1a61:	62 51 7c c9 28 ed    	vmovaps zmm13{k1}{z},zmm13
    1a67:	62 f1 0c 48 c2 c8 06 	vcmpnleps k1,zmm14,zmm0
    1a6e:	62 51 7c 48 29 6c 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0xc0],zmm13
    1a75:	fd 
    1a76:	62 51 7c c9 28 f6    	vmovaps zmm14{k1}{z},zmm14
    1a7c:	62 f1 04 48 c2 c8 06 	vcmpnleps k1,zmm15,zmm0
    1a83:	62 51 7c 48 29 74 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x80],zmm14
    1a8a:	fe 
    1a8b:	62 51 7c c9 28 ff    	vmovaps zmm15{k1}{z},zmm15
    1a91:	62 f1 7c 40 c2 c8 06 	vcmpnleps k1,zmm16,zmm0
    1a98:	62 51 7c 48 29 7c 14 	vmovaps ZMMWORD PTR [r12+rdx*1-0x40],zmm15
    1a9f:	ff 
    1aa0:	62 a1 7c c9 28 c0    	vmovaps zmm16{k1}{z},zmm16
    1aa6:	62 c1 7c 48 29 04 14 	vmovaps ZMMWORD PTR [r12+rdx*1],zmm16
    1aad:	48 81 c2 00 04 00 00 	add    rdx,0x400
    1ab4:	49 ff cf             	dec    r15
    1ab7:	0f 85 a3 fd ff ff    	jne    1860 <iree_hal_executable_library_query-0x280>
    1abd:	e9 ee fa ff ff       	jmp    15b0 <iree_hal_executable_library_query-0x530>
    1ac2:	31 c0                	xor    eax,eax
    1ac4:	48 8d 65 d8          	lea    rsp,[rbp-0x28]
    1ac8:	5b                   	pop    rbx
    1ac9:	41 5c                	pop    r12
    1acb:	41 5d                	pop    r13
    1acd:	41 5e                	pop    r14
    1acf:	41 5f                	pop    r15
    1ad1:	5d                   	pop    rbp
    1ad2:	c5 f8 77             	vzeroupper 
    1ad5:	c3                   	ret    
    1ad6:	cc                   	int3   
    1ad7:	cc                   	int3   
    1ad8:	cc                   	int3   
    1ad9:	cc                   	int3   
    1ada:	cc                   	int3   
    1adb:	cc                   	int3   
    1adc:	cc                   	int3   
    1add:	cc                   	int3   
    1ade:	cc                   	int3   
    1adf:	cc                   	int3   

0000000000001ae0 <iree_hal_executable_library_query>:
    1ae0:	31 c0                	xor    eax,eax
    1ae2:	83 ff 03             	cmp    edi,0x3
    1ae5:	48 8d 0d 54 10 00 00 	lea    rcx,[rip+0x1054]        # 2b40 <iree_hal_executable_library_query+0x1060>
    1aec:	48 0f 44 c1          	cmove  rax,rcx
    1af0:	c3                   	ret    

hanhanW avatar Jan 24 '24 16:01 hanhanW

is this a duplicate of #16025? (want to make sure they're the same effort - probably worth picking one issue to use?)

benvanik avatar Jan 29 '24 19:01 benvanik

The https://github.com/openxla/iree/issues/16025 is more for functionality -- i.e., making sure things are all controlled by ourselves. I think Benoit has more ideas about how it should work in terms of optimization, and we're tracking the optimization on this issue.

hanhanW avatar Jan 29 '24 19:01 hanhanW

#16025 is for fusing in terms of dispatch formation (right?). The present issue is for obtaining optimal object code. So #16025 is a prerequisite.

bjacob avatar Jan 29 '24 19:01 bjacob

Yes, that's what I think in my mind. I don't want to enable the fusion while we are doing bad control on tile size selection, or say that's unexpected. And this issue is more for tracking optimization when we enable the fusion.

hanhanW avatar Jan 29 '24 20:01 hanhanW

This makes sense to me. I understand that you plan to limit this fusion to cases where the pack op has been propagating above the producer op, right? Any plans for the consumer counterpart?

dcaballe avatar Feb 01 '24 01:02 dcaballe