test MFMA tiled layout in migraphx
DOR (Definition of Ready)
ready
Description
The purpose of this ticket is to try the tiled mfma-friendly layout POC from rocmlir (branch: mfma_layout_migraphx_integration).
Layout explanation
The idea is that we have a tiled layout: each tile we load from global memory is "packed" and we have a mfma-friendly layout at the tile level to avoid LDS bank conflicts.
The A tensor of a GEMM (batchxMxK) would be batch x (M/mPerBlock) x (K/kPerBlock) x kpackPerBlock x mPerBlock x kpack
As an example, with perfConfig=> mPerBlock=32, kpackPerBlock=8, kpack=8, kPerBlock=kpackPerBlock*kpack=64: So, if A is batchxMxK = 1x128x5120 it would be 1x4x80x8x32x8.
mPerBlock, kpackPerBlock and kpack are tuning parameters, mPerBlock x kPerBlock is the tile size we load from global memory for each workgroup. kPerBlock is decomposed into kpackPerBlock and kpack to avoid LDS bank conflicts. Because mfma/wmma instructions need the tile with this pattern.
This is an example migraphx IR of a GEMM f16 kernel (M=128, K=5120, N=1280, perfConfig=v3:32,32,8,16,16,8,1,2,2,1,1):
module {
func.func @mfma_layout(%a: !migraphx.shaped<1x4x80x8x32x8xf16, 655360x163840x2048x256x8x1>, %b: !migraphx.shaped<1x40x80x8x32x8xf16, 6553600x163840x2048x256x8x1>) -> !migraphx.shaped<1x128x1280xf16, 163840x1280x1> attributes{kernel, arch = "##TOKEN_ARCH##"} {
%aTranspose = migraphx.transpose %a {permutation = [0, 1, 4, 2, 3, 5]} : !migraphx.shaped<1x4x80x8x32x8xf16, 655360x163840x2048x256x8x1> -> !migraphx.shaped<1x4x32x80x8x8xf16, 655360x8x163840x2048x256x1>
%bTranspose = migraphx.transpose %b {permutation = [0, 1, 4, 2, 3, 5]} : !migraphx.shaped<1x40x80x8x32x8xf16, 6553600x163840x2048x256x8x1> -> !migraphx.shaped<1x40x32x80x8x8xf16, 6553600x8x163840x2048x256x1>
%aReshaped = migraphx.reshape %aTranspose {dims = [1, 128, 5120]} : <1x4x32x80x8x8xf16, 655360x8x163840x2048x256x1> -> !migraphx.shaped<1x128x5120xf16, 655360x5120x1>
%bReshaped = migraphx.reshape %bTranspose {dims = [1, 1280, 5120]} : <1x40x32x80x8x8xf16, 6553600x8x163840x2048x256x1> -> !migraphx.shaped<1x1280x5120xf16, 6553600x5120x1>
%bTranspose2 = migraphx.transpose %bReshaped {permutation = [0, 2, 1]} : !migraphx.shaped<1x1280x5120xf16, 6553600x5120x1> -> !migraphx.shaped<1x5120x1280xf16, 6553600x1x5120>
%0 = migraphx.dot %aReshaped, %bTranspose2 {perf_config="v3:32,32,8,16,16,8,1,2,2,1,1"} : <1x128x5120xf16, 655360x5120x1>, <1x5120x1280xf16, 6553600x1x5120> -> <1x128x1280xf16, 163840x1280x1>
return %0 : !migraphx.shaped<1x128x1280xf16, 163840x1280x1>
}
}
Note perfConfig=v3:mPerBlock,nPerBlock,kpackPerBlock,?,?,kpack,... So, in this case, mPerBlock=nPerBlock=32, kpackPerBlock=8, kpack=8
Implementation details
The branch only works for GEMMS, do not expect any speed up for convolutions of attention kernels. There are a few things to get right for this to work:
- Disable split-k
- Use the transforms correctly (as in the example IR above)
- The kernel needs a perfConfig assigned to work correctly
- Padding is not supported, so M must be divisible by mPerBlock (and N by nPerBlock) and K must be divisible by kpackPerBlock*kpack
Also, make sure our migraphx integration works correctly (input fusions might not work in all cases). This can be done by checking the rock dialect.
This is because the GEMM kernel loads data differently if it detects that we are using this layout. To verify, check the rock dialect output, it should have aAccelLayout and bAccelLayout (if both are using the new layout):
rock.gemm %alloc = %8 * tr %9 ... {aAccelLayout, arch = ..., bAccelLayout, perf_config = ...} ...
I'd recommend testing this on a model that has a lot of GEMM kernels (or an LLM/attention one, with attention kernel disabled). Also, make sure the model GEMM sizes are in the list below.
Data
I attach here a list of GEMMs that come from tier1 GEMM list (M, N and K padded to 32 at least) and the tuned perfConfig (and mPerBlock, nPerBlock, kpackPerBlock and kpack in columns).
DOD (Definition of Done)
Measure performance of rocmlir branch on end-to-end models vs develop branch with standard layout.
I've tried getting run-time for the same kernel as above but changing the output layout as well:
module {
func.func @mfma_layout(%a: !migraphx.shaped<1x4x80x8x32x8xf16, 655360x163840x2048x256x8x1>, %b: !migraphx.shaped<1x40x80x8x32x8xf16, 6553600x163840x2048x256x8x1>) -> !migraphx.shaped<1x4x20x8x32x8xf16, 163840x40960x64x8x1280x1> attributes{kernel, arch = "##TOKEN_ARCH##"} {
%aTranspose = migraphx.transpose %a {permutation = [0, 1, 4, 2, 3, 5]} : !migraphx.shaped<1x4x80x8x32x8xf16, 655360x163840x2048x256x8x1> -> !migraphx.shaped<1x4x32x80x8x8xf16, 655360x8x163840x2048x256x1>
%bTranspose = migraphx.transpose %b {permutation = [0, 1, 4, 2, 3, 5]} : !migraphx.shaped<1x40x80x8x32x8xf16, 6553600x163840x2048x256x8x1> -> !migraphx.shaped<1x40x32x80x8x8xf16, 6553600x8x163840x2048x256x1>
%aReshaped = migraphx.reshape %aTranspose {dims = [1, 128, 5120]} : <1x4x32x80x8x8xf16, 655360x8x163840x2048x256x1> -> !migraphx.shaped<1x128x5120xf16, 655360x5120x1>
%bReshaped = migraphx.reshape %bTranspose {dims = [1, 1280, 5120]} : <1x40x32x80x8x8xf16, 6553600x8x163840x2048x256x1> -> !migraphx.shaped<1x1280x5120xf16, 6553600x5120x1>
%bTranspose2 = migraphx.transpose %bReshaped {permutation = [0, 2, 1]} : !migraphx.shaped<1x1280x5120xf16, 6553600x5120x1> -> !migraphx.shaped<1x5120x1280xf16, 6553600x1x5120>
%0 = migraphx.dot %aReshaped, %bTranspose2 {perf_config="v3:32,32,8,16,16,8,1,2,2,1,1"} : <1x128x5120xf16, 655360x5120x1>, <1x5120x1280xf16, 6553600x1x5120> -> <1x128x1280xf16, 163840x1280x1>
// 1x128x1280 -> 1x4x20x8x32x8
%outReshaped = migraphx.reshape %0 {dims = [1, 4, 32, 20, 8, 8]} : <1x128x1280xf16, 163840x1280x1> -> !migraphx.shaped<1x4x32x20x8x8xf16, 163840x40960x1280x64x8x1>
%outTranspose = migraphx.transpose %outReshaped {permutation = [0, 1, 3, 4, 2, 5]} : !migraphx.shaped<1x4x32x20x8x8xf16, 163840x40960x1280x64x8x1> -> !migraphx.shaped<1x4x20x8x32x8xf16, 163840x40960x64x8x1280x1>
return %outTranspose : !migraphx.shaped<1x4x20x8x32x8xf16, 163840x40960x64x8x1280x1>
}
}
It seems performance is about the same at least for this case.
Can you post the data as a comma separated list or a link to the spreadsheet online?
Can you post the data as a comma separated list or a link to the spreadsheet online?
Some notes from Daniel's slides:
[B, dBlock, kIter, kpackPerBlock, dPerBlock, kpack]
B = Batch
K = kIter*kpackPerBlock*kpack
D = dBlock*dPerBlock
kIter = K/(kpackPerBlock*kpack)
[B, dBlock, kIter, kpackPerBlock, dPerBlock, kpack] => [B, dBlock, dPerBlock, kIter, kpackPerBlock, kpack]
Same file but for gfx1201: AccelLayout_gfx1201_migraphx.csv