tvm
tvm copied to clipboard
Update flash attention to integrate flash decoding with paged KV cache
Flash attention recently added support for loading from paged KV cache in https://github.com/Dao-AILab/flash-attention/commit/54e80a3829c6d2337570d01e78ebd9529c02d342. The support was added to the Flash-Decoding kernel, which we haven't used so far.
This PR lets us use Flash Decoding with paged KV cache support from TVM. We already use other kernels from Flash attention via BYOC, but due to the specialized nature of this kernel, it is supported as a contrib kernel (similar to vllm).
See also https://github.com/tlc-pack/libflash_attn/pull/9
@vinx13 @sunggg
LGTM! When passing paged kv cache, is there any assumption there? e.g., layout
Yes, I added shape and dtype requirements as comments.
@vinx13 Does this CI failure seem like a compilation timeout https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-gpu/detail/PR-16474/30/pipeline? I remember you hit something like this before.
@masahi this is probably the case, it didn't happen for this kernel before though
oof. It's not surprising, since I added a new variant of kernel (flash decoding) with yet another many explicit template instantiations https://github.com/tlc-pack/libflash_attn/pull/9
we can move gpu builder to larger cpu instance if needed