flash-linear-attention
flash-linear-attention copied to clipboard
Chunk wise linear attn kernel does not work with torch compile (returns incorrects values / NaNs)
I've been trying to use the linear attention kernels in a model which I am compiling, however the triton kernel does not seem to work with torch compile. Specifically, when comparing the output of the kernel to a reference implementation, they match when torch compile is not used, and they have huge discrepancies when the kernel is wrapped with torch compile)
Here is a script to reproduce the error:
import torch
import sys
import os
from fla.ops.linear_attn.chunk_fuse import FusedChunkLinearAttentionFunction
fused_chunk_linear_attn = FusedChunkLinearAttentionFunction.apply
def reference_implementation(q, k, v, scale, initial_state, output_final_state):
# q,k,v: B, H, L, E
q = q * scale
attn_weights = torch.matmul(q, k.transpose(2, 3))
causal_mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])).to(q.device).bool()
attn_weights.masked_fill_(~causal_mask, 0.0)
attn_output = torch.matmul(attn_weights, v)
if initial_state is not None:
state_contribution = torch.matmul(q, initial_state)
attn_output += state_contribution
hidden_state = None
if output_final_state:
hidden_state = torch.einsum("bhle,bhlf->bhef", k, v)
if initial_state is not None:
hidden_state += initial_state
return attn_output, hidden_state
# SE (02/26): borrowing these tolerances from Mamba's test_selective_state_update
# https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/tests/ops/triton/test_selective_state_update.py#L24C1-L26C32
DTYPE_TO_ATOL = {
torch.float32: 1e-3,
torch.float16: 1e-2,
torch.bfloat16: 5e-2,
}
DTYPE_TO_RTOL = {
torch.float32: 3e-4,
torch.float16: 5e-4,
torch.bfloat16: 1e-2,
}
def compare_two_outputs(output1, output2, atol=1e-2, rtol=0):
def compute_stats(a, b, atol=1e-2, rtol=0):
abs_diff = torch.abs(a - b)
mean_abs_diff = torch.mean(abs_diff).item()
max_abs_diff = torch.max(abs_diff).item()
not_close_mask = ~torch.isclose(a, b, atol=atol, rtol=rtol)
mean_abs_diff_not_close = (
torch.mean(abs_diff[not_close_mask]).item() if not_close_mask.any() else 0
)
mean_rel_diff_not_close = (
torch.mean(abs_diff[not_close_mask] / (torch.abs(b[not_close_mask]) + 1e-7)).item()
if not_close_mask.any()
else 0
)
return {
"mean_abs_diff": mean_abs_diff,
"max_abs_diff": max_abs_diff,
"mean_abs_diff_not_close": mean_abs_diff_not_close,
"mean_rel_diff_not_close": mean_rel_diff_not_close,
"percent_not_close": 100 * not_close_mask.float().mean().item(),
}
# if not torch.allclose(output1, output2, atol=atol, rtol=rtol):
if not torch.allclose(output1, output2):
output_stats = compute_stats(output1, output2, atol=atol, rtol=rtol)
print("Different outputs")
print(f"Mean absolute difference: {output_stats['mean_abs_diff']:.6e}")
print(f"Max absolute difference: {output_stats['max_abs_diff']:.6e}")
print(
f"Mean absolute difference (not close): {output_stats['mean_abs_diff_not_close']:.6e}"
)
print(
f"Relative absolute difference (not close): {output_stats['mean_rel_diff_not_close']:.6e}"
)
print(f"Percent not close: {output_stats['percent_not_close']:.2f}%")
else:
print("Outputs match")
print("-" * 80) # Separator for readability
return torch.allclose(output1, output2, atol=atol, rtol=rtol)
def test_torch_compile(dtype):
# Set up test parameters
B, H, L, E = 1, 8, 2048, 64
scale = 0.125
output_final_state = True
# Create input tensors
q = (
torch.empty(B, H, L, E, dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
.contiguous()
)
k = (
torch.empty(B, H, L, E, dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
.contiguous()
)
v = (
torch.empty(B, H, L, E, dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
.contiguous()
)
initial_state = (
torch.empty((B, H, E, E), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.contiguous()
)
# Define a function that uses the kernel
def attention_fn_kernel(q, k, v, initial_state): # -> Any:
return fused_chunk_linear_attn(q, k, v, scale, initial_state, output_final_state)
def attention_fn_pytorch(q, k, v, initial_state): # -> Any:
return reference_implementation(q, k, v, scale, initial_state, output_final_state)
# Compile/warmup the kernels
compiled_fn_kernel = torch.compile(attention_fn_kernel)
compiled_fn_pytorch = torch.compile(attention_fn_pytorch)
with torch.autocast(device_type="cuda", dtype=dtype):
output1, _ = compiled_fn_kernel(q, k, v, initial_state)
output2, _ = compiled_fn_pytorch(q, k, v, initial_state)
# Run the forward passes
with torch.autocast(device_type="cuda", dtype=dtype):
output_kernel_comp, _ = compiled_fn_kernel(q, k, v, initial_state)
output_pytorch_comp, _ = compiled_fn_pytorch(q, k, v, initial_state)
output_kernel, _ = attention_fn_kernel(q, k, v, initial_state)
output_pytorch, _ = attention_fn_pytorch(q, k, v, initial_state)
# Check if any of the compiled outputs match the original
atol = DTYPE_TO_ATOL[dtype]
rtol = DTYPE_TO_RTOL[dtype]
print("Compiled vs non compiled Pytorch")
compare_two_outputs(output_pytorch_comp, output_pytorch, atol=atol, rtol=rtol)
print("Triton (non compiled) vs Pytorch (non compiled)")
compare_two_outputs(output_kernel.float(), output_pytorch.float(), atol=atol, rtol=rtol)
print("Triton (compiled) vs Pytorch (non compiled)")
compare_two_outputs(output_kernel_comp.float(), output_pytorch.float(), atol=atol, rtol=rtol)
print("Test compile completed")
if __name__ == "__main__":
for dtype in [torch.float16, torch.bfloat16, torch.float32]:
print(f"Testing torc compile for dtype {dtype}")
test_torch_compile(dtype)
print("=" * 80)
The output of the script (running on A100 40GB GPU):
Testing torc compile for dtype torch.float16 [109/1908]Compiled vs non compiled Pytorch Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled) Different outputs
Mean absolute difference: 1.049560e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02 Relative absolute difference (not close): 1.805531e-03 Percent not close: 0.00%
-------------------------------------------------------------------------------- Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 3.020972e+00
Max absolute difference: 2.664062e+01
Mean absolute difference (not close): 3.029570e+00
Relative absolute difference (not close): 9.999998e-01
Percent not close: 99.72%
--------------------------------------------------------------------------------
Test compile completed
==================================================
Testing torc compile for dtype torch.bfloat16 Compiled vs non compiled Pytorch
Outputs match
-------------------------------------------------------------------------------- Triton (non compiled) vs Pytorch (non compiled) Different outputs
Mean absolute difference: 8.357576e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.615234e-02 Relative absolute difference (not close): 1.953556e-01 Percent not close: 0.00%
-------------------------------------------------------------------------------- Triton (compiled) vs Pytorch (non compiled) Different outputs
Mean absolute difference: nan
Max absolute difference: nan
Mean absolute difference (not close): nan Relative absolute difference (not close): nan
Percent not close: 100.00%
-------------------------------------------------------------------------------- Test compile completed
==============================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.531446e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: nan
Max absolute difference: nan
Mean absolute difference (not close): nan
Relative absolute difference (not close): nan
Percent not close: 100.00%
--------------------------------------------------------------------------------
Test compile completed
===============================================
Relevant libraries from environment:
Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.4.0.dev20240523
[pip3] torchaudio==2.2.0.dev20240523
[pip3] torchvision==0.19.0.dev20240523
[pip3] triton==3.0.0
[pip3] triton-nightly==3.0.0.post20240522224832
[conda] blas 1.0 mkl intel
[conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly
[conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly [conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly
[conda] mkl 2021.4.0 intel_640 intel
[conda] mkl-service 2.4.0 py311h9bf148f_0 pytorch-nightly
[conda] mkl_fft 1.3.1 py311hc796f24_0 pytorch-nightly
[conda] mkl_random 1.2.2 py311hbba84a0_0 pytorch-nightly
[conda] mpmath 1.2.1 py311_0 pytorch-nightly
[conda] numpy 1.26.4 pypi_0 pypi
[conda] numpy-base 1.24.3 py311hfd5febd_0
[conda] pysocks 1.7.1 py311_0 pytorch-nightly
[conda] pytorch 2.4.0.dev20240523 py3.11_cuda12.4_cudnn8.9.2_0 pytorch-nightly
[conda] pytorch-cuda 12.4 hc786d27_6 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 2.2.0.dev20240523 py311_cu124 pytorch-nightly
[conda] torchtriton 3.0.0+45fff310c8 py311 pytorch-nightly
[conda] torchvision 0.19.0.dev20240523 py311_cu124 pytorch-nightly [conda] triton-nightly 3.0.0.post20240522224832 pypi_0 pypi
[conda] urllib3 1.26.14 py311_0 pytorch-nightly
I think I have localized the issue. It seems to arrive from the line 249 where the outputs of each block are summed up.
Note, I tested the following to determine the cause of the error:
# The current approach in the code
o = o.sum(0)
return o.to(q.dtype), final_state
# When NK == 1, we do not need to sum.
# return o[0].to(q.dtype), final_state # --> this removes the large discrepancies in the output
"Workaround" (not sure how this impact on the torch compile performance):
# Adding the print statement removes the large discprepancies observed in the torch compile
print("")
o = o.sum(0)
I am not sure why exactly this works, but it can help to come up with a better solution
@juankost Hi, could you report your torch & triton version, I met some errors when running your code snippets.
@yzhangcs
Pytorch 2.4.0.dev20240523 Triton 3.0.0+45fff310c8
What type of error are you getting?
I think the problem is that torch compile tries to fuse the o.sum(0) operation with the triton kernel, and the sum operation might be executed before all the thread blocks finished computing the output. The print statement basically breaks the torch compile graph before the sum, which probably ensure that the sum operation is done after the triton kernel is fully finished
Hi, mine is torch 2.3 & triton 2.3
TypeError: launcher() got an unexpected keyword argument 'num_warps'
not sure what exactly happened.
Just provide a quick fix, could you check it again?
I tested it, it does not work when NK > 1 (so when the sum operation actually gets done).
Running torch dynamo explain on the function, it shows that it compiles it into a single graph still
Graph Count: 1
Graph Break Count: 0
Op Count: 4
Break Reasons:
Ops per Graph:
Ops 1:
<class 'torch.autograd.function.FunctionCtx'> autograd_function_apply <built-in function getitem> <built-in function getitem>
while if I have the print statement, it does indeed break the compilation graph. I have not figured out a better way to break the graph (or why this is a problem in the first place). Maybe we could raise an issue to the Pytorch team to see their take on it
@juankost Here is my current output with torch2.4 and triton3.0
Testing torc compile for dtype torch.float16
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 1.049412e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02
Relative absolute difference (not close): 1.618346e-03
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 1.049412e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02
Relative absolute difference (not close): 1.618346e-03
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.float16
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 1.050618e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02
Relative absolute difference (not close): 1.841894e-03
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 1.050618e-03
Max absolute difference: 1.562500e-02
Mean absolute difference (not close): 1.562500e-02
Relative absolute difference (not close): 1.841894e-03
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.bfloat16
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 8.342735e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.664062e-02
Relative absolute difference (not close): 1.106870e-01
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 8.342735e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.664062e-02
Relative absolute difference (not close): 1.106870e-01
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.bfloat16
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 8.370323e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.322266e-02
Relative absolute difference (not close): 5.157646e-01
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 8.370323e-03
Max absolute difference: 1.250000e-01
Mean absolute difference (not close): 5.322266e-02
Relative absolute difference (not close): 5.157646e-01
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.526679e-06
Max absolute difference: 4.577637e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.526679e-06
Max absolute difference: 4.577637e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
Testing torc compile for dtype torch.float32
Compiled vs non compiled Pytorch
Outputs match
--------------------------------------------------------------------------------
Triton (non compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.525714e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Triton (compiled) vs Pytorch (non compiled)
Different outputs
Mean absolute difference: 2.525714e-06
Max absolute difference: 5.340576e-05
Mean absolute difference (not close): 0.000000e+00
Relative absolute difference (not close): 0.000000e+00
Percent not close: 0.00%
--------------------------------------------------------------------------------
Test compile completed
================================================================================
@yzhangcs What input dimensions are you using?
If you are just running my script from above, your proposed fix will work, since the NK == 1 (i.e. the input head dimension is 64, so NK will be 1)
If you increase the input head dimensions like the code like below, then it will still fail, no?
def test_torch_compile(dtype):
# Set up test parameters
B, H, L, E = 1, 8, 2048, 64 --> Change this to 128 instead of 64
I’m just using the above code. If still failing, we might consider other solutions for this issue
Yeah, it still fails for me when increasing the head dimension.
I asked on the Pytorch forum, if they have any suggestions or explanations for this https://discuss.pytorch.org/t/custom-triton-kernel-with-torch-compile-returns-incorrect-outputs/208355
@juankost closing this as answered by the link you provided. We've not experimented too much with torch.compile yet. Any PRs are welcome if you fix this issue.
@yzhangcs For now, I simply did a workaround to remove the parallelization over the K dimension, such that the operation
o = o.sum(0)
at the end of the kernel can be dropped. This "fixes" the problem with torch.compile, but is not ideal in terms of optimization. I'm following up the issue on the forum. If we find a solution, I'll submit a PR