AMDMIGraphX
AMDMIGraphX copied to clipboard
Fuse `where` into MLIR attention
From the 22 Feb 2024 performance model review of Distilgpt2:
There is a where before the softmax which prevents us from using flash attention:
@34 = gpu::code_object[code_object=9224,symbol_name=where_kernel,global=363312,local=1024,](@33,@30,@32) -> half_type, {1, 12, 348, 348}, {1453248, 121104, 348, 1}
@35 = load[offset=0,end=8352](@1) -> half_type, {1, 12, 348, 1}, {4176, 348, 1, 1}
@36 = gpu::code_object[code_object=9296,symbol_name=reduce_max_kernel,global=534528,local=128,](@34,@35) -> half_type, {1, 12, 348, 1}, {4176, 348, 1, 1}
MLIR all compile
@27 = gpu::code_object[code_object=5704,symbol_name=mul_where_reduce_max_sub_exp_reduce_sum_div_kernel,global=801792,local=192,](@24,@26,@25) -> half_type, {1, 12, 348, 348}, {1453248, 121104, 348, 1}, target_id=0: 0.015923ms, 2%
MLIR’s Flash Attention does support where
before the softmax
so it might be possible to just fuse it directly in MLIR. We need some tweaking to how we currently fuse FA. Also create a ticket for MLIR about this once investigated some more.
Deliverables:
- Create tickets on the MLIR repo:
- Confirm that
where
is supported by MLIR - Provide an example of this fusion usecase with flash attention
- Confirm that
- Have the
where
fused with the MLIR flash attention kernel