flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

Chunk wise linear attn kernel does not work with torch compile (returns incorrects values / NaNs)

Open juankost opened this issue 1 year ago • 9 comments

I've been trying to use the linear attention kernels in a model which I am compiling, however the triton kernel does not seem to work with torch compile. Specifically, when comparing the output of the kernel to a reference implementation, they match when torch compile is not used, and they have huge discrepancies when the kernel is wrapped with torch compile)

Here is a script to reproduce the error:

import torch
import sys
import os
from fla.ops.linear_attn.chunk_fuse import FusedChunkLinearAttentionFunction
fused_chunk_linear_attn = FusedChunkLinearAttentionFunction.apply

def reference_implementation(q, k, v, scale, initial_state, output_final_state):
    # q,k,v: B, H, L, E
    q = q * scale
    attn_weights = torch.matmul(q, k.transpose(2, 3))
    causal_mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])).to(q.device).bool()
    attn_weights.masked_fill_(~causal_mask, 0.0)
    attn_output = torch.matmul(attn_weights, v)

    if initial_state is not None:
        state_contribution = torch.matmul(q, initial_state)
        attn_output += state_contribution

    hidden_state = None
    if output_final_state:
        hidden_state = torch.einsum("bhle,bhlf->bhef", k, v)
        if initial_state is not None:
            hidden_state += initial_state
    return attn_output, hidden_state


# SE (02/26): borrowing these tolerances from Mamba's test_selective_state_update
# https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/tests/ops/triton/test_selective_state_update.py#L24C1-L26C32
DTYPE_TO_ATOL = {
    torch.float32: 1e-3,
    torch.float16: 1e-2,
    torch.bfloat16: 5e-2,
}
DTYPE_TO_RTOL = {
    torch.float32: 3e-4,
    torch.float16: 5e-4,
    torch.bfloat16: 1e-2,
}


def compare_two_outputs(output1, output2, atol=1e-2, rtol=0):
    def compute_stats(a, b, atol=1e-2, rtol=0):
        abs_diff = torch.abs(a - b)
        mean_abs_diff = torch.mean(abs_diff).item()
        max_abs_diff = torch.max(abs_diff).item()
        not_close_mask = ~torch.isclose(a, b, atol=atol, rtol=rtol)
        mean_abs_diff_not_close = (
            torch.mean(abs_diff[not_close_mask]).item() if not_close_mask.any() else 0
        )
        mean_rel_diff_not_close = (
            torch.mean(abs_diff[not_close_mask] / (torch.abs(b[not_close_mask]) + 1e-7)).item()
            if not_close_mask.any()
            else 0
        )
        return {
            "mean_abs_diff": mean_abs_diff,
            "max_abs_diff": max_abs_diff,
            "mean_abs_diff_not_close": mean_abs_diff_not_close,
            "mean_rel_diff_not_close": mean_rel_diff_not_close,
            "percent_not_close": 100 * not_close_mask.float().mean().item(),
        }

    # if not torch.allclose(output1, output2, atol=atol, rtol=rtol):
    if not torch.allclose(output1, output2):
        output_stats = compute_stats(output1, output2, atol=atol, rtol=rtol)
        print("Different outputs")
        print(f"Mean absolute difference: {output_stats['mean_abs_diff']:.6e}")
        print(f"Max absolute difference: {output_stats['max_abs_diff']:.6e}")
        print(
            f"Mean absolute difference (not close): {output_stats['mean_abs_diff_not_close']:.6e}"
        )
        print(
            f"Relative absolute difference (not close): {output_stats['mean_rel_diff_not_close']:.6e}"
        )
        print(f"Percent not close: {output_stats['percent_not_close']:.2f}%")
    else:
        print("Outputs match")
    print("-" * 80)  # Separator for readability
    return torch.allclose(output1, output2, atol=atol, rtol=rtol)


