AMDMIGraphX
AMDMIGraphX copied to clipboard
GEMM fusion (over slice or not)
From the 22 Feb 2024 performance model review of Distilgpt2:
There are several gemms that are applied together(this is the tailend of attention):
@17 = hip::hip_copy_literal[id=main:@literal:6] -> half_type, {348, 2304}, {2304, 1}
@18 = hip::hip_copy_literal[id=main:@literal:73] -> half_type, {768, 2304}, {2304, 1}
@21 = gpu::gemm[alpha=1,beta=1,compute_fp32=1,trans_batch=0,solution_idx=0](@20,@18,@17,@19) -> half_type, {348, 2304}, {2304, 1}
@22 = reshape_lazy[dims={1, 348, 36, 64}](@21) -> half_type, {1, 348, 36, 64}, {801792, 2304, 64, 1}
@23 = transpose[permutation={0, 2, 1, 3}](@22) -> half_type, {1, 36, 348, 64}, {801792, 64, 2304, 1}
@35 = slice[axes={1},starts={24},ends={36}](@23) -> half_type, {1, 12, 348, 64}, {801792, 64, 2304, 1}
@36 = gpu::gemm[alpha=1,beta=0,compute_fp32=1,trans_batch=1,solution_idx=0](@32,@35,@34) -> half_type, {1, 12, 348, 64}, {267264, 64, 768, 1}
@37 = hip::hip_copy_literal[id=main:@literal:72] -> half_type, {768, 768}, {768, 1}
@38 = load[offset=1069056,end=1603584](@1) -> half_type, {348, 768}, {768, 1}
@39 = transpose[permutation={0, 2, 1, 3}](@36) -> half_type, {1, 348, 12, 64}, {267264, 768, 64, 1}
@40 = reshape_lazy[dims={348, 768}](@39) -> half_type, {348, 768}, {768, 1}
@41 = gpu::gemm[alpha=1,beta=0,compute_fp32=1,trans_batch=0,solution_idx=0](@40,@37,@38) -> half_type, {348, 768}, {768, 1}
We have something like X * (Y*A + b) * C
(where * is matmul) if we get rid of the slice(which is undoing some of the horizontal fusions).
So we could possibly rewrite it as X * (Y*A*C + b*C)
, which after const folding we would just have X * (Y*A' + b')
which gets rid of the gemm completely.
This case can be generalized to also not have the slice
operator, simplifying the manipulations needed.
Deliverables:
- Figure out how to implement the GEMM fusion over the slice and create a matcher for it