flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

Chunk wise linear attn kernel does not work with torch compile (returns incorrects values / NaNs)

Open juankost opened this issue 1 year ago • 9 comments

I've been trying to use the linear attention kernels in a model which I am compiling, however the triton kernel does not seem to work with torch compile. Specifically, when comparing the output of the kernel to a reference implementation, they match when torch compile is not used, and they have huge discrepancies when the kernel is wrapped with torch compile)

Here is a script to reproduce the error:

import torch
import sys
import os
from fla.ops.linear_attn.chunk_fuse import FusedChunkLinearAttentionFunction
fused_chunk_linear_attn = FusedChunkLinearAttentionFunction.apply

def reference_implementation(q, k, v, scale, initial_state, output_final_state):
    # q,k,v: B, H, L, E
    q = q * scale
    attn_weights = torch.matmul(q, k.transpose(2, 3))
    causal_mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])).to(q.device).bool()
    attn_weights.masked_fill_(~causal_mask, 0.0)
    attn_output = torch.matmul(attn_weights, v)

    if initial_state is not None:
        state_contribution = torch.matmul(q, initial_state)
        attn_output += state_contribution

    hidden_state = None
    if output_final_state:
        hidden_state = torch.einsum("bhle,bhlf->bhef", k, v)
        if initial_state is not None:
            hidden_state += initial_state
    return attn_output, hidden_state


# SE (02/26): borrowing these tolerances from Mamba's test_selective_state_update
# https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/tests/ops/triton/test_selective_state_update.py#L24C1-L26C32
DTYPE_TO_ATOL = {
    torch.float32: 1e-3,
    torch.float16: 1e-2,
    torch.bfloat16: 5e-2,
}
DTYPE_TO_RTOL = {
    torch.float32: 3e-4,
    torch.float16: 5e-4,
    torch.bfloat16: 1e-2,
}


def compare_two_outputs(output1, output2, atol=1e-2, rtol=0):
    def compute_stats(a, b, atol=1e-2, rtol=0):
        abs_diff = torch.abs(a - b)
        mean_abs_diff = torch.mean(abs_diff).item()
        max_abs_diff = torch.max(abs_diff).item()
        not_close_mask = ~torch.isclose(a, b, atol=atol, rtol=rtol)
        mean_abs_diff_not_close = (
            torch.mean(abs_diff[not_close_mask]).item() if not_close_mask.any() else 0
        )
        mean_rel_diff_not_close = (
            torch.mean(abs_diff[not_close_mask] / (torch.abs(b[not_close_mask]) + 1e-7)).item()
            if not_close_mask.any()
            else 0
        )
        return {
            "mean_abs_diff": mean_abs_diff,
            "max_abs_diff": max_abs_diff,
            "mean_abs_diff_not_close": mean_abs_diff_not_close,
            "mean_rel_diff_not_close": mean_rel_diff_not_close,
            "percent_not_close": 100 * not_close_mask.float().mean().item(),
        }

    # if not torch.allclose(output1, output2, atol=atol, rtol=rtol):
    if not torch.allclose(output1, output2):
        output_stats = compute_stats(output1, output2, atol=atol, rtol=rtol)
        print("Different outputs")
        print(f"Mean absolute difference: {output_stats['mean_abs_diff']:.6e}")
        print(f"Max absolute difference: {output_stats['max_abs_diff']:.6e}")
        print(
            f"Mean absolute difference (not close): {output_stats['mean_abs_diff_not_close']:.6e}"
        )
        print(
            f"Relative absolute difference (not close): {output_stats['mean_rel_diff_not_close']:.6e}"
        )
        print(f"Percent not close: {output_stats['percent_not_close']:.2f}%")
    else:
        print("Outputs match")
    print("-" * 80)  # Separator for readability
    return torch.allclose(output1, output2, atol=atol, rtol=rtol)