def test_torch_compile(dtype):
    # Set up test parameters
    B, H, L, E = 1, 8, 2048, 64
    scale = 0.125
    output_final_state = True

    # Create input tensors
    q = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    k = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    v = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    initial_state = (
        torch.empty((B, H, E, E), dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .contiguous()
    )

    # Define a function that uses the kernel
    def attention_fn_kernel(q, k, v, initial_state):  # -> Any:
        return fused_chunk_linear_attn(q, k, v, scale, initial_state, output_final_state)

    def attention_fn_pytorch(q, k, v, initial_state):  # -> Any:
        return reference_implementation(q, k, v, scale, initial_state, output_final_state)

    # Compile/warmup the kernels
    compiled_fn_kernel = torch.compile(attention_fn_kernel)
    compiled_fn_pytorch = torch.compile(attention_fn_pytorch)
    with torch.autocast(device_type="cuda", dtype=dtype):
        output1, _ = compiled_fn_kernel(q, k, v, initial_state)
        output2, _ = compiled_fn_pytorch(q, k, v, initial_state)

    # Run the forward passes
    with torch.autocast(device_type="cuda", dtype=dtype):
        output_kernel_comp, _ = compiled_fn_kernel(q, k, v, initial_state)
        output_pytorch_comp, _ = compiled_fn_pytorch(q, k, v, initial_state)
        output_kernel, _ = attention_fn_kernel(q, k, v, initial_state)
        output_pytorch, _ = attention_fn_pytorch(q, k, v, initial_state)

    # Check if any of the compiled outputs match the original
    atol = DTYPE_TO_ATOL[dtype]
    rtol = DTYPE_TO_RTOL[dtype]

    print("Compiled vs non compiled Pytorch")
    compare_two_outputs(output_pytorch_comp, output_pytorch, atol=atol, rtol=rtol)
    print("Triton (non compiled) vs Pytorch (non compiled)")
    compare_two_outputs(output_kernel.float(), output_pytorch.float(), atol=atol, rtol=rtol)
    print("Triton (compiled) vs Pytorch (non compiled)")
    compare_two_outputs(output_kernel_comp.float(), output_pytorch.float(), atol=atol, rtol=rtol)
    print("Test compile completed")


if __name__ == "__main__":

    for dtype in [torch.float16, torch.bfloat16, torch.float32]:
        print(f"Testing torc compile for dtype {dtype}")
        test_torch_compile(dtype)
        print("=" * 80)

The output of the script (running on A100 40GB GPU):

Testing torc compile for dtype torch.float16                                                                                             [109/1908]Compiled vs non compiled Pytorch                                                                                                                   Outputs match                                                                                                                                      
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)                                                                                                    Different outputs                                                                                                                                  
Mean absolute difference: 1.049560e-03                                                                                                            
Max absolute difference: 1.562500e-02                                                                                                              
Mean absolute difference (not close): 1.562500e-02                                                                                                 Relative absolute difference (not close): 1.805531e-03                                                                                             Percent not close: 0.00%                                                                                                                           
--------------------------------------------------------------------------------                                                                   Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 3.020972e+00
Max absolute difference: 2.664062e+01
Mean absolute difference (not close): 3.029570e+00
Relative absolute difference (not close): 9.999998e-01
Percent not close: 99.72%
--------------------------------------------------------------------------------
Test compile completed
==================================================
Testing torc compile for dtype torch.bfloat16                                                                                                      Compiled vs non compiled Pytorch                                                                                                                   
Outputs match                                                                                                                                      
--------------------------------------------------------------------------------                                                                   Triton (non compiled) vs Pytorch (non compiled)                                                                                                    Different outputs                                                                                                                                  
Mean absolute difference: 8.357576e-03                                                                                                             
Max absolute difference: 1.250000e-01                                                                                                              
Mean absolute difference (not close): 5.615234e-02                                                                                                 Relative absolute difference (not close): 1.953556e-01                                                                                             Percent not close: 0.00%
--------------------------------------------------------------------------------                                                                   Triton (compiled) vs Pytorch (non compiled)                                                                                                        Different outputs                                                                                                                                  
Mean absolute difference: nan                                                                                                                      
Max absolute difference: nan                                                                                                                       
Mean absolute difference (not close): nan                                                                                                          Relative absolute difference (not close): nan                                                                                                      
Percent not close: 100.00%                                                                                                                         
--------------------------------------------------------------------------------                                                                   Test compile completed
==============================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.531446e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: nan
Max absolute difference: nan
Mean absolute difference (not close): nan
Relative absolute difference (not close): nan
Percent not close: 100.00%
--------------------------------------------------------------------------------
Test compile completed
===============================================

Relevant libraries from environment:

