triton icon indicating copy to clipboard operation
triton copied to clipboard

Flip or reverse associative scan?

Open srush opened this issue 1 year ago • 16 comments

I am trying to figure out if there is any clever way to flip a tensor along an axis or run a right-to-left associative sum? I know I can load a tensor in the reverse direction, but ideally I would like to be able to do this without reloading all the tensors in reverse order.

srush avatar Jan 13 '24 05:01 srush

It's not possible for now. May I know what's your use case?

Jokeren avatar Jan 13 '24 15:01 Jokeren

https://github.com/srush/annotated-mamba/issues/1

srush avatar Jan 13 '24 22:01 srush

It's not difficult to add a reverse option and modify our backend to support reverse=True. Unfortunately I'm occupied with other stuff, not sure if @ThomasRaoux is interested?

Jokeren avatar Jan 14 '24 02:01 Jokeren

@srush or do you have any student interested in adding the reverse=True option? It might be interesting option since the support is not on the critical path and the student can get a chance to learn more about triton's backend.

Jokeren avatar Jan 14 '24 02:01 Jokeren

Let me see if I can find a workaround, otherwise, sure, that would be fun.

srush avatar Jan 14 '24 02:01 srush

Is reading the memory in reverse inside ScanOpToLLVM a good idea? I am taking this route in my CUDA implementation and the algorithm should be close to the one in ScanOpToLLVM. I'd be down to write some MLIR C++ for this.

proger avatar Jan 14 '24 12:01 proger

Here's an in-memory reverse hack that seems to work for me. Unfortunately with tl.dot I get a segfault. (any recs for debugging those?)

L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)[:, None, :]
    y = (Ls[:, None] == L - Ls - 1)   
    z = tl.sum(x * y, 2)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)

image

srush avatar Jan 14 '24 22:01 srush

Thanks for the solution. Seems like uses similar idea as triton's sort.

tl.dot doesn't support 3d matrix multiplications. Since this is a hack, either tl.dot or tl.sum will be slower than the native reverse=True version anyway

Jokeren avatar Jan 14 '24 22:01 Jokeren

Just to be clear the tl.dot version is a 2D mat mul. (code below that segfaults)

But yes I agree that a reverse=True is the best way. Unfortunately I am now running into a bug with tl.associativescan giving the wrong answer on the forward pass, so I am trying to debug that :cry:

L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)
    y = (Ls[:, None] == L - Ls - 1).to(tl.float32) # forget the exact syntax
    z = tl.dot(x * y, 2)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)

srush avatar Jan 15 '24 15:01 srush

Unfortunately I am now running into a bug with tl.associativescan giving the wrong answer on the forward pass, so I am trying to debug that 😢

Please let us know if the bug is caused by triton

Jokeren avatar Jan 15 '24 15:01 Jokeren

Although I cannot understand how this hack runs but I testes it and found I works well. It seems that @srush's code $L$ must be in must be power of 2.

confucianism72 avatar Jan 16 '24 15:01 confucianism72

Sorry, shouldn't have called it a hack. Here's an explanation of what it is doing.

image

However as @Jokeren notes my method requires creating an B x L x L intermediate. This wouldn't be a problem, but tl.dot seems pretty broken in that the following triton version segfaults for me.

import triton
import triton.language as tl
import torch
L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)
    M = (Ls[:, None] == L - Ls - 1).to(tl.float32)
    z = tl.dot(x, M)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)
    
x = (torch.arange(L) + torch.zeros(L, L)).float().cuda()
z = (torch.arange(L) + torch.zeros(L, L)).float().cuda()
reverse[(1,)](x, z)

srush avatar Jan 16 '24 15:01 srush

Well, the above code works well for me. It did not report an error. Maybe you need to check your triton version?

confucianism72 avatar Jan 16 '24 16:01 confucianism72

What's your version? I was using nightly.

srush avatar Jan 16 '24 16:01 srush

I am using triton 2.1.0, python 3.11.5. I directly copied your code and ran it in Jupyter notebook.

confucianism72 avatar Jan 16 '24 16:01 confucianism72

image

confucianism72 avatar Jan 16 '24 16:01 confucianism72