triton icon indicating copy to clipboard operation
triton copied to clipboard

Dot Product Computes Wrong Values

Open calclavia opened this issue 3 years ago • 22 comments

First off - this is really cool library! I'm trying to implement the causal dot product kernel, but I'm running into some issues. It could be either a bug or my misunderstanding of the documentation.

Triton code: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py#L26

The above implements the algorithm from "Transformers are RNNs" paper (Algorithm 1) in Triton. In summary, I'm trying to batch-parallelize a for loop that computes a prefix sum. This is a simple O(n) implementation (not the more efficient O(log(n)) version).

The equivalent naive Pytorch implementation is here: https://github.com/calclavia/Triton-Transformer/blob/master/ttx/attention/causal_product.py#L6

When running a simple unit test, I'm getting very different values.

ref_output tensor([[[[0.0157, 0.6251, 0.6993],
          [0.7910, 2.1930, 2.1000],
          [0.7413, 1.7217, 1.4139],
          [0.8863, 1.4795, 1.5222],
          [1.2453, 2.5844, 1.9665]]]])
triton_output tensor([[[0.4268, 0.6251, 0.6993],
         [2.4132, 0.6186, 0.3389],
         [0.8975, 0.4929, 0.2288],
         [0.8470, 0.0080, 0.3058],
         [1.1330, 0.7073, 0.0776]]])

I tested a single dimensional vector and was unable to get matching values. The logic seems to be correct, but I suspect the issue is related to tl.dot. If anyone has insights, I would appreciate comments/feedback!

calclavia avatar Aug 17 '21 08:08 calclavia

Hmm, this is interesting. Could you share a self-contained file that contains the content of causal_product_naive plus your failing unit test? Thanks!

ptillet avatar Aug 17 '21 16:08 ptillet

@ptillet You can clone that repository and run the test via:

python -m unittest

See README. Only requires Pytorch and Triton. The repo currently contains nothing but causal dot product.

calclavia avatar Aug 18 '21 03:08 calclavia

@ptillet Any updates on this? I'm hoping to figure out what's missing.

calclavia avatar Aug 22 '21 12:08 calclavia

Hey! Sorry I haven't had time to look into it. I've been busy working on older issues (#170 and #176 ) that require quite a bit of work

ptillet avatar Aug 22 '21 16:08 ptillet

I made and pushed some fixes, and now I'm running into a strange situation where the first column of my output doesn't match, but the rest of the elements pass the test:

ref_output tensor([[[[-1.8713, -3.2907,  2.6836,  1.0433]]]], device='cuda:0')
triton_output tensor([[[ 7.9835, -3.2907,  2.6836,  1.0433]]], device='cuda:0')

Replicate test on this commit:

python -m unittest tests.test_causal_dotprod.TestCausalDotProduct.test_causal_product_fwd_triton

There seems to be something wrong with tl.dot in terms of not working with the first column, and it's hard to debug why. I tried multiplying the dot product operands by constants to see where the cause was:

Changing:

output = tl.dot(q[None, :], state)

into

output = tl.dot(q[None, :] * 100, state * 0)

Now, I would expect the above matmul to just zero out the entire output, but the malformed first column is still there!

ref_output tensor([[[[-1.8713, -3.2907,  2.6836,  1.0433]]]], device='cuda:0')
triton_output tensor([[[79834.8906,     0.0000,     0.0000,     0.0000]]], device='cuda:0')

Seems like dot product took did not properly multiply the first element in the matrix, but multiplied the rest of them?

Another strange behavior is that: My kernel only "works" to the current extent if num_warps=1. If I don't set num_warps, it would output all zeros.

calclavia avatar Aug 23 '21 08:08 calclavia

@calclavia I'm really interested in this as well! Phil had told me that it wasn't possible through email, but maybe I didn't explain it correctly

lucidrains avatar Aug 31 '21 23:08 lucidrains

@lucidrains Great to see you're interested! I was hoping this could be a drop in replacement for your Performers repo to solve this issue (https://github.com/lucidrains/performer-pytorch/issues/44)

calclavia avatar Sep 01 '21 09:09 calclavia

@calclavia yes, me too... i feel like the whole "linear" attention line of research is being held back by this

lucidrains avatar Sep 01 '21 18:09 lucidrains

@calclavia have you tried cuda-python yet?

lucidrains avatar Sep 01 '21 18:09 lucidrains

Hey~ Sorry for the delay. I'll be looking into this tonight or tomorrow. Since the dot unit tests pass, I wonder what is causing the results to be incorerct in this case. Maybe some interaction between broadcasting and tensor cores.

ptillet avatar Sep 01 '21 20:09 ptillet

@lucidrains I haven't tried CUDA Python. I'm not an expert at CUDA programming, hence Triton seems like a nice alternative that is easy to learn and can get the job done.

A brief look at EPFL's CUDA seems to indicate they're using an O(n) implementation in their fallback code: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/causal_product/causal_product_cuda.cu#L1204 but the optimized version is too complex for me to understand. It should be possible to implement the O(log(n)) version on Triton, it'll just take some time to figure out how to port the prefix sum algorithm described by NVIDIA.

calclavia avatar Sep 02 '21 02:09 calclavia

@ptillet If it helps, I ran my tests on a NVIDIA RTX 3080 GPU, which is on the Ampere architecture.

calclavia avatar Sep 02 '21 02:09 calclavia

Hey! I looked into this and there are multiple issues at play:

  • I think what you ran into is an issue for dot products when the blocks are too small. It's quite tricky to get the thread partitioning right when the compiler needs to distribute 16 elements over hundreds of threads. This is one of the areas where the compiler is buggy at the moment.
  • It seems like with the most recent dev version, your causal attention makes the shared memory allocator hang. So that's another issue that I'll have to look into

FWIW I think causal attention will be much easier to implement when triton.language contains prefix sum primitives.

ptillet avatar Sep 02 '21 19:09 ptillet

@ptillet Thanks for the analysis. So I'm assuming these are Triton compiler side fixes, and not a bug on my implementation of Triton.

In terms of building a prefix sum primitive - how difficult do you think it would be to implement and what types of performance gains a primitive would likely yield? Ideally, wouldn't the compiler be able to "detect" these types of loop and optimize it to use a primitive behind the scenes?

calclavia avatar Sep 03 '21 02:09 calclavia

@calclavia It's a bit tricky. Our goal is actually to extend Triton-IR to support indexing operations, then prefix sum could just be implemented using an algorithm like https://en.wikipedia.org/wiki/Prefix_sum#Parallel_algorithms without having dedicated Triton-IR instructions. Then the compiler would look at all the indexing ops and add some shared memory barriers (and warp shuffles) accordingly to maximize performance without breaking correctness.

Pattern matching across loop boundaries tends to be pretty hard and brittle. So in order to avoid that, Triton tries really hard to find abstractions that make sense :p It'll take a while before indexing ops are part of the ecosystem (I'm thinking maybe by Jan - Feb), but hopefully we'll get a prototype before then that will be enough for causal attention.