Versions of relevant libraries:                                                                                                                    
[pip3] numpy==1.24.3
[pip3] torch==2.4.0.dev20240523                                                                                                                    
[pip3] torchaudio==2.2.0.dev20240523
[pip3] torchvision==0.19.0.dev20240523                                                                                                            
[pip3] triton==3.0.0                                                                                                                               
[pip3] triton-nightly==3.0.0.post20240522224832
[conda] blas                      1.0                         mkl    intel
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly                                                               [conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
[conda] mkl                       2021.4.0              intel_640    intel
[conda] mkl-service               2.4.0           py311h9bf148f_0    pytorch-nightly
[conda] mkl_fft                   1.3.1           py311hc796f24_0    pytorch-nightly
[conda] mkl_random                1.2.2           py311hbba84a0_0    pytorch-nightly
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.26.4                   pypi_0    pypi                                                                         
[conda] numpy-base                1.24.3          py311hfd5febd_0
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.4.0.dev20240523 py3.11_cuda12.4_cudnn8.9.2_0    pytorch-nightly
[conda] pytorch-cuda              12.4                 hc786d27_6    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchaudio                2.2.0.dev20240523     py311_cu124    pytorch-nightly
[conda] torchtriton               3.0.0+45fff310c8           py311    pytorch-nightly
[conda] torchvision               0.19.0.dev20240523     py311_cu124    pytorch-nightly                                                            [conda] triton-nightly            3.0.0.post20240522224832          pypi_0    pypi
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly 

juankost avatar Aug 20 '24 11:08 juankost

I think I have localized the issue. It seems to arrive from the line 249 where the outputs of each block are summed up.

Note, I tested the following to determine the cause of the error:

        # The current approach in the code
        o = o.sum(0)
        return o.to(q.dtype), final_state
        
        # When NK == 1, we do not need to sum. 
        # return o[0].to(q.dtype), final_state  # --> this removes the large discrepancies in the output

"Workaround" (not sure how this impact on the torch compile performance):

# Adding the print statement removes the large discprepancies observed in the torch compile
print("")
o = o.sum(0)

I am not sure why exactly this works, but it can help to come up with a better solution

juankost avatar Aug 20 '24 12:08 juankost

@juankost Hi, could you report your torch & triton version, I met some errors when running your code snippets.

yzhangcs avatar Aug 20 '24 12:08 yzhangcs

@yzhangcs

Pytorch 2.4.0.dev20240523 Triton 3.0.0+45fff310c8

What type of error are you getting?

I think the problem is that torch compile tries to fuse the o.sum(0) operation with the triton kernel, and the sum operation might be executed before all the thread blocks finished computing the output. The print statement basically breaks the torch compile graph before the sum, which probably ensure that the sum operation is done after the triton kernel is fully finished

juankost avatar Aug 20 '24 12:08 juankost

Hi, mine is torch 2.3 & triton 2.3

TypeError: launcher() got an unexpected keyword argument 'num_warps'

not sure what exactly happened.

Just provide a quick fix, could you check it again?

yzhangcs avatar Aug 20 '24 13:08 yzhangcs

I tested it, it does not work when NK > 1 (so when the sum operation actually gets done).

Running torch dynamo explain on the function, it shows that it compiles it into a single graph still

Graph Count: 1                                                                                                                                     
Graph Break Count: 0
Op Count: 4                                                                                                                                        
Break Reasons:                                                                                                                                     
Ops per Graph:
  Ops 1:                                                                                                                                               
<class 'torch.autograd.function.FunctionCtx'>                                                                                                      autograd_function_apply                                                                                                                            <built-in function getitem>                                                                                                                        <built-in function getitem> 

while if I have the print statement, it does indeed break the compilation graph. I have not figured out a better way to break the graph (or why this is a problem in the first place). Maybe we could raise an issue to the Pytorch team to see their take on it

juankost avatar Aug 20 '24 13:08 juankost

@juankost Here is my current output with torch2.4 and triton3.0

Testing torc compile for dtype torch.float16
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 1.049412e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02
Relative absolute difference (not close): 1.618346e-03
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 1.049412e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02
Relative absolute difference (not close): 1.618346e-03
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.float16
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 1.050618e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02
Relative absolute difference (not close): 1.841894e-03
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 1.050618e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02
Relative absolute difference (not close): 1.841894e-03
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.bfloat16
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 8.342735e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.664062e-02
Relative absolute difference (not close): 1.106870e-01
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 8.342735e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.664062e-02
Relative absolute difference (not close): 1.106870e-01
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.bfloat16
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 8.370323e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.322266e-02
Relative absolute difference (not close): 5.157646e-01
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 8.370323e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.322266e-02
Relative absolute difference (not close): 5.157646e-01
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.526679e-06
Max absolute difference: 4.577637e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.526679e-06
Max absolute difference: 4.577637e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.525714e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.525714e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================

yzhangcs avatar Aug 20 '24 14:08 yzhangcs

@yzhangcs What input dimensions are you using?

If you are just running my script from above, your proposed fix will work, since the NK == 1 (i.e. the input head dimension is 64, so NK will be 1)

If you increase the input head dimensions like the code like below, then it will still fail, no?

def test_torch_compile(dtype):
    # Set up test parameters
    B, H, L, E = 1, 8, 2048, 64 --> Change this to 128 instead of 64

juankost avatar Aug 20 '24 14:08 juankost

I’m just using the above code. If still failing, we might consider other solutions for this issue

yzhangcs avatar Aug 20 '24 14:08 yzhangcs

Yeah, it still fails for me when increasing the head dimension.

I asked on the Pytorch forum, if they have any suggestions or explanations for this https://discuss.pytorch.org/t/custom-triton-kernel-with-torch-compile-returns-incorrect-outputs/208355

juankost avatar Aug 20 '24 14:08 juankost

@juankost closing this as answered by the link you provided. We've not experimented too much with torch.compile yet. Any PRs are welcome if you fix this issue.

yzhangcs avatar Aug 22 '24 06:08 yzhangcs

@yzhangcs For now, I simply did a workaround to remove the parallelization over the K dimension, such that the operation

o = o.sum(0) 

at the end of the kernel can be dropped. This "fixes" the problem with torch.compile, but is not ideal in terms of optimization. I'm following up the issue on the forum. If we find a solution, I'll submit a PR

juankost avatar Aug 28 '24 07:08 juankost