flashinfer icon indicating copy to clipboard operation
flashinfer copied to clipboard

[Performance Issue] FlashInfer shows no performance improvement with FP8 compared to BF16 in BatchDecodeWithPagedKVCacheWrapper with page_size=1

Open cscyuge opened this issue 7 months ago • 1 comments

Description

When using FlashInfer for decode operations with page_size=1, no performance improvement is observed when using FP8 data type compared to BF16. This contradicts theoretical expectations, as FP8 should provide better memory bandwidth utilization.

Reproduction

import torch
import triton
import flashinfer

class FlashInferWrapper:
    def __init__(self, kv_layout="NHD", kv_dtype=torch.bfloat16):
        self.page_size = 1
        self.kv_layout = kv_layout
        self.kv_dtype = kv_dtype
        self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.uint8, device="cuda")
        self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
            self.workspace_buffer, self.kv_layout, use_tensor_cores=True
        )
        
    def prepare_kv_cache(self, batch_size, kv_len, head_dim, num_kv_heads):
        k_data = torch.randn(batch_size * kv_len, num_kv_heads, head_dim, dtype=torch.bfloat16).to("cuda")
        v_data = torch.randn(batch_size * kv_len, num_kv_heads, head_dim, dtype=torch.bfloat16).to("cuda")
        kv_indptr = torch.arange(0, batch_size + 1).int().to("cuda") * kv_len
        kv_indices = torch.arange(0, batch_size * kv_len).int().to("cuda")
        kv_last_page_len = torch.full(
            (batch_size,), (kv_len - 1) % self.page_size + 1, dtype=torch.int32
        ).to("cuda")
        return k_data, v_data, kv_indptr, kv_indices, kv_last_page_len
    
    def update_decode(self, batch_size, kv_len, head_dim, num_heads, num_kv_heads):
        k_data, v_data, kv_indptr, kv_indices, kv_last_page_len = self.prepare_kv_cache(batch_size, kv_len, head_dim, num_kv_heads)
        self.decode_wrapper.plan(
            kv_indptr,
            kv_indices,
            kv_last_page_len,
            num_heads,
            num_kv_heads,
            head_dim,
            self.page_size,
            q_data_type=torch.bfloat16,
            kv_data_type=self.kv_dtype,
            non_blocking=True,
        )
        return k_data, v_data
    
    def forward_decode(self, q, kv_cache):
        return self.decode_wrapper.run(q, kv_cache)


def perf_flashinfer_decode(batch_size, kv_len, d, H, H_kv):
    flashinfer_wrapper = FlashInferWrapper(kv_dtype=torch.bfloat16)
    q = torch.randn(batch_size, H, d).bfloat16().to("cuda")
    k_data, v_data = flashinfer_wrapper.update_decode(batch_size, kv_len, d, H, H_kv)
    kv_cache = torch.cat([k_data, v_data], dim=1)
    kv_cache = kv_cache.view(batch_size*kv_len, 2, 1, H_kv, d)
    
    def fn():
        o = flashinfer_wrapper.forward_decode(q, kv_cache)
    ms = triton.testing.do_bench(fn, quantiles=[0.5])
    return ms/1e3

def perf_flashinfer_decode_fp8(batch_size, kv_len, d, H, H_kv):
    flashinfer_wrapper = FlashInferWrapper(kv_dtype=torch.float8_e4m3fn)
    q = torch.randn(batch_size, H, d).bfloat16().to("cuda")
    k_data, v_data = flashinfer_wrapper.update_decode(batch_size, kv_len, d, H, H_kv)
    kv_cache = torch.cat([k_data, v_data], dim=1)
    kv_cache = kv_cache.view(batch_size*kv_len, 2, 1, H_kv, d)
    kv_cache = kv_cache.to(torch.float8_e4m3fn)
    
    def fn():
        o = flashinfer_wrapper.forward_decode(q, kv_cache)
    
    ms = triton.testing.do_bench(fn, quantiles=[0.5])
    return ms/1e3

if __name__ == "__main__":
    batch_size = 64
    kv_len = 8192
    d = 128  
    H = 40   
    H_kv = 8
    time_ms = perf_flashinfer_decode(batch_size, kv_len, d, H, H_kv)
    print(f"bf16 time cost: {time_ms*1e3:.3f} ms")
    time_ms = perf_flashinfer_decode_fp8(batch_size, kv_len, d, H, H_kv)
    print(f"fp8 time cost: {time_ms*1e3:.3f} ms")

results:

bf16 time cost: 0.703 ms
fp8 time cost: 0.710 ms

Environment Information

  • GPU: H20
  • FlashInfer Version: flashinfer-python 0.2.3+cu124torch2.5
  • CUDA Version: 12.4

cscyuge avatar Apr 18 '25 10:04 cscyuge

I have same question on L20 performance #914

yongchaoding avatar Apr 21 '25 07:04 yongchaoding