triton icon indicating copy to clipboard operation
triton copied to clipboard

[BUG] triton.language.associative_scan returning incorrect results when `reverse=True`

Open PheelaV opened this issue 1 year ago • 4 comments

Hi,

I believe triton.language.associative_scan is returning incorrect results when reverse=True, or I could not figure out the desired behaviour. In the original PR described as "jax like" and rev(scan(rev(x))). So I compared different variants and also did a jax run.

Spotted a problem

Here I thought maybe not using out pointers is causing some sort of race condition as the result for the "exponent" f part is very clearly wrong. So I tested them pair-wise like so

import torch
import triton.language as tl
import triton


# Setup
@triton.jit
def op(fl, xl, fr, xr) -> tuple[torch.float, torch.float]:
  """First order linear recurrence operation.
     source: https://srush.github.io/annotated-mamba/hard.html
    """
  f = fr * fl
  x = fr * xl + xr
  return f, x


@triton.jit
def kernel1(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op, reverse=reverse
  )

  tl.store(exp_ref + input_range, exp)
  tl.store(vals_ref + input_range, vals)

@triton.jit
def kernel2(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op, reverse=reverse
  )

  tl.store(out_exp_ref + input_range, exp)
  tl.store(out_vals_ref + input_range, vals)

def init():
    BS=4
    device = torch.device("cuda")
    exp = torch.tensor([1.0, 1.5, 0.8, 2.0]).to(device)
    vals = torch.tensor([1.0, -1.0, 0.5, 2.0]).to(device)

    out_exp = torch.empty_like(exp)
    out_vals = torch.empty_like(vals)

    return BS, device, exp, vals, out_exp, out_vals

# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel1[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals

BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel2[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")

print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel1[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals

BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
kernel2[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")

Output:

out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')

out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')


out_exp_reverse1=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0')

out_exp_reverse2=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse2=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0')

The reverse is incorrect.

What I believe is the correct result

Not trusting it I created a manual version as a reference and compared it to jax implementation.

Manual reference:

import torch


def op(fl, xl, fr, xr):
  f = fr * fl
  x = fr * xl + xr
  return f, x


def forward_scan(exp, vals):
  exp = torch.tensor(exp)
  vals = torch.tensor(vals)
  state_exp = [torch.tensor(1.0)]
  state_vals = [torch.tensor(0.0)]
  for i in range(len(exp)):
    new_exp, new_val = op(state_exp[-1], state_vals[-1], exp[i], vals[i])
    state_exp.append(new_exp)
    state_vals.append(new_val)
  return torch.stack(state_exp), torch.stack(state_vals)

def reverse_scan(exp, vals):
  exp = torch.tensor(exp)
  vals = torch.tensor(vals)

  state_f = torch.tensor(1.0)
  state_x = torch.tensor(0.0)

  f_results = [state_f]
  x_results = [state_x]

  # iterate in reverse, taking the accumulated state as the left side
  # and new iterated over states on the right side
  for i in range(len(exp) - 1, -1, -1):
    state_f, state_x = op(state_f, state_x, exp[i], vals[i])
    f_results.append(state_f)
    x_results.append(state_x)

  return torch.stack(f_results[1:][::-1]), torch.stack(x_results[1:][::-1])


exp = [1.0, 1.5, 0.8, 2.0]
vals = [1.0, -1.0, 0.5, 2.0]

forward_exp, forward_vals = forward_scan(exp, vals)
backward_exp, backward_vals = reverse_scan(exp, vals)

print("Forward scan results:")
print("gates", forward_exp)
print("tokens", forward_vals)

print("Backward scan results:")
print("gates", backward_exp)
print("tokens", backward_vals)

output

Forward scan results:
gates tensor([1.0000, 1.0000, 1.5000, 1.2000, 2.4000])
tokens tensor([0.0000, 1.0000, 0.5000, 0.9000, 3.8000])
Backward scan results:
gates tensor([2.4000, 2.4000, 1.6000, 2.0000])
tokens tensor([3.1500, 2.1500, 2.1000, 2.0000])

Jax reference

from jax import lax
import jax.numpy as jnp

result_add_1 = lax.associative_scan(jnp.add, jnp.arange(0, 4))
result_add_1_reverse = lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
print(f"{result_add_1=}")
print(f"{result_add_1_reverse=}")


def op(left, right) -> tuple[float, float]:
    fl, xl = left
    fr, xr = right
    f = fl * fr
    x = fr * xl + xr
    return f, x


exp = jnp.array([1.0, 1.5, 0.8, 2.0])
vals = jnp.array([1.0, -1.0, 0.5, 2.0])

result_jax_normal = lax.associative_scan(op, (exp, vals))
print(result_jax_normal)
result_jax_reversed = lax.associative_scan(op, (exp, vals), reverse=True)
print(result_jax_reversed)

output

result_add_1=Array([0, 1, 3, 6], dtype=int32)
result_add_1_reverse=Array([6, 6, 5, 3], dtype=int32)
(Array([1. , 1.5, 1.2, 2.4], dtype=float32), Array([1. , 0.5, 0.9, 3.8], dtype=float32))
(Array([2.4, 2.4, 1.6, 2. ], dtype=float32), Array([3.1499999, 2.1499999, 2.1 , 2.], dtype=float32))

My intuition seems to be correct and both yield the same results.

Any ideas? #3177 #2930, kindly referencing @srush and the original usage in the mamba implementation.

Python: 3.10 and 3.12 Triton: 3.0 and triton-nightly

Two workarounds:

  1. use flip()
  2. read/write using a reversed range
Investigation into workaround and further testing.

naive flip of axis, one possible workaround

but I do not trust it, why does it suddenly change behavior as compared to out_ref?

@triton.jit
def kernel3(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (tl.flip(exp), tl.flip(vals)), axis=0, combine_fn=op, #reverse=reverse
  ) if reverse else tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op
  )

  tl.store(exp_ref + input_range, tl.flip(exp) if reverse else exp)
  tl.store(vals_ref + input_range, tl.flip(vals) if reverse else vals)

@triton.jit
def kernel4(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (tl.flip(exp), tl.flip(vals)), axis=0, combine_fn=op, #reverse=reverse
  ) if reverse else tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op
  )

  tl.store(exp_ref + input_range, tl.flip(exp) if reverse else exp)
  tl.store(vals_ref + input_range, tl.flip(vals) if reverse else vals)

def init():
    BS=4
    device = torch.device("cuda")
    exp = torch.tensor([1.0, 1.5, 0.8, 2.0]).to(device)
    vals = torch.tensor([1.0, -1.0, 0.5, 2.0]).to(device)

    out_exp = torch.empty_like(exp)
    out_vals = torch.empty_like(vals)

    return BS, device, exp, vals, out_exp, out_vals

# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel3[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals

BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel4[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")

print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel3[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals

BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
kernel4[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")

output

out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')

out_exp2=tensor([2.8924e+28, 2.0208e+00, 0.0000e+00, 0.0000e+00], device='cuda:0')
out_vals2=tensor([0., 0., 0., 0.], device='cuda:0')


out_exp_reverse1=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0')

out_exp_reverse2=tensor([1.0842e-19, 2.1437e+00, 2.0000e+00, 2.2844e+00], device='cuda:0')
out_vals_reverse2=tensor([2.0000e+00, 2.2844e+00, 1.4013e-45, 0.0000e+00], device='cuda:0')

somehow only the in-place replacement seems to have gotten the correct

usage in the mamba blog with shifted gates

@triton.jit
def kernel5(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op, reverse=reverse
  )

  tl.store(exp_ref + input_range, exp)
  tl.store(vals_ref + input_range, vals)

@triton.jit
def kernel6(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op, reverse=reverse
  )

  tl.store(out_exp_ref + input_range, exp)
  tl.store(out_vals_ref + input_range, vals)

# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel5[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals

BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel6[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")

print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel5[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals

BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
exp = torch.cat([exp[1:], torch.tensor([1]).to(device)])
kernel6[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")

output

out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')

out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')


out_exp_reverse1=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0')

out_exp_reverse2=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse2=tensor([4.9000, 4.2000, 3.5000, 2.1000], device='cuda:0')

both reverses are wrong

workaround1: reversing indexes

@triton.jit
def kernel7(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  input_range = (BS - 1 - input_range) if reverse else input_range
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op
  )

  tl.store(exp_ref + input_range, exp)
  tl.store(vals_ref + input_range, vals)

@triton.jit
def kernel8(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  input_range = (BS - 1 - input_range) if reverse else input_range

  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op
  )

  tl.store(out_exp_ref + input_range, exp)
  tl.store(out_vals_ref + input_range, vals)

# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel7[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals

BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel8[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")

print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel7[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals

BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
kernel8[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")

output

out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')

out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')


out_exp_reverse1=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0')

out_exp_reverse2=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0')
out_vals_reverse2=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0')

both correct

Performance of the variants that yielded correct results?

import torch
import triton.language as tl
import triton

torch.random.manual_seed(1155)

def init_rand(seqlen):
    BS=4
    device = torch.device("cuda")
    exp = torch.rand(seqlen).to(device)
    vals = torch.rand(seqlen).to(device)

    out_exp = torch.empty_like(exp)
    out_vals = torch.empty_like(vals)

    return BS, device, exp, vals, out_exp, out_vals

lines=["in_place_flip", "in_place_rev", "out_ref_rev"]
@triton.testing.perf_report([
    triton.testing.Benchmark(
        x_names=["seqlen"],
        x_vals=[2**i for i in range(7,20)],
        xlabel='sequence length',
        ylabel='ms',
        x_log=True,
        y_log=True,
        line_arg="benched",
        line_vals=lines,
        line_names=lines,
        plot_name="reversing scan",
        args={
        }
    ),
])
def bench(benched, seqlen, device="cuda"):
        BS, device, exp, vals, out_exp, out_vals = init_rand(seqlen)
        match benched:
            case "in_place_flip":
                subject = lambda:kernel3[(1,)](exp, vals, BS, True)
            case "in_place_rev":
                subject = lambda:kernel7[(1,)](exp, vals, BS, True)
            case "out_ref_rev":
                subject = lambda:kernel8[(1,)](exp, vals, out_exp, out_vals, BS, True)
        ms = triton.testing.do_bench(subject, warmup=1000, rep=200)
        print(f"{seqlen=};\t{benched=};{ms=}")
        return ms

bench.run(save_path="./bench_results", print_data=True)

output

reversing scan:
      seqlen  in_place_flip  in_place_rev  out_ref_rev
0      128.0       0.007101      0.004218     0.003976
1      256.0       0.004264      0.003972     0.003923
2      512.0       0.004228      0.003937     0.004008
3     1024.0       0.004271      0.004030     0.003985
4     2048.0       0.004283      0.004002     0.003961
5     4096.0       0.004266      0.004031     0.003958
6     8192.0       0.004274      0.003972     0.004049
7    16384.0       0.004264      0.003959     0.004006
8    32768.0       0.004260      0.003984     0.004011
9    65536.0       0.004251      0.004020     0.004054
10  131072.0       0.004317      0.004037     0.004085
11  262144.0       0.004282      0.004010     0.004020
12  524288.0       0.004318      0.004024     0.004000

image

and for reverse=False

reversing scan:
      seqlen  in_place_flip  in_place_rev  out_ref_rev
0      128.0       0.006323      0.004214     0.004002
1      256.0       0.003979      0.004000     0.003952
2      512.0       0.003979      0.004011     0.004018
3     1024.0       0.004027      0.003972     0.004007
4     2048.0       0.003967      0.004033     0.003976
5     4096.0       0.004024      0.003985     0.004023
6     8192.0       0.004042      0.004050     0.004083
7    16384.0       0.003983      0.004045     0.003997
8    32768.0       0.004036      0.004020     0.004061
9    65536.0       0.004019      0.004021     0.004017
10  131072.0       0.004057      0.004016     0.004059
11  262144.0       0.004033      0.004002     0.004021
12  524288.0       0.003994      0.003968     0.004008

image

On the reverse=True, the results on a T4 is tends to favour the in-place reversed index variant, but on a A100 it was inconclusive, mind you this is ran on google collab. in-place flip does worse.

On the reverse=False, they are all mostly the same.

PheelaV avatar Jul 21 '24 15:07 PheelaV

We debugged it a bit and foudn that this issue only occurs with length <32. I am going to send a PR to block these sequences for now and also try to debug why that happens.

srush avatar Jul 24 '24 19:07 srush

Are there any updates on this? I have also observed the following behavior with <32 length, on both AMD and NVIDIA gpus.

Hprairie avatar Jul 28 '24 20:07 Hprairie

@Hprairie Not yet as far as I know, but look at the workarounds in the details of the report. The reversed=False works well, one can load the memory pointers in reversed order and save them back like that, or for <32 elements you might be better off with a linear scan using a simple for loop.

PheelaV avatar Jul 28 '24 22:07 PheelaV

yeah I think under 32 this is not going to be fast anyway. That being said I am looking into a fix.

srush avatar Jul 29 '24 10:07 srush

I have found that tl.flip doesn't fix the problem and that it still persists when doing something like tl.flip(tl.scan(tl.flip())). It may also be happening for >32 lengths when running a scan with a 2d tensor. I will work on getting a reproducible script to help with debugging.

Hprairie avatar Aug 12 '24 05:08 Hprairie

I also have the same issue when using tl.cumsum (which uses the scan) with reverse=True.

anasiri avatar Aug 19 '24 18:08 anasiri

What is the status of this?

bhack avatar Dec 14 '24 22:12 bhack