flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Add support for small page sizes

Open skrider opened this issue 2 years ago • 23 comments

Recently, support has been added for paged attention with large page sizes of 256 tokens. However, projects which use paged attention prefer smaller page sizes of around 16. This PR adds support for smaller page sizes by reshaping the GMEM -> SMEM copy to ensure that in each iteration of the mainloop each thread fetches only from a single page. Hence physical page addresses need only be resolved at the beginning of each mainloop iteration and can be resolved per-thread rather than per-CTA.

Preliminary benchmarking with ncu on the unit testing suite shows no degradation in performance.

skrider avatar Feb 13 '24 08:02 skrider

Thanks for your great work! Small page size is important for llm inference framework. Expect this pr could be merged soon.

zhaoyang-star avatar Feb 20 '24 03:02 zhaoyang-star

Fixed issue with fused RoPE embeddings - should be ready for review.

skrider avatar Feb 26 '24 07:02 skrider

Hi, I am waiting for this PR! Is this planning to be merged soon? Also, can I ask when it is planned to be released?

rkooo567 avatar Feb 28 '24 09:02 rkooo567

Not sure - @tridao if you have time, would greatly appreciate a review so I can make any changes necessary to get this PR merged!

skrider avatar Feb 29 '24 22:02 skrider

Hi @skrider, thanks for the great work! Based on my test, this kernel is 1.5-4x faster than the triton equivalent. But when I use it for end-to-end testing in vLLM, I hit RuntimeError: CUDA error: an illegal memory access was encountered. Below is the minimum code to reproduce:

import torch
from flash_attn import flash_attn_with_kvcache

def cdiv(a, b):
    return (a + b - 1) // b

block_size = 16
num_blocks = 1000*16//block_size
bs = 4
seq_len = 170
num_heads = 32
head_dim = 128

key_cache = torch.rand([num_blocks, block_size, num_heads, head_dim]).half().cuda()
value_cache = torch.rand([num_blocks, block_size, num_heads, head_dim]).half().cuda()
cache_seqlens = torch.zeros(bs, dtype=torch.int32).cuda()

for _ in range(1000):
    query = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
    key = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
    value = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
    block_tables = torch.randint(0, num_blocks, size=(bs, cdiv(seq_len, block_size)), dtype=torch.int32, device="cuda")
    output = flash_attn_with_kvcache(
        query,
        key_cache,
        value_cache,
        k=key,
        v=value,
        cache_seqlens=cache_seqlens,
        block_table=block_tables,
        causal=True,
    )

Error message:

Traceback (most recent call last):
  File "/home/ubuntu/src/vllm-test/debug.py", line 21, in <module>
    value = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Some observations:

  • This only occurs in the prefill stage and it happens sporadically. Using it for decoding (single query or multi-query) seems fine.
  • The error is gone after increasing the block_size to 256.
  • The error still exists after removing (k=key, v=value). So the illegal memory access may happen when reading from page blocks.

ymwangg avatar Mar 02 '24 04:03 ymwangg

@ymwangg Thanks for the heads up - I will look into it.

Reproducing with the provided code, I believe the error is with an async copy not being properly awaited. Synchronizing after every launch and setting a manual seed does not get rid of the nondeterminism. Additionally, if I only run the kernel every two iterations, then the kernel never errors. Small num_heads also gets rid of the issue. To me this suggests that the state of the L2 cache is correlated with the error. Signing off for tonight but will revisit when I have time.

skrider avatar Mar 02 '24 06:03 skrider

@ymwangg is it possible that passing k=key and v=value each iteration causing seq_len=170 tokens appended to the kvcache each time, which overflows after couple of iterations?

gnap avatar Mar 11 '24 09:03 gnap

@ymwangg is it possible that passing k=key and v=value each iteration causing seq_len=170 tokens appended to the kvcache each time, which overflows after couple of iterations?

