AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

test MFMA tiled layout in migraphx

Open dhernandez0 opened this issue 4 months ago • 5 comments

DOR (Definition of Ready)

ready

Description

The purpose of this ticket is to try the tiled mfma-friendly layout POC from rocmlir (branch: mfma_layout_migraphx_integration).

Layout explanation

The idea is that we have a tiled layout: each tile we load from global memory is "packed" and we have a mfma-friendly layout at the tile level to avoid LDS bank conflicts.

The A tensor of a GEMM (batchxMxK) would be batch x (M/mPerBlock) x (K/kPerBlock) x kpackPerBlock x mPerBlock x kpack

As an example, with perfConfig=> mPerBlock=32, kpackPerBlock=8, kpack=8, kPerBlock=kpackPerBlock*kpack=64: So, if A is batchxMxK = 1x128x5120 it would be 1x4x80x8x32x8.

mPerBlock, kpackPerBlock and kpack are tuning parameters, mPerBlock x kPerBlock is the tile size we load from global memory for each workgroup. kPerBlock is decomposed into kpackPerBlock and kpack to avoid LDS bank conflicts. Because mfma/wmma instructions need the tile with this pattern.

This is an example migraphx IR of a GEMM f16 kernel (M=128, K=5120, N=1280, perfConfig=v3:32,32,8,16,16,8,1,2,2,1,1):

module {
  func.func @mfma_layout(%a: !migraphx.shaped<1x4x80x8x32x8xf16, 655360x163840x2048x256x8x1>, %b: !migraphx.shaped<1x40x80x8x32x8xf16, 6553600x163840x2048x256x8x1>) -> !migraphx.shaped<1x128x1280xf16, 163840x1280x1> attributes{kernel, arch = "##TOKEN_ARCH##"} {
    %aTranspose = migraphx.transpose %a {permutation = [0, 1, 4, 2, 3, 5]} : !migraphx.shaped<1x4x80x8x32x8xf16, 655360x163840x2048x256x8x1> -> !migraphx.shaped<1x4x32x80x8x8xf16, 655360x8x163840x2048x256x1>
    %bTranspose = migraphx.transpose %b {permutation = [0, 1, 4, 2, 3, 5]} : !migraphx.shaped<1x40x80x8x32x8xf16, 6553600x163840x2048x256x8x1> -> !migraphx.shaped<1x40x32x80x8x8xf16, 6553600x8x163840x2048x256x1>
    %aReshaped = migraphx.reshape %aTranspose {dims = [1, 128, 5120]} : <1x4x32x80x8x8xf16, 655360x8x163840x2048x256x1> -> !migraphx.shaped<1x128x5120xf16, 655360x5120x1>
    %bReshaped = migraphx.reshape %bTranspose {dims = [1, 1280, 5120]} : <1x40x32x80x8x8xf16, 6553600x8x163840x2048x256x1> -> !migraphx.shaped<1x1280x5120xf16, 6553600x5120x1>
    %bTranspose2 = migraphx.transpose %bReshaped {permutation = [0, 2, 1]} : !migraphx.shaped<1x1280x5120xf16, 6553600x5120x1> -> !migraphx.shaped<1x5120x1280xf16, 6553600x1x5120>
    %0 = migraphx.dot %aReshaped, %bTranspose2 {perf_config="v3:32,32,8,16,16,8,1,2,2,1,1"} : <1x128x5120xf16, 655360x5120x1>, <1x5120x1280xf16, 6553600x1x5120> -> <1x128x1280xf16, 163840x1280x1>
    return %0 : !migraphx.shaped<1x128x1280xf16, 163840x1280x1>
  }
}

Note perfConfig=v3:mPerBlock,nPerBlock,kpackPerBlock,?,?,kpack,... So, in this case, mPerBlock=nPerBlock=32, kpackPerBlock=8, kpack=8

Implementation details

The branch only works for GEMMS, do not expect any speed up for convolutions of attention kernels. There are a few things to get right for this to work:

  • Disable split-k
  • Use the transforms correctly (as in the example IR above)
  • The kernel needs a perfConfig assigned to work correctly
  • Padding is not supported, so M must be divisible by mPerBlock (and N by nPerBlock) and K must be divisible by kpackPerBlock*kpack

Also, make sure our migraphx integration works correctly (input fusions might not work in all cases). This can be done by checking the rock dialect.

This is because the GEMM kernel loads data differently if it detects that we are using this layout. To verify, check the rock dialect output, it should have aAccelLayout and bAccelLayout (if both are using the new layout):

  rock.gemm %alloc = %8 * tr %9 ... {aAccelLayout, arch = ..., bAccelLayout, perf_config = ...} ...

