flash-attention
flash-attention copied to clipboard
torch.compile(fullgraph=True) support for flash-decoding
I was optimizing inference for a CPU overhead bound model and ran into problems compiling the model to a CUDA graph with flash decoding. I've tried to wrap the flash decoding kernel into torch.library however the suggested path in the error message suggests adding functionalisation to the C++ code. Is there an alternative to making this modification + recompiling the flash-attention library?
Vanilla compilation
Vanilla compilation doesn't work
import torch
from flash_attn import flash_attn_with_kvcache
with torch.device("cuda"):
q = torch.randn((1, 2, 2, 4), dtype=torch.bfloat16)
k_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
v_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
k = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
v = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
cache_seqlens = torch.tensor([3], dtype=torch.int32)
flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens)
torch.compile(flash_attn_with_kvcache, fullgraph=True)(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens
)
it leads to the following error;
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(fwd_kvcache) __call__ [TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), LazyVariableTracker(), LazyVariableTracker(), ConstantVariable(NoneType), ConstantVariable(NoneType), LazyVariableTracker(), ConstantVariable(NoneType), ConstantVariable(float), LazyVariableTracker(), ConstantVariable(int), ConstantVariable(int), LazyVariableTracker(), LazyVariableTracker()] {}
from user code:
File "/home/vatsal/miniconda3/envs/os/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 1189, in flash_attn_with_kvcache
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch.library compilation
import torch
from flash_attn import flash_attn_with_kvcache
with torch.device("cuda"):
q = torch.randn((1, 2, 2, 4), dtype=torch.bfloat16)
k_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
v_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
k = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
v = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
cache_seqlens = torch.tensor([3], dtype=torch.int32)
torch.library.define(
"mylib::custom_func",
"(Tensor q, Tensor(a!) k_cache, Tensor(a!) v_cache, Tensor k, Tensor v, Tensor cache_seqlens) -> Tensor",
)
@torch.library.impl("mylib::custom_func", "cuda")
def custom_func(q, k_cache, v_cache, k, v, cache_seqlens):
return flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens
)
@torch.library.impl_abstract("mylib::custom_func")
def custom_func_abstract(q, k_cache, v_cache, k, v, cache_seqlens):
return torch.empty_like(q)
assert torch.allclose(
flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens),
torch.ops.mylib.custom_func(q, k_cache, v_cache, k, v, cache_seqlens),
)
torch.compile(torch.ops.mylib.custom_func, fullgraph=True)(
q, k_cache, v_cache, k, v, cache_seqlens
)
This raises the following error:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: mylib::custom_func.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. Please file a GitHub issue if you run into any problems.
While executing %custom_func : [num_users=1] = call_function[target=torch.ops.mylib.custom_func](args = (%l_args_0_, %l_args_1_, %l_args_2_, %l_args_3_, %l_args_4_, %l_args_5_), kwargs = {})
I've been trying to make this work, but I'm not experienced with torch.compile. If you figure sth out please let me know.
Have you also tried the functionalization approach mentioned in their message or is it worth trying? (https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa)
It's seems to be recommending creating an out-of-place copy of the op, which sounds like it might not be a great thing to do given the size of the kv-cache?
I followed up with them on https://github.com/pytorch/pytorch/issues/120441 and the auto-functionalisation approach from pytorch nightly fixes the error, however it has perf problems as I've outlined in https://github.com/pytorch/pytorch/issues/120441#issuecomment-1961550554
Thanks for the investigation. So right now sounds like it's hard to do in-place ops with torch.compile.
Hi @tridao :) I saw that xformers already added torch.compile support for flash-attention.
I guess it shouldn't be hard to take this feature back to here
https://github.com/facebookresearch/xformers/commit/5d590237a0544e6069a79add0800168af319f1a8
https://github.com/facebookresearch/xformers/releases/tag/v0.0.26.post1
Sure, can you point me to how they do it?
yes:) generally using https://pytorch.org/docs/stable/library.html like they did in the commit: https://github.com/facebookresearch/xformers/commit/5d590237a0544e6069a79add0800168af319f1a8
Sure, can you point me to how they do it?
Hey! Are there plans on implementing this? From @TamirFriedman-RecoLabs message seems like it's only the problem of correctly defining flash attention operators through PyTorch's library API