[BUG] triton.language.associative_scan returning incorrect results when `reverse=True`
Hi,
I believe triton.language.associative_scan is returning incorrect results when reverse=True, or I could not figure out the desired behaviour. In the original PR described as "jax like" and rev(scan(rev(x))). So I compared different variants and also did a jax run.
Spotted a problem
Here I thought maybe not using out pointers is causing some sort of race condition as the result for the "exponent" f part is very clearly wrong. So I tested them pair-wise like so
import torch
import triton.language as tl
import triton
# Setup
@triton.jit
def op(fl, xl, fr, xr) -> tuple[torch.float, torch.float]:
"""First order linear recurrence operation.
source: https://srush.github.io/annotated-mamba/hard.html
"""
f = fr * fl
x = fr * xl + xr
return f, x
@triton.jit
def kernel1(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
input_range = tl.arange(0, BS)
exp = tl.load(exp_ref + input_range)
vals = tl.load(vals_ref + input_range)
exp, vals = tl.associative_scan(
(exp, vals), axis=0, combine_fn=op, reverse=reverse
)
tl.store(exp_ref + input_range, exp)
tl.store(vals_ref + input_range, vals)
@triton.jit
def kernel2(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
input_range = tl.arange(0, BS)
exp = tl.load(exp_ref + input_range)
vals = tl.load(vals_ref + input_range)
exp, vals = tl.associative_scan(
(exp, vals), axis=0, combine_fn=op, reverse=reverse
)
tl.store(out_exp_ref + input_range, exp)
tl.store(out_vals_ref + input_range, vals)
def init():
BS=4
device = torch.device("cuda")
exp = torch.tensor([1.0, 1.5, 0.8, 2.0]).to(device)
vals = torch.tensor([1.0, -1.0, 0.5, 2.0]).to(device)
out_exp = torch.empty_like(exp)
out_vals = torch.empty_like(vals)
return BS, device, exp, vals, out_exp, out_vals
# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel1[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals
BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel2[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")
print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel1[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals
BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
kernel2[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")
Output:
out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')
out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')
out_exp_reverse1=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0')
out_exp_reverse2=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse2=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0')
The reverse is incorrect.
What I believe is the correct result
Not trusting it I created a manual version as a reference and compared it to jax implementation.
Manual reference:
import torch
def op(fl, xl, fr, xr):
f = fr * fl
x = fr * xl + xr
return f, x
def forward_scan(exp, vals):
exp = torch.tensor(exp)
vals = torch.tensor(vals)
state_exp = [torch.tensor(1.0)]
state_vals = [torch.tensor(0.0)]
for i in range(len(exp)):
new_exp, new_val = op(state_exp[-1], state_vals[-1], exp[i], vals[i])
state_exp.append(new_exp)
state_vals.append(new_val)
return torch.stack(state_exp), torch.stack(state_vals)
def reverse_scan(exp, vals):
exp = torch.tensor(exp)
vals = torch.tensor(vals)
state_f = torch.tensor(1.0)
state_x = torch.tensor(0.0)
f_results = [state_f]
x_results = [state_x]
# iterate in reverse, taking the accumulated state as the left side
# and new iterated over states on the right side
for i in range(len(exp) - 1, -1, -1):
state_f, state_x = op(state_f, state_x, exp[i], vals[i])
f_results.append(state_f)
x_results.append(state_x)
return torch.stack(f_results[1:][::-1]), torch.stack(x_results[1:][::-1])
exp = [1.0, 1.5, 0.8, 2.0]
vals = [1.0, -1.0, 0.5, 2.0]
forward_exp, forward_vals = forward_scan(exp, vals)
backward_exp, backward_vals = reverse_scan(exp, vals)
print("Forward scan results:")
print("gates", forward_exp)
print("tokens", forward_vals)
print("Backward scan results:")
print("gates", backward_exp)
print("tokens", backward_vals)
output
Forward scan results:
gates tensor([1.0000, 1.0000, 1.5000, 1.2000, 2.4000])
tokens tensor([0.0000, 1.0000, 0.5000, 0.9000, 3.8000])
Backward scan results:
gates tensor([2.4000, 2.4000, 1.6000, 2.0000])
tokens tensor([3.1500, 2.1500, 2.1000, 2.0000])
Jax reference
from jax import lax
import jax.numpy as jnp
result_add_1 = lax.associative_scan(jnp.add, jnp.arange(0, 4))
result_add_1_reverse = lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
print(f"{result_add_1=}")
print(f"{result_add_1_reverse=}")
def op(left, right) -> tuple[float, float]:
fl, xl = left
fr, xr = right
f = fl * fr
x = fr * xl + xr
return f, x
exp = jnp.array([1.0, 1.5, 0.8, 2.0])
vals = jnp.array([1.0, -1.0, 0.5, 2.0])
result_jax_normal = lax.associative_scan(op, (exp, vals))
print(result_jax_normal)
result_jax_reversed = lax.associative_scan(op, (exp, vals), reverse=True)
print(result_jax_reversed)
output
result_add_1=Array([0, 1, 3, 6], dtype=int32)
result_add_1_reverse=Array([6, 6, 5, 3], dtype=int32)
(Array([1. , 1.5, 1.2, 2.4], dtype=float32), Array([1. , 0.5, 0.9, 3.8], dtype=float32))
(Array([2.4, 2.4, 1.6, 2. ], dtype=float32), Array([3.1499999, 2.1499999, 2.1 , 2.], dtype=float32))
My intuition seems to be correct and both yield the same results.
Any ideas? #3177 #2930, kindly referencing @srush and the original usage in the mamba implementation.
Python: 3.10 and 3.12 Triton: 3.0 and triton-nightly
Two workarounds:
- use flip()
- read/write using a reversed range
Investigation into workaround and further testing.
naive flip of axis, one possible workaround
but I do not trust it, why does it suddenly change behavior as compared to out_ref?
@triton.jit
def kernel3(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
input_range = tl.arange(0, BS)
exp = tl.load(exp_ref + input_range)
vals = tl.load(vals_ref + input_range)
exp, vals = tl.associative_scan(
(tl.flip(exp), tl.flip(vals)), axis=0, combine_fn=op, #reverse=reverse
) if reverse else tl.associative_scan(
(exp, vals), axis=0, combine_fn=op
)
tl.store(exp_ref + input_range, tl.flip(exp) if reverse else exp)
tl.store(vals_ref + input_range, tl.flip(vals) if reverse else vals)
@triton.jit
def kernel4(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
input_range = tl.arange(0, BS)
exp = tl.load(exp_ref + input_range)
vals = tl.load(vals_ref + input_range)
exp, vals = tl.associative_scan(
(tl.flip(exp), tl.flip(vals)), axis=0, combine_fn=op, #reverse=reverse
) if reverse else tl.associative_scan(
(exp, vals), axis=0, combine_fn=op
)
tl.store(exp_ref + input_range, tl.flip(exp) if reverse else exp)
tl.store(vals_ref + input_range, tl.flip(vals) if reverse else vals)
def init():
BS=4
device = torch.device("cuda")
exp = torch.tensor([1.0, 1.5, 0.8, 2.0]).to(device)
vals = torch.tensor([1.0, -1.0, 0.5, 2.0]).to(device)
out_exp = torch.empty_like(exp)
out_vals = torch.empty_like(vals)
return BS, device, exp, vals, out_exp, out_vals
# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel3[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals
BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel4[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")
print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel3[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals
BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
kernel4[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")
output
out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')
out_exp2=tensor([2.8924e+28, 2.0208e+00, 0.0000e+00, 0.0000e+00], device='cuda:0')
out_vals2=tensor([0., 0., 0., 0.], device='cuda:0')
out_exp_reverse1=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0')
out_exp_reverse2=tensor([1.0842e-19, 2.1437e+00, 2.0000e+00, 2.2844e+00], device='cuda:0')
out_vals_reverse2=tensor([2.0000e+00, 2.2844e+00, 1.4013e-45, 0.0000e+00], device='cuda:0')
somehow only the in-place replacement seems to have gotten the correct
usage in the mamba blog with shifted gates
@triton.jit
def kernel5(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
input_range = tl.arange(0, BS)
exp = tl.load(exp_ref + input_range)
vals = tl.load(vals_ref + input_range)
exp, vals = tl.associative_scan(
(exp, vals), axis=0, combine_fn=op, reverse=reverse
)
tl.store(exp_ref + input_range, exp)
tl.store(vals_ref + input_range, vals)
@triton.jit
def kernel6(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
input_range = tl.arange(0, BS)
exp = tl.load(exp_ref + input_range)
vals = tl.load(vals_ref + input_range)
exp, vals = tl.associative_scan(
(exp, vals), axis=0, combine_fn=op, reverse=reverse
)
tl.store(out_exp_ref + input_range, exp)
tl.store(out_vals_ref + input_range, vals)
# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel5[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals
BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel6[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")
print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel5[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals
BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
exp = torch.cat([exp[1:], torch.tensor([1]).to(device)])
kernel6[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")
output
out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')
out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')
out_exp_reverse1=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0')
out_exp_reverse2=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse2=tensor([4.9000, 4.2000, 3.5000, 2.1000], device='cuda:0')
both reverses are wrong
workaround1: reversing indexes
@triton.jit
def kernel7(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
input_range = tl.arange(0, BS)
input_range = (BS - 1 - input_range) if reverse else input_range
exp = tl.load(exp_ref + input_range)
vals = tl.load(vals_ref + input_range)
exp, vals = tl.associative_scan(
(exp, vals), axis=0, combine_fn=op
)
tl.store(exp_ref + input_range, exp)
tl.store(vals_ref + input_range, vals)
@triton.jit
def kernel8(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
input_range = tl.arange(0, BS)
input_range = (BS - 1 - input_range) if reverse else input_range
exp = tl.load(exp_ref + input_range)
vals = tl.load(vals_ref + input_range)
exp, vals = tl.associative_scan(
(exp, vals), axis=0, combine_fn=op
)
tl.store(out_exp_ref + input_range, exp)
tl.store(out_vals_ref + input_range, vals)
# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel7[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals
BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel8[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")
print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel7[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals
BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
kernel8[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")
output
out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')
out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')
out_exp_reverse1=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0')
out_exp_reverse2=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0')
out_vals_reverse2=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0')
both correct
Performance of the variants that yielded correct results?
import torch
import triton.language as tl
import triton
torch.random.manual_seed(1155)
def init_rand(seqlen):
BS=4
device = torch.device("cuda")
exp = torch.rand(seqlen).to(device)
vals = torch.rand(seqlen).to(device)
out_exp = torch.empty_like(exp)
out_vals = torch.empty_like(vals)
return BS, device, exp, vals, out_exp, out_vals
lines=["in_place_flip", "in_place_rev", "out_ref_rev"]
@triton.testing.perf_report([
triton.testing.Benchmark(
x_names=["seqlen"],
x_vals=[2**i for i in range(7,20)],
xlabel='sequence length',
ylabel='ms',
x_log=True,
y_log=True,
line_arg="benched",
line_vals=lines,
line_names=lines,
plot_name="reversing scan",
args={
}
),
])
def bench(benched, seqlen, device="cuda"):
BS, device, exp, vals, out_exp, out_vals = init_rand(seqlen)
match benched:
case "in_place_flip":
subject = lambda:kernel3[(1,)](exp, vals, BS, True)
case "in_place_rev":
subject = lambda:kernel7[(1,)](exp, vals, BS, True)
case "out_ref_rev":
subject = lambda:kernel8[(1,)](exp, vals, out_exp, out_vals, BS, True)
ms = triton.testing.do_bench(subject, warmup=1000, rep=200)
print(f"{seqlen=};\t{benched=};{ms=}")
return ms
bench.run(save_path="./bench_results", print_data=True)
output
reversing scan:
seqlen in_place_flip in_place_rev out_ref_rev
0 128.0 0.007101 0.004218 0.003976
1 256.0 0.004264 0.003972 0.003923
2 512.0 0.004228 0.003937 0.004008
3 1024.0 0.004271 0.004030 0.003985
4 2048.0 0.004283 0.004002 0.003961
5 4096.0 0.004266 0.004031 0.003958
6 8192.0 0.004274 0.003972 0.004049
7 16384.0 0.004264 0.003959 0.004006
8 32768.0 0.004260 0.003984 0.004011
9 65536.0 0.004251 0.004020 0.004054
10 131072.0 0.004317 0.004037 0.004085
11 262144.0 0.004282 0.004010 0.004020
12 524288.0 0.004318 0.004024 0.004000
and for reverse=False
reversing scan:
seqlen in_place_flip in_place_rev out_ref_rev
0 128.0 0.006323 0.004214 0.004002
1 256.0 0.003979 0.004000 0.003952
2 512.0 0.003979 0.004011 0.004018
3 1024.0 0.004027 0.003972 0.004007
4 2048.0 0.003967 0.004033 0.003976
5 4096.0 0.004024 0.003985 0.004023
6 8192.0 0.004042 0.004050 0.004083
7 16384.0 0.003983 0.004045 0.003997
8 32768.0 0.004036 0.004020 0.004061
9 65536.0 0.004019 0.004021 0.004017
10 131072.0 0.004057 0.004016 0.004059
11 262144.0 0.004033 0.004002 0.004021
12 524288.0 0.003994 0.003968 0.004008
On the reverse=True, the results on a T4 is tends to favour the in-place reversed index variant, but on a A100 it was inconclusive, mind you this is ran on google collab. in-place flip does worse.
On the reverse=False, they are all mostly the same.
We debugged it a bit and foudn that this issue only occurs with length <32. I am going to send a PR to block these sequences for now and also try to debug why that happens.
Are there any updates on this? I have also observed the following behavior with <32 length, on both AMD and NVIDIA gpus.
@Hprairie Not yet as far as I know, but look at the workarounds in the details of the report. The reversed=False works well, one can load the memory pointers in reversed order and save them back like that, or for <32 elements you might be better off with a linear scan using a simple for loop.
yeah I think under 32 this is not going to be fast anyway. That being said I am looking into a fix.
I have found that tl.flip doesn't fix the problem and that it still persists when doing something like tl.flip(tl.scan(tl.flip())). It may also be happening for >32 lengths when running a scan with a 2d tensor. I will work on getting a reproducible script to help with debugging.
I also have the same issue when using tl.cumsum (which uses the scan) with reverse=True.
What is the status of this?