flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

torch.compile(fullgraph=True) support for flash-decoding

Open vatsalaggarwal opened this issue 1 year ago • 8 comments

I was optimizing inference for a CPU overhead bound model and ran into problems compiling the model to a CUDA graph with flash decoding. I've tried to wrap the flash decoding kernel into torch.library however the suggested path in the error message suggests adding functionalisation to the C++ code. Is there an alternative to making this modification + recompiling the flash-attention library?

Vanilla compilation

Vanilla compilation doesn't work

import torch
from flash_attn import flash_attn_with_kvcache

with torch.device("cuda"):
    q = torch.randn((1, 2, 2, 4), dtype=torch.bfloat16)
    k_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
    v_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
    k = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
    v = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
    cache_seqlens = torch.tensor([3], dtype=torch.int32)


flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens)
torch.compile(flash_attn_with_kvcache, fullgraph=True)(
    q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens
)

it leads to the following error;

torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(fwd_kvcache) __call__ [TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable(), LazyVariableTracker(), LazyVariableTracker(), ConstantVariable(NoneType), ConstantVariable(NoneType), LazyVariableTracker(), ConstantVariable(NoneType), ConstantVariable(float), LazyVariableTracker(), ConstantVariable(int), ConstantVariable(int), LazyVariableTracker(), LazyVariableTracker()] {}

from user code:
   File "/home/vatsal/miniconda3/envs/os/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 1189, in flash_attn_with_kvcache
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

torch.library compilation

import torch
from flash_attn import flash_attn_with_kvcache

with torch.device("cuda"):
    q = torch.randn((1, 2, 2, 4), dtype=torch.bfloat16)
    k_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
    v_cache = torch.randn((1, 5, 2, 4), dtype=torch.bfloat16)
    k = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
    v = torch.randn((1, 1, 2, 4), dtype=torch.bfloat16)
    cache_seqlens = torch.tensor([3], dtype=torch.int32)

torch.library.define(
    "mylib::custom_func",
    "(Tensor q, Tensor(a!) k_cache, Tensor(a!) v_cache, Tensor k, Tensor v, Tensor cache_seqlens) -> Tensor",
)

@torch.library.impl("mylib::custom_func", "cuda")
def custom_func(q, k_cache, v_cache, k, v, cache_seqlens):
    return flash_attn_with_kvcache(
        q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens
    )

@torch.library.impl_abstract("mylib::custom_func")
def custom_func_abstract(q, k_cache, v_cache, k, v, cache_seqlens):
    return torch.empty_like(q)

assert torch.allclose(
    flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens),
    torch.ops.mylib.custom_func(q, k_cache, v_cache, k, v, cache_seqlens),
)
torch.compile(torch.ops.mylib.custom_func, fullgraph=True)(
    q, k_cache, v_cache, k, v, cache_seqlens
)

This raises the following error:

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: mylib::custom_func.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. Please file a GitHub issue if you run into any problems.

While executing %custom_func : [num_users=1] = call_function[target=torch.ops.mylib.custom_func](args = (%l_args_0_, %l_args_1_, %l_args_2_, %l_args_3_, %l_args_4_, %l_args_5_), kwargs = {})

vatsalaggarwal avatar Feb 22 '24 16:02 vatsalaggarwal