Anyhow, the issues you're having are unrelated and I'll try to address them sooner than that :D

ptillet avatar Sep 03 '21 05:09 ptillet

Same wrong values for small blocks here! Is there any plan to fix such a bug?

tdzdog avatar Sep 30 '21 09:09 tdzdog

@ptillet I checked Triton v1.1 and I think this issue is still here (I pushed an update to my repo with the upgrade, so anyone can test it). Any updates on this?

calclavia avatar Feb 16 '22 06:02 calclavia

this should be fixed on the latest master branch, or the latest dev wheel

ptillet avatar Feb 16 '22 22:02 ptillet

@ptillet I've updated Triton but I'm getting segmentation fault on version triton 2.0.0.dev20220211

Here's what I'm trying to run...

Simple Pytorch function that does Causal Dot Product: https://github.com/calclavia/Triton-Transformer/blob/lrn/ttx/lrn/torch.py

Triton version (incomplete): https://github.com/calclavia/Triton-Transformer/blob/lrn/ttx/lrn/triton_2d.py

Running my unit test python -m unittest tests.test_lrn yields the following error (which I'm not sure what it means):

Traceback (most recent call last):
  File "/workspace/tests/test_lrn.py", line 25, in test_triton
    triton_output = lrn_fwd_triton(q, k, v, z1, z2, state)
  File "/workspace/ttx/lrn/triton.py", line 120, in lrn_fwd_triton
    V_BLOCK_SIZE=v_block_size
  File "/opt/conda/lib/python3.7/site-packages/triton/code_gen.py", line 783, in __call__
    return self.kernel(*wargs, **kwargs, grid=self.grid)
  File "/opt/conda/lib/python3.7/site-packages/triton/code_gen.py", line 774, in __call__
    self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
  File "/opt/conda/lib/python3.7/site-packages/triton/code_gen.py", line 724, in add_to_cache
    constants=constants,
  File "/opt/conda/lib/python3.7/site-packages/triton/code_gen.py", line 675, in _compile
    name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
IndexError: map::at

calclavia avatar Feb 20 '22 10:02 calclavia

@ptillet I did more digging to see which part of the kernel causes the error. Added some comments. In particular, if I try to load another vector q = tl.load(q_ptr + k_offsets[:, None], mask=k_mask[:, None], other=0) in the loop, I get a segfault. The tl.store at the end of the loop causes the IndexError: map::at error. The kernel seems to run if I comment out tl.store.

    for _ in range(0, length, 1):

        # Load a single row of K and V as matrices.
        # [K_BLOCK_SIZE, 1]
        k = tl.load(k_ptr + k_offsets,
                    mask=k_mask, other=0)[:, None]
        # [1, V_BLOCK_SIZE]
        v = tl.load(v_ptr + v_offsets, mask=v_mask, other=0)[None, :]

        # Compute context [V, 1] x [1, K] => [V, K]
        context = tl.dot(v, k)
        state += context

        # Load a single row of Q of shape [K, 1]
        # TODO: Loading this causes segfault
        # q = tl.load(q_ptr + k_offsets[:, None], mask=k_mask[:, None], other=0)

        # Compute output = S * Q. [V, K] x [K, 1] x  => [D, 1]
        # TODO: Correct equation
        # output = tl.dot(state, q)
        output = tl.dot(state, k)

        # TODO: Storing output causes IndexError: map::at
        tl.store(
            output_ptr + v_offsets[:, None],
            output,
            mask=v_mask[:, None]
        )

        # Move to next row
        k_offsets += kdim
        v_offsets += vdim

calclavia avatar Feb 23 '22 07:02 calclavia

I think this issue is related to https://github.com/openai/triton/issues/375 If I use tl.dot within a loop, it tends to be buggy with Segmentation fault

calclavia avatar Feb 24 '22 06:02 calclavia

Hey!

First of all, thanks for digging into this issue. Sorry for having been unresponsive, I've been quite busy with other matters, and I am well aware of the instabilities of triton 😅 I will look into this problem. FWIW, we are planning a rewrite of the backend that should greatly improve stability on these kinds of issues.

ptillet avatar Feb 24 '22 17:02 ptillet

Closing this, as the compiler now throws an error when the blocks are too small for tl.dot. Please submit a new issue with a repro if you have an alternative kernel that compiles

ptillet avatar Feb 22 '23 16:02 ptillet