flash-attention
flash-attention copied to clipboard
torch.compile(fullgraph=True) support for flash-decoding
I was optimizing inference for a CPU overhead bound model and ran into problems compiling the model to a CUDA graph with flash decoding. I've tried to wrap the flash decoding kernel into torch.library
however the suggested path in the error message suggests adding functionalisation to the C++ code. Is there an alternative to making this modification + recompiling the flash-attention library?
Vanilla compilation
Vanilla compilation doesn't work
import torch
from flash_attn import flash_attn_with_kvcache
with torch.device("cuda"):
q = torch.randn((1, 2, 2, 4), dtype=torch.bfloat16)
k_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
v_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
k = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
v = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
cache_seqlens = torch.tensor([3], dtype=torch.int32)
flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens)
torch.compile(flash_attn_with_kvcache, fullgraph=True)(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens
)
it leads to the following error;
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(fwd_kvcache) __call__ [TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), LazyVariableTracker(), LazyVariableTracker(), ConstantVariable(NoneType), ConstantVariable(NoneType), LazyVariableTracker(), ConstantVariable(NoneType), ConstantVariable(float), LazyVariableTracker(), ConstantVariable(int), ConstantVariable(int), LazyVariableTracker(), LazyVariableTracker()] {}
from user code:
File "/home/vatsal/miniconda3/envs/os/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 1189, in flash_attn_with_kvcache
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch.library compilation
import torch
from flash_attn import flash_attn_with_kvcache
with torch.device("cuda"):
q = torch.randn((1, 2, 2, 4), dtype=torch.bfloat16)
k_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
v_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
k = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
v = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
cache_seqlens = torch.tensor([3], dtype=torch.int32)
torch.library.define(
"mylib::custom_func",
"(Tensor q, Tensor(a!) k_cache, Tensor(a!) v_cache, Tensor k, Tensor v, Tensor cache_seqlens) -> Tensor",
)
@torch.library.impl("mylib::custom_func", "cuda")
def custom_func(q, k_cache, v_cache, k, v, cache_seqlens):
return flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens
)
@torch.library.impl_abstract("mylib::custom_func")
def custom_func_abstract(q, k_cache, v_cache, k, v, cache_seqlens):
return torch.empty_like(q)
assert torch.allclose(
flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens),
torch.ops.mylib.custom_func(q, k_cache, v_cache, k, v, cache_seqlens),
)
torch.compile(torch.ops.mylib.custom_func, fullgraph=True)(
q, k_cache, v_cache, k, v, cache_seqlens
)
This raises the following error:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: mylib::custom_func.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. Please file a GitHub issue if you run into any problems.
While executing %custom_func : [num_users=1] = call_function[target=torch.ops.mylib.custom_func](args = (%l_args_0_, %l_args_1_, %l_args_2_, %l_args_3_, %l_args_4_, %l_args_5_), kwargs = {})