My understanding is that this function does not allocate new memory but rather using block_table to identify the memory address to read/write. So as long as the block_id in block_table is valid, it should not cause overflow issue.

ymwangg avatar Mar 11 '24 17:03 ymwangg

after inspection locally found that the illegal mem access is caused by table_diff calculation overflows and propagated to further iterations.

since the n_block is iterated in reverse order, the calculated virtual_page_idx_next of the page_table may be larger than the table allocated in the first round, getting undetermined table_diffs and never get fixed by advancing tKgK.data() relatively.

so the fix is straight forward: use init_thread_kv_page_slice_offset(..., n_block, ...) to calculate the absolute offset and add to gK.data()/gV.data() directly. the copy_w_min_idx() would guarantee that only rows in range are copied.

tested locally the illegal access is gone and the flash_attn_kvcache tests are passed.

gnap avatar Mar 14 '24 11:03 gnap

Btw, do we plan to merge this soon?

rkooo567 avatar Mar 19 '24 02:03 rkooo567

The fix mentioned by @gnap was implemented by @ymwangg in this commit: https://github.com/ymwangg/flash-attention/commit/73541983dec952980b43aac36da296e6bc517211.

@gnap, would you mind checking it to verify that's what you had in mind? @ymwangg told me the illegal access is gone with that commit on top of this PR.

Could someone pull that fix into this PR and fix the conflicts so this can be merged?

davidthomas426 avatar Mar 20 '24 21:03 davidthomas426

@davidthomas426 @ymwangg I have checked that commit and the modification is mutually identical with my local change. currently I am conducting more tests with our internal inference engine. but if the vllm community tests okay, feel free to commit or notify @skrider to update this PR.

gnap avatar Mar 21 '24 08:03 gnap

Thanks for your great work! Does this PR support varlen with KV block: https://github.com/Dao-AILab/flash-attention/commit/2a15840f09905a55adbb2d218c3cd8244d8b922d

mjp9527 avatar Mar 21 '24 11:03 mjp9527

Thank you everyone for all the help! I will review locally and push the fix. I used the difference between page indices rather than calculating the offset directly because that's how it was done originally. Besides saving a register I am not sure if there are any advantages to doing this.

@gnap curious what your process was for finding the bug?

Thanks for your great work! Does this PR support varlen with KV block: https://github.com/Dao-AILab/flash-attention/commit/2a15840f09905a55adbb2d218c3cd8244d8b922d

In progress, expect it sometime next week

skrider avatar Mar 22 '24 23:03 skrider

These changes pass unit tests for standard and varlen APIs as well as the example provided above by @ymwangg

skrider avatar Mar 26 '24 06:03 skrider

@gnap curious what your process was for finding the bug?

by ran the compute-santinizer --tool memcheck against the reproduction code @ymwangg provided, which showed that some threads did access memory addresses way smaller than gK, gV's gmem_ptr, then with some printings did find that table_diffs could be larger than the partitioned copy tile's strides.

gnap avatar Mar 28 '24 13:03 gnap

Thanks so much for your work @skrider. Can you rebase and then I'll merge?

tridao avatar Apr 10 '24 01:04 tridao

@skrider Are you going to rebase this so it can get merged?

davidthomas426 avatar Apr 29 '24 22:04 davidthomas426

@tridao absolutely! Sorry, just seeing this. Notification fell through the cracks.

skrider avatar May 02 '24 07:05 skrider

@skrider It looks like rebasing was pretty easy. would you mind if I just create a PR to your branch? (I just ran git merge main, and no conflict)

rkooo567 avatar May 03 '24 12:05 rkooo567

@skrider Hi, any updates on this PR?

yangelaboy avatar Jun 11 '24 12:06 yangelaboy

Any updates on the current PR ? @skrider

jorgeantonio21 avatar Aug 29 '24 17:08 jorgeantonio21

This PR has already been merged into this repository and is now part of the version of flash_attention, which vllm depends on.

itsliupeng avatar Aug 30 '24 02:08 itsliupeng