def test_torch_compile(dtype):
    # Set up test parameters
    B, H, L, E = 1, 8, 2048, 64
    scale = 0.125
    output_final_state = True

    # Create input tensors
    q = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    k = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    v = (
        torch.empty(B, H, L, E, dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
        .contiguous()
    )
    initial_state = (
        torch.empty((B, H, E, E), dtype=dtype, device="cuda")
        .normal_(mean=0.0, std=0.5)
        .contiguous()
    )

    # Define a function that uses the kernel
    def attention_fn_kernel(q, k, v, initial_state):  # -> Any:
        return fused_chunk_linear_attn(q, k, v, scale, initial_state, output_final_state)

    def attention_fn_pytorch(q, k, v, initial_state):  # -> Any:
        return reference_implementation(q, k, v, scale, initial_state, output_final_state)

    # Compile/warmup the kernels
    compiled_fn_kernel = torch.compile(attention_fn_kernel)
    compiled_fn_pytorch = torch.compile(attention_fn_pytorch)
    with torch.autocast(device_type="cuda", dtype=dtype):
        output1, _ = compiled_fn_kernel(q, k, v, initial_state)
        output2, _ = compiled_fn_pytorch(q, k, v, initial_state)

    # Run the forward passes
    with torch.autocast(device_type="cuda", dtype=dtype):
        output_kernel_comp, _ = compiled_fn_kernel(q, k, v, initial_state)
        output_pytorch_comp, _ = compiled_fn_pytorch(q, k, v, initial_state)
        output_kernel, _ = attention_fn_kernel(q, k, v, initial_state)
        output_pytorch, _ = attention_fn_pytorch(q, k, v, initial_state)

    # Check if any of the compiled outputs match the original
    atol = DTYPE_TO_ATOL[dtype]
    rtol = DTYPE_TO_RTOL[dtype]

    print("Compiled vs non compiled Pytorch")
    compare_two_outputs(output_pytorch_comp, output_pytorch, atol=atol, rtol=rtol)
    print("Triton (non compiled) vs Pytorch (non compiled)")
    compare_two_outputs(output_kernel.float(), output_pytorch.float(), atol=atol, rtol=rtol)
    print("Triton (compiled) vs Pytorch (non compiled)")
    compare_two_outputs(output_kernel_comp.float(), output_pytorch.float(), atol=atol, rtol=rtol)
    print("Test compile completed")


if __name__ == "__main__":

    for dtype in [torch.float16, torch.bfloat16, torch.float32]:
        print(f"Testing torc compile for dtype {dtype}")
        test_torch_compile(dtype)
        print("=" * 80)

The output of the script (running on A100 40GB GPU):

Testing torc compile for dtype torch.float16                                                                                             [109/1908]Compiled vs non compiled Pytorch                                                                                                                   Outputs match                                                                                                                                      
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)                                                                                                    Different outputs                                                                                                                                  
Mean absolute difference: 1.049560e-03                                                                                                            
Max absolute difference: 1.562500e-02                                                                                                              
Mean absolute difference (not close): 1.562500e-02                                                                                                 Relative absolute difference (not close): 1.805531e-03                                                                                             Percent not close: 0.00%                                                                                                                           
--------------------------------------------------------------------------------                                                                   Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 3.020972e+00
Max absolute difference: 2.664062e+01
Mean absolute difference (not close): 3.029570e+00
Relative absolute difference (not close): 9.999998e-01
Percent not close: 99.72%
--------------------------------------------------------------------------------
Test compile completed
==================================================
Testing torc compile for dtype torch.bfloat16                                                                                                      Compiled vs non compiled Pytorch                                                                                                                   
Outputs match                                                                                                                                      
--------------------------------------------------------------------------------                                                                   Triton (non compiled) vs Pytorch (non compiled)                                                                                                    Different outputs                                                                                                                                  
Mean absolute difference: 8.357576e-03                                                                                                             
Max absolute difference: 1.250000e-01                                                                                                              
Mean absolute difference (not close): 5.615234e-02                                                                                                 Relative absolute difference (not close): 1.953556e-01                                                                                             Percent not close: 0.00%
--------------------------------------------------------------------------------                                                                   Triton (compiled) vs Pytorch (non compiled)                                                                                                        Different outputs                                                                                                                                  
Mean absolute difference: nan                                                                                                                      
Max absolute difference: nan                                                                                                                       
Mean absolute difference (not close): nan                                                                                                          Relative absolute difference (not close): nan                                                                                                      
Percent not close: 100.00%                                                                                                                         
--------------------------------------------------------------------------------                                                                   Test compile completed
==============================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.531446e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: nan
Max absolute difference: nan
Mean absolute difference (not close): nan
Relative absolute difference (not close): nan
Percent not close: 100.00%
--------------------------------------------------------------------------------
Test compile completed
===============================================

Relevant libraries from environment:

Versions of relevant libraries:                                                                                                                    
[pip3] numpy==1.24.3
[pip3] torch==2.4.0.dev20240523                                                                                                                    
[pip3] torchaudio==2.2.0.dev20240523
[pip3] torchvision==0.19.0.dev20240523                                                                                                            
[pip3] triton==3.0.0                                                                                                                               
[pip3] triton-nightly==3.0.0.post20240522224832
[conda] blas                      1.0                         mkl    intel
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly                                                               [conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
[conda] mkl                       2021.4.0              intel_640    intel
[conda] mkl-service               2.4.0           py311h9bf148f_0    pytorch-nightly
[conda] mkl_fft                   1.3.1           py311hc796f24_0    pytorch-nightly
[conda] mkl_random                1.2.2           py311hbba84a0_0    pytorch-nightly
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.26.4                   pypi_0    pypi                                                                         
[conda] numpy-base                1.24.3          py311hfd5febd_0
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.4.0.dev20240523 py3.11_cuda12.4_cudnn8.9.2_0    pytorch-nightly
[conda] pytorch-cuda              12.4                 hc786d27_6    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchaudio                2.2.0.dev20240523     py311_cu124    pytorch-nightly
[conda] torchtriton               3.0.0+45fff310c8           py311    pytorch-nightly
[conda] torchvision               0.19.0.dev20240523     py311_cu124    pytorch-nightly                                                            [conda] triton-nightly            3.0.0.post20240522224832          pypi_0    pypi
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly 

juankost avatar Aug 20 '24 11:08 juankost