Add support for small page sizes
Recently, support has been added for paged attention with large page sizes of 256 tokens. However, projects which use paged attention prefer smaller page sizes of around 16. This PR adds support for smaller page sizes by reshaping the GMEM -> SMEM copy to ensure that in each iteration of the mainloop each thread fetches only from a single page. Hence physical page addresses need only be resolved at the beginning of each mainloop iteration and can be resolved per-thread rather than per-CTA.
Preliminary benchmarking with ncu on the unit testing suite shows no degradation in performance.
Thanks for your great work! Small page size is important for llm inference framework. Expect this pr could be merged soon.
Fixed issue with fused RoPE embeddings - should be ready for review.
Hi, I am waiting for this PR! Is this planning to be merged soon? Also, can I ask when it is planned to be released?
Not sure - @tridao if you have time, would greatly appreciate a review so I can make any changes necessary to get this PR merged!
Hi @skrider, thanks for the great work! Based on my test, this kernel is 1.5-4x faster than the triton equivalent. But when I use it for end-to-end testing in vLLM, I hit RuntimeError: CUDA error: an illegal memory access was encountered.
Below is the minimum code to reproduce:
import torch
from flash_attn import flash_attn_with_kvcache
def cdiv(a, b):
return (a + b - 1) // b
block_size = 16
num_blocks = 1000*16//block_size
bs = 4
seq_len = 170
num_heads = 32
head_dim = 128
key_cache = torch.rand([num_blocks, block_size, num_heads, head_dim]).half().cuda()
value_cache = torch.rand([num_blocks, block_size, num_heads, head_dim]).half().cuda()
cache_seqlens = torch.zeros(bs, dtype=torch.int32).cuda()
for _ in range(1000):
query = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
key = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
value = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
block_tables = torch.randint(0, num_blocks, size=(bs, cdiv(seq_len, block_size)), dtype=torch.int32, device="cuda")
output = flash_attn_with_kvcache(
query,
key_cache,
value_cache,
k=key,
v=value,
cache_seqlens=cache_seqlens,
block_table=block_tables,
causal=True,
)
Error message:
Traceback (most recent call last):
File "/home/ubuntu/src/vllm-test/debug.py", line 21, in <module>
value = torch.rand([bs, seq_len, num_heads, head_dim], dtype=torch.float16, device="cuda")
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Some observations:
- This only occurs in the prefill stage and it happens sporadically. Using it for decoding (single query or multi-query) seems fine.
- The error is gone after increasing the block_size to 256.
- The error still exists after removing
(k=key, v=value). So the illegal memory access may happen when reading from page blocks.
@ymwangg Thanks for the heads up - I will look into it.
Reproducing with the provided code, I believe the error is with an async copy not being properly awaited. Synchronizing after every launch and setting a manual seed does not get rid of the nondeterminism. Additionally, if I only run the kernel every two iterations, then the kernel never errors. Small num_heads also gets rid of the issue. To me this suggests that the state of the L2 cache is correlated with the error. Signing off for tonight but will revisit when I have time.
@ymwangg is it possible that passing k=key and v=value each iteration causing seq_len=170 tokens appended to the kvcache each time, which overflows after couple of iterations?
@ymwangg is it possible that passing k=key and v=value each iteration causing seq_len=170 tokens appended to the kvcache each time, which overflows after couple of iterations?
My understanding is that this function does not allocate new memory but rather using block_table to identify the memory address to read/write. So as long as the block_id in block_table is valid, it should not cause overflow issue.
after inspection locally found that the illegal mem access is caused by table_diff calculation overflows and propagated to further iterations.
since the n_block is iterated in reverse order, the calculated virtual_page_idx_next of the page_table may be larger than the table allocated in the first round, getting undetermined table_diffs and never get fixed by advancing tKgK.data() relatively.
so the fix is straight forward: use init_thread_kv_page_slice_offset(..., n_block, ...) to calculate the absolute offset and add to gK.data()/gV.data() directly. the copy_w_min_idx() would guarantee that only rows in range are copied.
tested locally the illegal access is gone and the flash_attn_kvcache tests are passed.
Btw, do we plan to merge this soon?
The fix mentioned by @gnap was implemented by @ymwangg in this commit: https://github.com/ymwangg/flash-attention/commit/73541983dec952980b43aac36da296e6bc517211.
@gnap, would you mind checking it to verify that's what you had in mind? @ymwangg told me the illegal access is gone with that commit on top of this PR.
Could someone pull that fix into this PR and fix the conflicts so this can be merged?
@davidthomas426 @ymwangg I have checked that commit and the modification is mutually identical with my local change. currently I am conducting more tests with our internal inference engine. but if the vllm community tests okay, feel free to commit or notify @skrider to update this PR.
Thanks for your great work! Does this PR support varlen with KV block: https://github.com/Dao-AILab/flash-attention/commit/2a15840f09905a55adbb2d218c3cd8244d8b922d
Thank you everyone for all the help! I will review locally and push the fix. I used the difference between page indices rather than calculating the offset directly because that's how it was done originally. Besides saving a register I am not sure if there are any advantages to doing this.
@gnap curious what your process was for finding the bug?
Thanks for your great work! Does this PR support varlen with KV block: https://github.com/Dao-AILab/flash-attention/commit/2a15840f09905a55adbb2d218c3cd8244d8b922d
In progress, expect it sometime next week
These changes pass unit tests for standard and varlen APIs as well as the example provided above by @ymwangg
@gnap curious what your process was for finding the bug?
by ran the compute-santinizer --tool memcheck against the reproduction code @ymwangg provided, which showed that some threads did access memory addresses way smaller than gK, gV's gmem_ptr, then with some printings did find that table_diffs could be larger than the partitioned copy tile's strides.
Thanks so much for your work @skrider. Can you rebase and then I'll merge?
@skrider Are you going to rebase this so it can get merged?
@tridao absolutely! Sorry, just seeing this. Notification fell through the cracks.
@skrider It looks like rebasing was pretty easy. would you mind if I just create a PR to your branch? (I just ran git merge main, and no conflict)
@skrider Hi, any updates on this PR?
Any updates on the current PR ? @skrider
This PR has already been merged into this repository and is now part of the version of flash_attention, which vllm depends on.