AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Fuse `where` into MLIR attention

Open CharlieL7 opened this issue 1 year ago • 0 comments

From the 22 Feb 2024 performance model review of Distilgpt2:

There is a where before the softmax which prevents us from using flash attention:

@34 = gpu::code_object[code_object=9224,symbol_name=where_kernel,global=363312,local=1024,](@33,@30,@32) -> half_type, {1, 12, 348, 348}, {1453248, 121104, 348, 1}
@35 = load[offset=0,end=8352](@1) -> half_type, {1, 12, 348, 1}, {4176, 348, 1, 1}
@36 = gpu::code_object[code_object=9296,symbol_name=reduce_max_kernel,global=534528,local=128,](@34,@35) -> half_type, {1, 12, 348, 1}, {4176, 348, 1, 1}

MLIR all compile

@27 = gpu::code_object[code_object=5704,symbol_name=mul_where_reduce_max_sub_exp_reduce_sum_div_kernel,global=801792,local=192,](@24,@26,@25) -> half_type, {1, 12, 348, 348}, {1453248, 121104, 348, 1}, target_id=0: 0.015923ms, 2%

MLIR’s Flash Attention does support where before the softmax so it might be possible to just fuse it directly in MLIR. We need some tweaking to how we currently fuse FA. Also create a ticket for MLIR about this once investigated some more.

Deliverables:

  • Create tickets on the MLIR repo:
    • Confirm that where is supported by MLIR
    • Provide an example of this fusion usecase with flash attention
  • Have the where fused with the MLIR flash attention kernel

CharlieL7 avatar Feb 22 '24 18:02 CharlieL7