triton
                                
                                 triton copied to clipboard
                                
                                    triton copied to clipboard
                            
                            
                            
                        Flip or reverse associative scan?
I am trying to figure out if there is any clever way to flip a tensor along an axis or run a right-to-left associative sum? I know I can load a tensor in the reverse direction, but ideally I would like to be able to do this without reloading all the tensors in reverse order.
It's not possible for now. May I know what's your use case?
https://github.com/srush/annotated-mamba/issues/1
It's not difficult to add a reverse option and modify our backend to support reverse=True. Unfortunately I'm occupied with other stuff, not sure if @ThomasRaoux is interested?
@srush or do you have any student interested in adding the reverse=True option? It might be interesting option since the support is not on the critical path and the student can get a chance to learn more about triton's backend.
Let me see if I can find a workaround, otherwise, sure, that would be fun.
Is reading the memory in reverse inside ScanOpToLLVM a good idea? I am taking this route in my CUDA implementation and the algorithm should be close to the one in ScanOpToLLVM. I'd be down to write some MLIR C++ for this.
Here's an in-memory reverse hack that seems to work for me. Unfortunately with tl.dot I get a segfault. (any recs for debugging those?)
L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)[:, None, :]
    y = (Ls[:, None] == L - Ls - 1)   
    z = tl.sum(x * y, 2)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)
Thanks for the solution. Seems like uses similar idea as triton's sort.
tl.dot doesn't support 3d matrix multiplications.
Since this is a hack, either tl.dot or tl.sum will be slower than the native reverse=True version anyway
Just to be clear the tl.dot version is a 2D mat mul. (code below that segfaults)
But yes I agree that a reverse=True is the best way. Unfortunately I am now running into a bug with tl.associativescan giving the wrong answer on the forward pass, so I am trying to debug that :cry:
L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)
    y = (Ls[:, None] == L - Ls - 1).to(tl.float32) # forget the exact syntax
    z = tl.dot(x * y, 2)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)
Unfortunately I am now running into a bug with tl.associativescan giving the wrong answer on the forward pass, so I am trying to debug that 😢
Please let us know if the bug is caused by triton
Although I cannot understand how this hack runs but I testes it and found I works well. It seems that @srush's code $L$ must be in must be power of 2.
Sorry, shouldn't have called it a hack. Here's an explanation of what it is doing.
However as @Jokeren notes my method requires creating an B x L x L intermediate. This wouldn't be a problem, but tl.dot seems pretty broken in that the following triton version segfaults for me.
import triton
import triton.language as tl
import torch
L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)
    M = (Ls[:, None] == L - Ls - 1).to(tl.float32)
    z = tl.dot(x, M)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)
    
x = (torch.arange(L) + torch.zeros(L, L)).float().cuda()
z = (torch.arange(L) + torch.zeros(L, L)).float().cuda()
reverse[(1,)](x, z)
Well, the above code works well for me. It did not report an error. Maybe you need to check your triton version?
What's your version? I was using nightly.
I am using triton 2.1.0, python 3.11.5. I directly copied your code and ran it in Jupyter notebook.