flash-linear-attention
flash-linear-attention copied to clipboard
Chunk wise linear attn kernel does not work with torch compile (returns incorrects values / NaNs)
I've been trying to use the linear attention kernels in a model which I am compiling, however the triton kernel does not seem to work with torch compile. Specifically, when comparing the output of the kernel to a reference implementation, they match when torch compile is not used, and they have huge discrepancies when the kernel is wrapped with torch compile)
Here is a script to reproduce the error:
import torch
import sys
import os
from fla.ops.linear_attn.chunk_fuse import FusedChunkLinearAttentionFunction
fused_chunk_linear_attn = FusedChunkLinearAttentionFunction.apply
def reference_implementation(q, k, v, scale, initial_state, output_final_state):
# q,k,v: B, H, L, E
q = q * scale
attn_weights = torch.matmul(q, k.transpose(2, 3))
causal_mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])).to(q.device).bool()
attn_weights.masked_fill_(~causal_mask, 0.0)
attn_output = torch.matmul(attn_weights, v)
if initial_state is not None:
state_contribution = torch.matmul(q, initial_state)
attn_output += state_contribution
hidden_state = None
if output_final_state:
hidden_state = torch.einsum("bhle,bhlf->bhef", k, v)
if initial_state is not None:
hidden_state += initial_state
return attn_output, hidden_state
# SE (02/26): borrowing these tolerances from Mamba's test_selective_state_update
# https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/tests/ops/triton/test_selective_state_update.py#L24C1-L26C32
DTYPE_TO_ATOL = {
torch.float32: 1e-3,
torch.float16: 1e-2,
torch.bfloat16: 5e-2,
}
DTYPE_TO_RTOL = {
torch.float32: 3e-4,
torch.float16: 5e-4,
torch.bfloat16: 1e-2,
}
def compare_two_outputs(output1, output2, atol=1e-2, rtol=0):
def compute_stats(a, b, atol=1e-2, rtol=0):
abs_diff = torch.abs(a - b)
mean_abs_diff = torch.mean(abs_diff).item()
max_abs_diff = torch.max(abs_diff).item()
not_close_mask = ~torch.isclose(a, b, atol=atol, rtol=rtol)
mean_abs_diff_not_close = (
torch.mean(abs_diff[not_close_mask]).item() if not_close_mask.any() else 0
)
mean_rel_diff_not_close = (
torch.mean(abs_diff[not_close_mask] / (torch.abs(b[not_close_mask]) + 1e-7)).item()
if not_close_mask.any()
else 0
)
return {
"mean_abs_diff": mean_abs_diff,
"max_abs_diff": max_abs_diff,
"mean_abs_diff_not_close": mean_abs_diff_not_close,
"mean_rel_diff_not_close": mean_rel_diff_not_close,
"percent_not_close": 100 * not_close_mask.float().mean().item(),
}
# if not torch.allclose(output1, output2, atol=atol, rtol=rtol):
if not torch.allclose(output1, output2):
output_stats = compute_stats(output1, output2, atol=atol, rtol=rtol)
print("Different outputs")
print(f"Mean absolute difference: {output_stats['mean_abs_diff']:.6e}")
print(f"Max absolute difference: {output_stats['max_abs_diff']:.6e}")
print(
f"Mean absolute difference (not close): {output_stats['mean_abs_diff_not_close']:.6e}"
)
print(
f"Relative absolute difference (not close): {output_stats['mean_rel_diff_not_close']:.6e}"
)
print(f"Percent not close: {output_stats['percent_not_close']:.2f}%")
else:
print("Outputs match")
print("-" * 80) # Separator for readability
return torch.allclose(output1, output2, atol=atol, rtol=rtol)
def test_torch_compile(dtype):
# Set up test parameters
B, H, L, E = 1, 8, 2048, 64
scale = 0.125
output_final_state = True
# Create input tensors
q = (
torch.empty(B, H, L, E, dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
.contiguous()
)
k = (
torch.empty(B, H, L, E, dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
.contiguous()
)
v = (
torch.empty(B, H, L, E, dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
.contiguous()
)
initial_state = (
torch.empty((B, H, E, E), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.contiguous()
)
# Define a function that uses the kernel
def attention_fn_kernel(q, k, v, initial_state): # -> Any:
return fused_chunk_linear_attn(q, k, v, scale, initial_state, output_final_state)
def attention_fn_pytorch(q, k, v, initial_state): # -> Any:
return reference_implementation(q, k, v, scale, initial_state, output_final_state)
# Compile/warmup the kernels
compiled_fn_kernel = torch.compile(attention_fn_kernel)
compiled_fn_pytorch = torch.compile(attention_fn_pytorch)
with torch.autocast(device_type="cuda", dtype=dtype):
output1, _ = compiled_fn_kernel(q, k, v, initial_state)
output2, _ = compiled_fn_pytorch(q, k, v, initial_state)
# Run the forward passes
with torch.autocast(device_type="cuda", dtype=dtype):
output_kernel_comp, _ = compiled_fn_kernel(q, k, v, initial_state)
output_pytorch_comp, _ = compiled_fn_pytorch(q, k, v, initial_state)
output_kernel, _ = attention_fn_kernel(q, k, v, initial_state)
output_pytorch, _ = attention_fn_pytorch(q, k, v, initial_state)
# Check if any of the compiled outputs match the original
atol = DTYPE_TO_ATOL[dtype]
rtol = DTYPE_TO_RTOL[dtype]
print("Compiled vs non compiled Pytorch")
compare_two_outputs(output_pytorch_comp, output_pytorch, atol=atol, rtol=rtol)
print("Triton (non compiled) vs Pytorch (non compiled)")
compare_two_outputs(output_kernel.float(), output_pytorch.float(), atol=atol, rtol=rtol)
print("Triton (compiled) vs Pytorch (non compiled)")
compare_two_outputs(output_kernel_comp.float(), output_pytorch.float(), atol=atol, rtol=rtol)
print("Test compile completed")
if __name__ == "__main__":
for dtype in [torch.float16, torch.bfloat16, torch.float32]:
print(f"Testing torc compile for dtype {dtype}")
test_torch_compile(dtype)
print("=" * 80)
The output of the script (running on A100 40GB GPU):
Testing torc compile for dtype torch.float16 [109/1908]Compiled vs non compiled Pytorch Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled) Different outputs
Mean absolute difference: 1.049560e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02 Relative absolute difference (not close): 1.805531e-03 Percent not close: 0.00%
-------------------------------------------------------------------------------- Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 3.020972e+00
Max absolute difference: 2.664062e+01
Mean absolute difference (not close): 3.029570e+00
Relative absolute difference (not close): 9.999998e-01
Percent not close: 99.72%
--------------------------------------------------------------------------------
Test compile completed
==================================================
Testing torc compile for dtype torch.bfloat16 Compiled vs non compiled Pytorch
Outputs match
-------------------------------------------------------------------------------- Triton (non compiled) vs Pytorch (non compiled) Different outputs
Mean absolute difference: 8.357576e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.615234e-02 Relative absolute difference (not close): 1.953556e-01 Percent not close: 0.00%
-------------------------------------------------------------------------------- Triton (compiled) vs Pytorch (non compiled) Different outputs
Mean absolute difference: nan
Max absolute difference: nan
Mean absolute difference (not close): nan Relative absolute difference (not close): nan
Percent not close: 100.00%
-------------------------------------------------------------------------------- Test compile completed
==============================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.531446e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: nan
Max absolute difference: nan
Mean absolute difference (not close): nan
Relative absolute difference (not close): nan
Percent not close: 100.00%
--------------------------------------------------------------------------------
Test compile completed
===============================================
Relevant libraries from environment:
Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.4.0.dev20240523
[pip3] torchaudio==2.2.0.dev20240523
[pip3] torchvision==0.19.0.dev20240523
[pip3] triton==3.0.0
[pip3] triton-nightly==3.0.0.post20240522224832
[conda] blas 1.0 mkl intel
[conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly
[conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly [conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly
[conda] mkl 2021.4.0 intel_640 intel
[conda] mkl-service 2.4.0 py311h9bf148f_0 pytorch-nightly
[conda] mkl_fft 1.3.1 py311hc796f24_0 pytorch-nightly
[conda] mkl_random 1.2.2 py311hbba84a0_0 pytorch-nightly
[conda] mpmath 1.2.1 py311_0 pytorch-nightly
[conda] numpy 1.26.4 pypi_0 pypi
[conda] numpy-base 1.24.3 py311hfd5febd_0
[conda] pysocks 1.7.1 py311_0 pytorch-nightly
[conda] pytorch 2.4.0.dev20240523 py3.11_cuda12.4_cudnn8.9.2_0 pytorch-nightly
[conda] pytorch-cuda 12.4 hc786d27_6 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 2.2.0.dev20240523 py311_cu124 pytorch-nightly
[conda] torchtriton 3.0.0+45fff310c8 py311 pytorch-nightly
[conda] torchvision 0.19.0.dev20240523 py311_cu124 pytorch-nightly [conda] triton-nightly 3.0.0.post20240522224832 pypi_0 pypi
[conda] urllib3 1.26.14 py311_0 pytorch-nightly