tvm icon indicating copy to clipboard operation
tvm copied to clipboard

Update flash attention to integrate flash decoding with paged KV cache

Open masahi opened this issue 1 year ago • 5 comments

Flash attention recently added support for loading from paged KV cache in https://github.com/Dao-AILab/flash-attention/commit/54e80a3829c6d2337570d01e78ebd9529c02d342. The support was added to the Flash-Decoding kernel, which we haven't used so far.

This PR lets us use Flash Decoding with paged KV cache support from TVM. We already use other kernels from Flash attention via BYOC, but due to the specialized nature of this kernel, it is supported as a contrib kernel (similar to vllm).

See also https://github.com/tlc-pack/libflash_attn/pull/9

@vinx13 @sunggg

masahi avatar Jan 26 '24 03:01 masahi

LGTM! When passing paged kv cache, is there any assumption there? e.g., layout

Yes, I added shape and dtype requirements as comments.

masahi avatar Jan 26 '24 20:01 masahi

@vinx13 Does this CI failure seem like a compilation timeout https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-gpu/detail/PR-16474/30/pipeline? I remember you hit something like this before.

masahi avatar Jan 26 '24 21:01 masahi

@masahi this is probably the case, it didn't happen for this kernel before though

vinx13 avatar Jan 26 '24 22:01 vinx13

oof. It's not surprising, since I added a new variant of kernel (flash decoding) with yet another many explicit template instantiations https://github.com/tlc-pack/libflash_attn/pull/9

masahi avatar Jan 26 '24 22:01 masahi

we can move gpu builder to larger cpu instance if needed

vinx13 avatar Jan 29 '24 17:01 vinx13