I'd recommend testing this on a model that has a lot of GEMM kernels (or an LLM/attention one, with attention kernel disabled). Also, make sure the model GEMM sizes are in the list below.

Data

I attach here a list of GEMMs that come from tier1 GEMM list (M, N and K padded to 32 at least) and the tuned perfConfig (and mPerBlock, nPerBlock, kpackPerBlock and kpack in columns).

AccelLayout_migraphx.xlsx

DOD (Definition of Done)

Measure performance of rocmlir branch on end-to-end models vs develop branch with standard layout.

dhernandez0 avatar Aug 11 '25 10:08 dhernandez0

I've tried getting run-time for the same kernel as above but changing the output layout as well:

module {
  func.func @mfma_layout(%a: !migraphx.shaped<1x4x80x8x32x8xf16, 655360x163840x2048x256x8x1>, %b: !migraphx.shaped<1x40x80x8x32x8xf16, 6553600x163840x2048x256x8x1>) -> !migraphx.shaped<1x4x20x8x32x8xf16, 163840x40960x64x8x1280x1> attributes{kernel, arch = "##TOKEN_ARCH##"} {
    %aTranspose = migraphx.transpose %a {permutation = [0, 1, 4, 2, 3, 5]} : !migraphx.shaped<1x4x80x8x32x8xf16, 655360x163840x2048x256x8x1> -> !migraphx.shaped<1x4x32x80x8x8xf16, 655360x8x163840x2048x256x1>
    %bTranspose = migraphx.transpose %b {permutation = [0, 1, 4, 2, 3, 5]} : !migraphx.shaped<1x40x80x8x32x8xf16, 6553600x163840x2048x256x8x1> -> !migraphx.shaped<1x40x32x80x8x8xf16, 6553600x8x163840x2048x256x1>
    %aReshaped = migraphx.reshape %aTranspose {dims = [1, 128, 5120]} : <1x4x32x80x8x8xf16, 655360x8x163840x2048x256x1> -> !migraphx.shaped<1x128x5120xf16, 655360x5120x1>
    %bReshaped = migraphx.reshape %bTranspose {dims = [1, 1280, 5120]} : <1x40x32x80x8x8xf16, 6553600x8x163840x2048x256x1> -> !migraphx.shaped<1x1280x5120xf16, 6553600x5120x1>
    %bTranspose2 = migraphx.transpose %bReshaped {permutation = [0, 2, 1]} : !migraphx.shaped<1x1280x5120xf16, 6553600x5120x1> -> !migraphx.shaped<1x5120x1280xf16, 6553600x1x5120>
    %0 = migraphx.dot %aReshaped, %bTranspose2 {perf_config="v3:32,32,8,16,16,8,1,2,2,1,1"} : <1x128x5120xf16, 655360x5120x1>, <1x5120x1280xf16, 6553600x1x5120> -> <1x128x1280xf16, 163840x1280x1>
    // 1x128x1280 -> 1x4x20x8x32x8
    %outReshaped = migraphx.reshape %0 {dims = [1, 4, 32, 20, 8, 8]} : <1x128x1280xf16, 163840x1280x1> -> !migraphx.shaped<1x4x32x20x8x8xf16, 163840x40960x1280x64x8x1>
    %outTranspose = migraphx.transpose %outReshaped {permutation = [0, 1, 3, 4, 2, 5]} : !migraphx.shaped<1x4x32x20x8x8xf16, 163840x40960x1280x64x8x1> -> !migraphx.shaped<1x4x20x8x32x8xf16, 163840x40960x64x8x1280x1>
    return %outTranspose : !migraphx.shaped<1x4x20x8x32x8xf16, 163840x40960x64x8x1280x1>
  }
}

It seems performance is about the same at least for this case.

dhernandez0 avatar Aug 11 '25 11:08 dhernandez0

Can you post the data as a comma separated list or a link to the spreadsheet online?

pfultz2 avatar Aug 11 '25 16:08 pfultz2

Can you post the data as a comma separated list or a link to the spreadsheet online?

AccelLayout_migraphx.csv

dhernandez0 avatar Aug 11 '25 17:08 dhernandez0

Some notes from Daniel's slides:

[B, dBlock, kIter, kpackPerBlock, dPerBlock, kpack]

B = Batch
K = kIter*kpackPerBlock*kpack
D = dBlock*dPerBlock
kIter = K/(kpackPerBlock*kpack)

[B, dBlock, kIter, kpackPerBlock, dPerBlock, kpack] => [B, dBlock, dPerBlock, kIter, kpackPerBlock, kpack]

pfultz2 avatar Sep 03 '25 15:09 pfultz2

Same file but for gfx1201: AccelLayout_gfx1201_migraphx.csv

dhernandez0 avatar Sep 17 '25 09:09 dhernandez0