AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Improve `simplify_algebra` to find more horizontal fusion opportunities

Open kahmed10 opened this issue 1 year ago • 0 comments

In SD clip, there is an opportunity to fuse all the add kernels:

@15 = gpu::code_object[code_object=7632,symbol_name=mlir_dot_add,global=133632,local=256,](@13,@12,@5,@14) -> half_type, {24, 77, 2304}, {177408, 2304, 1}: 0.0934304ms, 2%
@16 = hip::hip_copy_literal[id=main:@literal:78] -> half_type, {768}, {1}: 0.00109522ms, 1%
@17 = hip::hip_copy_literal[id=main:@literal:59] -> half_type, {768}, {1}: 0.00108192ms, 1%
@18 = slice[axes={2},starts={768},ends={1536}](@15) -> half_type, {24, 77, 768}, {177408, 2304, 1}: 0.00165542ms, 1%
@19 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@17) -> half_type, {24, 77, 768}, {0, 0, 1}: 0.00094074ms, 1%
@20 = load[offset=18184320,end=21022848](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.00076536ms, 1%
**@21 = gpu::code_object[code_object=5128,symbol_name=add_kernel,global=354816,local=1024,](@19,@18,@20) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.0211362ms, 1%**
@22 = load[offset=11354112,end=14192640](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.00099472ms, 1%
@23 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@16) -> half_type, {24, 77, 768}, {0, 0, 1}: 0.00182424ms, 1%
@24 = slice[axes={2},starts={0},ends={768}](@15) -> half_type, {24, 77, 768}, {177408, 2304, 1}: 0.00103286ms, 1%
**@25 = gpu::code_object[code_object=5136,symbol_name=mul_add_kernel,global=354816,local=1024,](@24,@23,@22) -> half_type, {24, 77, 768}, {59136, 768, 1}: 0.0413997ms, 1%**
@26 = load[offset=14769216,end=18184320](@1) -> half_type, {24, 12, 77, 77}, {71148, 5929, 77, 1}: 0.00105ms, 1%
@27 = gpu::code_object[code_object=6736,symbol_name=mlir_reshape_transpose_reshape_transpose_dot,global=73728,local=256,](@25,@21,@26) -> half_type, {24, 12, 77, 77}, {71148, 5929, 77, 1}: 0.0248955ms, 1%
...
@32 = load[offset=14769216,end=17607744](@1) -> half_type, {24, 77, 768}, {59136, 768, 1}
@33 = multibroadcast[out_lens={24, 77, 768},out_dyn_dims={}](@31) -> half_type, {24, 77, 768}, {0, 0, 1}
@34 = slice[axes={2},starts={1536},ends={2304}](@14) -> half_type, {24, 77, 768}, {177408, 2304, 1}
**@35 = gpu::code_object[code_object=5128,symbol_name=add_kernel,global=354816,local=1024,](@33,@34,@32) -> half_type, {24, 77, 768}, {59136, 768, 1}**

Here the mul_add kernel is actually a scalar multiply + add:

module: "main:pointwise10"
main:pointwise10:x1 = @param:x1 -> half_type, {1}, {0}
main:pointwise10:x0 = @param:x0 -> half_type, {1}, {0}
main:pointwise10:@2 = @literal{0.125} -> half_type, {1}, {0}
main:pointwise10:@3 = mul(main:pointwise10:@2,main:pointwise10:x0) -> half_type, {1}, {0}
main:pointwise10:@4 = add(main:pointwise10:@3,main:pointwise10:x1) -> half_type, {1}, {0}
main:pointwise10:@5 = @return(main:pointwise10:@4)

One possible solution would be to improve simplify_algebra to add two loops. The first is to check for horizontal fusions, and the second is to rewrite expressions.

The scalar multiply may be standalone after this, so find_unary_shape_transforms would need to be tweaked to support this as well.

And we may need to add an exception to find_mul_add to skip the rewrite if the input is scalar and feeds into a gemm or convolution.

kahmed10 avatar Sep 10 '24 21:09 kahmed10