iree icon indicating copy to clipboard operation
iree copied to clipboard

Port `s8s4s32` mmt4d ukernel code path to x86-64

Open bjacob opened this issue 1 year ago • 0 comments

Recently @mariecwhite has been adding s8s4s32 code paths to the mmt4d ukernel, including optimized code paths for arm64 but not for x86-64. This Issue is about adding the x86-64 pieces.

Explanation of "mmt4d": "matrix-times-matrix-transposed on 4D tensors" == our matrix-multiplication ukernel.

Explanation of "s8s4s32": this is the type triple describing the mmt4d op. Here s8 is the LHS element type = signed int8, s4 is the RHS element type = signed int4, s32 is the accumulator (output) element type.

Get familiar with the code:

  • Entry point into the mmt4d ukernel: https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/mmt4d.c#L115-L141
  • This selects a "tile function" which is an optimized implementation of the innermost loop: https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/mmt4d.c#L125-L126
  • Currently for s8s4s32 the x86-64 implementation just returns NULL here as there is no case for that so it hits this default:: https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c#L414-L415
  • To understand what this should be doing, start by looking at the generic fallback implementation: https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/mmt4d_tile.c#L10-L51
  • If you're curious what the arm64 optimized implementation of that looks like (even though it won't be directly applicable), https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_base.c#L291-L423
  • Back on x86, here is a closely related existing kernel for s8s8s32 --- so it's almost the same, just the RHS is s8 instead of s4 so it doesn't need to do the extra work of unpacking two 4-bit values from each byte:
    • AVX2 case, for narrow tiles from 1x8 to 4x8: https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c#L154-L201
    • AVX2 case, general tile 8x8: https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c#L154-L201
    • Other _avx512 files in this directoty have corresponding AVX512 cases.
  • Another existing related kernel on x86, that does deal with 4bit values, is s16u4s32. The u stands for unsigned. https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c#L189-L295

Explanation of the tile sizes:

  • Tile sizes are given in M0xN0xK0 convention. For example, 8x32x4 means tile 8 along the M dimension, 32 along the N dimension, 4 along the K dimension.
  • Each ukernel can dictate its preferred tile size. You are free to choose what tile size you want here, but I would suggest that you start with what the existing s8s8s32 ukernel does on x86, and multiply its K0 tile size by 2 to account for the fact that 4 bit values are 2x smaller and you want to mask odd/even lanes and still have enough to feed your arithmetic instructions. So for instance on AVX2, the existing tile is Mx8x2, so you could start from Mx8x4 in your case. You will need to implement M values 1, 2, 4, 8 (M=8 is needed for the general case, other M values are needed for narrow problems such as vector-times-matrix).

How to run tests and micro benchmarks:

 ninja mmt4d_test mmt4d_benchmark && ./runtime/src/iree/builtins/ukernel/tools/mmt4d_test && ./runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark

bjacob avatar Mar 04 '24 22:03 bjacob