flashinfer icon indicating copy to clipboard operation
flashinfer copied to clipboard

flashinfer fp8 prefill has no speed up than fp16 in L20

Open yongchaoding opened this issue 9 months ago • 0 comments

I try to use fp8 prefill attention kernel and i found that in L20, there has no speed up than fp16, even slow. is it correct?

from flash_attn.utils.benchmark import benchmark_forward
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
import flashinfer

import argparse

parser = argparse.ArgumentParser(description='Benchmark FlashInfer')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--num_kv_heads', type=int, default=32, help='Number of heads')
parser.add_argument('--num_qo_heads', type=int, default=32, help='Number of heads')
parser.add_argument('--head_dim', type=int, default=128, help='Head dimension')
args = parser.parse_args()

qo_head = args.num_qo_heads
kv_head = args.num_kv_heads
batch = args.batch_size
headdim = args.head_dim

print(f"FlashInfer Benchmark")
print(f"batch: {batch}, qo_head: {qo_head}, kv_head: {kv_head}, headdim: {headdim}")

kv_layout = "NHD"
workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
    workspace_buffer, kv_layout
)

is_causal = False
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
    flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
    q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda")
    k = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")
    v = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")

    q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    wrapper.plan(
            q_indptr,
            kv_indptr,
            qo_head,
            kv_head,
            headdim,
            causal=is_causal,
        )
    o = wrapper.run(q, k, v)
    for i in range(5): wrapper.run(q, k, v)
    torch.cuda.synchronize()
    _, time = benchmark_forward(wrapper.run, q, k, v, repeats=100, verbose=False, desc='Flashinfer')
    print(f'{seq_len} flops:{flops/time.mean*1e-12}')


is_causal = True
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
    flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
    q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda")
    k = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")
    v = torch.randn(batch*seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")

    q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    wrapper.plan(
            q_indptr,
            kv_indptr,
            qo_head,
            kv_head,
            headdim,
            causal=is_causal,
        )
    o = wrapper.run(q, k, v)
    for i in range(5): wrapper.run(q, k, v)
    torch.cuda.synchronize()
    _, time = benchmark_forward(wrapper.run, q, k, v, repeats=100, verbose=False, desc='Flashinfer')
    print(f'{seq_len} flops:{flops/time.mean*1e-12}')

kv_layout = "NHD"
workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, kv_layout
)

is_causal = True
print(f"is_causal: {is_causal}")
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
    flops = 4 * qo_head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
    q = torch.randn(batch*seq_len, qo_head, headdim, dtype=torch.float16, device="cuda")
    k = torch.randn(batch, seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")
    v = torch.randn(batch, seq_len, kv_head, headdim, dtype=torch.float16, device="cuda")

    dtype = torch.float8_e5m2
    k_scale = k.amax().item() / 256
    v_scale = v.amax().item() / 256

    k_fp8 = (k / k_scale).to(dtype).transpose(0, 1)
    v_fp8 = (v / v_scale).to(dtype).transpose(0, 1)

    page_size = 1
    q_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    kv_indptr = torch.arange(0, batch + 1).to(0).int() * seq_len
    kv_indices = torch.arange(0, seq_len).to(0).int()
    kv_last_page_len = torch.full(
        (batch,), (seq_len - 1) % page_size + 1, dtype=torch.int32
    ).to(0)

    wrapper.plan(
            q_indptr,
            kv_indptr,
            kv_indices,
            kv_last_page_len,
            qo_head,
            kv_head,
            headdim,
            page_size,
            causal=is_causal,
            q_data_type=torch.float16,
            kv_data_type=dtype,
            use_fp16_qk_reduction=False,
        )
    o = wrapper.run(q, (k_fp8, v_fp8), k_scale, v_scale)
    for i in range(5): wrapper.run(q, (k_fp8, v_fp8), k_scale, v_scale)
    torch.cuda.synchronize()
    _, time = benchmark_forward(wrapper.run, q, (k_fp8, v_fp8), k_scale, v_scale, repeats=100, verbose=False, desc='Flashinfer FP8')
    print(f'fp8 {seq_len} flops:{flops/time.mean*1e-12}')```

The Result is as follows:

is_causal: False
1024 flops:91.40299148402637
2048 flops:107.07696479819627
4096 flops:108.43416398791643
8192 flops:109.00693589495772
16384 flops:109.32962903647683
32768 flops:109.45145192913175

is_causal: True
1024 flops:68.55602643189259
2048 flops:81.8157771148776
4096 flops:95.28764360811904
8192 flops:102.80001359824226
16384 flops:106.81962644318786
32768 flops:108.20337146264986

is_causal: True
fp8 1024 flops:60.862217383540795
fp8 2048 flops:77.45613669425248
fp8 4096 flops:90.72903913564886
fp8 8192 flops:98.91800674555174
fp8 16384 flops:101.95377017553793
fp8 32768 flops:103.3322699046888

yongchaoding avatar Mar 06 '25 07:03 yongchaoding