LLMRoofline
LLMRoofline copied to clipboard
关于LLM-Viewer flash attention的Memory Access的计算
大佬问一下,按flashattention的理论IO复杂度分心是N^2d^2/M,Memory Access应该是非常低的,下面的代码计算出来的memory access非常大,而且用的是T_r不是T_c,和flashattention的理论分析不太一样,请问如何理解下面的计算?
if use_flashattention:
name = f"fused_attention"
bandwidth, max_OPS, onchip_buffer = self.get_hardware_info()
# flashattention-2 https://arxiv.org/pdf/2307.08691.pdf
block_size_r = min(math.ceil(onchip_buffer / (kv_byte * head_size)), head_size)
n_blocks_r = math.ceil(seqlen / block_size_r)
q_numel = seqlen * head_size * batchsize * num_attention_heads * a_byte
o_numel = seqlen * seqlen * batchsize * num_attention_heads * a_byte
self._analyze_to_results(
"prefill",
name,
OPs=qk_matmul_OPs + sv_matmul_OPs + softmax_OPs,
load_weight=0,
load_act=q_numel,
store_act=o_numel * 2, # initialize O and save O
load_kv_cache=n_blocks_r * (seqlen) * head_size * batchsize * num_attention_heads * kv_byte * 2,
store_kv_cache=0,
)