flashinfer
flashinfer copied to clipboard
[Performance Issue] FlashInfer shows no performance improvement with FP8 compared to BF16 in BatchDecodeWithPagedKVCacheWrapper with page_size=1
Description
When using FlashInfer for decode operations with page_size=1, no performance improvement is observed when using FP8 data type compared to BF16. This contradicts theoretical expectations, as FP8 should provide better memory bandwidth utilization.
Reproduction
import torch
import triton
import flashinfer
class FlashInferWrapper:
def __init__(self, kv_layout="NHD", kv_dtype=torch.bfloat16):
self.page_size = 1
self.kv_layout = kv_layout
self.kv_dtype = kv_dtype
self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.uint8, device="cuda")
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, self.kv_layout, use_tensor_cores=True
)
def prepare_kv_cache(self, batch_size, kv_len, head_dim, num_kv_heads):
k_data = torch.randn(batch_size * kv_len, num_kv_heads, head_dim, dtype=torch.bfloat16).to("cuda")
v_data = torch.randn(batch_size * kv_len, num_kv_heads, head_dim, dtype=torch.bfloat16).to("cuda")
kv_indptr = torch.arange(0, batch_size + 1).int().to("cuda") * kv_len
kv_indices = torch.arange(0, batch_size * kv_len).int().to("cuda")
kv_last_page_len = torch.full(
(batch_size,), (kv_len - 1) % self.page_size + 1, dtype=torch.int32
).to("cuda")
return k_data, v_data, kv_indptr, kv_indices, kv_last_page_len
def update_decode(self, batch_size, kv_len, head_dim, num_heads, num_kv_heads):
k_data, v_data, kv_indptr, kv_indices, kv_last_page_len = self.prepare_kv_cache(batch_size, kv_len, head_dim, num_kv_heads)
self.decode_wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_len,
num_heads,
num_kv_heads,
head_dim,
self.page_size,
q_data_type=torch.bfloat16,
kv_data_type=self.kv_dtype,
non_blocking=True,
)
return k_data, v_data
def forward_decode(self, q, kv_cache):
return self.decode_wrapper.run(q, kv_cache)
def perf_flashinfer_decode(batch_size, kv_len, d, H, H_kv):
flashinfer_wrapper = FlashInferWrapper(kv_dtype=torch.bfloat16)
q = torch.randn(batch_size, H, d).bfloat16().to("cuda")
k_data, v_data = flashinfer_wrapper.update_decode(batch_size, kv_len, d, H, H_kv)
kv_cache = torch.cat([k_data, v_data], dim=1)
kv_cache = kv_cache.view(batch_size*kv_len, 2, 1, H_kv, d)
def fn():
o = flashinfer_wrapper.forward_decode(q, kv_cache)
ms = triton.testing.do_bench(fn, quantiles=[0.5])
return ms/1e3
def perf_flashinfer_decode_fp8(batch_size, kv_len, d, H, H_kv):
flashinfer_wrapper = FlashInferWrapper(kv_dtype=torch.float8_e4m3fn)
q = torch.randn(batch_size, H, d).bfloat16().to("cuda")
k_data, v_data = flashinfer_wrapper.update_decode(batch_size, kv_len, d, H, H_kv)
kv_cache = torch.cat([k_data, v_data], dim=1)
kv_cache = kv_cache.view(batch_size*kv_len, 2, 1, H_kv, d)
kv_cache = kv_cache.to(torch.float8_e4m3fn)
def fn():
o = flashinfer_wrapper.forward_decode(q, kv_cache)
ms = triton.testing.do_bench(fn, quantiles=[0.5])
return ms/1e3
if __name__ == "__main__":
batch_size = 64
kv_len = 8192
d = 128
H = 40
H_kv = 8
time_ms = perf_flashinfer_decode(batch_size, kv_len, d, H, H_kv)
print(f"bf16 time cost: {time_ms*1e3:.3f} ms")
time_ms = perf_flashinfer_decode_fp8(batch_size, kv_len, d, H, H_kv)
print(f"fp8 time cost: {time_ms*1e3:.3f} ms")
results:
bf16 time cost: 0.703 ms
fp8 time cost: 0.710 ms
Environment Information
- GPU: H20
- FlashInfer Version: flashinfer-python 0.2.3+cu124torch2.5
- CUDA Version: 12.4
I have same question on L20 performance #914