triton icon indicating copy to clipboard operation
triton copied to clipboard

Blocksparse MatMul bugs

Open btyu opened this issue 3 years ago • 9 comments

Hi, I'm using the blocksparse matmul. I think I met some bugs.

Here's my code:

import torch
from sparse_matmul import matmul  # same as matmul.py in the current version of triton

B, H, L, D = 4, 8, 1024, 512
block = 16
device = 'cuda:2'
requires_grad = False

a = torch.rand((B, H, L, D), dtype=torch.float32, device=device, requires_grad=requires_grad)
b = torch.rand((B, H, L, D), dtype=torch.float32, device=device, requires_grad=requires_grad)

layout = torch.randint(0, 2, (H, L // block, L // block))

dot = matmul(layout, block, 'sdd', trans_a=False, trans_b=True)
c = dot(a, b)

assert c.shape[1] == layout.sum(), (c.shape[1], layout.sum().item())

Bug 1: Sometimes, the last assertion fails:

AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_30700/3882726527.py in <module>
----> 1 assert c.shape[1] == layout.sum(), (c.shape[1], layout.sum().item())

AssertionError: (16452, 16446)

Bug 2: Sometimes, the dot line even cannot run but raise the following error:

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_28701/2809911630.py in <module>
----> 1 r = dot(padded_q, padded_k)

~/ms/sparse_matmul.py in __call__(self, a, b)
    558         c_lut, c_num_locks, c_width, c_packs,\
    559         da_lut, da_num_locks, da_width, da_packs,\
--> 560         db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
    561 
    562         # If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior

~/ms/sparse_matmul.py in make_lut(self, dtype, device)
    505         # DA look-up table
    506         if self.mode == 'sdd':
--> 507             da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device)
    508         elif self.mode == 'dsd':
    509             da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device)

~/ms/sparse_matmul.py in dsd_lut(layout, block, step, trans, device)
    316     offsets  = offsets*2*div + 4*width
    317     segments = segments*div
--> 318     header   = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
    319     # create increments
    320     incs     = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()

RuntimeError: stack expects each tensor to be equal size, but got [384] at entry 0 and [364] at entry 2

(The file "~/ms/sparse_matmul.py" is just the matmul.py in the current version of triton. I copied it out and ran it with old triton v1.0, since I cannot run triton v1.1 as I mentioned here)

Looking forward to your help!

btyu avatar Oct 30 '21 15:10 btyu

Meet the same error in v1.1.1 (the latest version in pip)

xysmlx avatar Nov 03 '21 13:11 xysmlx

There are a few known issues in the case where the layout has a row/column full of zeros. Can you check whether this is the case?

ptillet avatar Nov 04 '21 08:11 ptillet

Hi @ptillet , thank you for your suggestion.

This script still report the problem in "bug 1" when the layout does not has a row/column full of zeros.

The problem in "bug 2" is because the layout has a row/column full of zeros.

xysmlx avatar Nov 04 '21 09:11 xysmlx

Hi @ptillet , thank you for your suggestion.

This script still report the problem in "bug 1" when the layout does not has a row/column full of zeros.

The problem in "bug 2" is because the layout has a row/column full of zeros.

@ptillet Yes, I agree with @xysmlx .

btyu avatar Nov 04 '21 09:11 btyu

Hi! I'm trying to run triton.ops.blocksparse.matmul but struggling with the error:

matmul = triton.ops.blocksparse.matmul(layout.cpu(), block_size, 'dds', trans_a=False, trans_b=True) out = matmul(x, w)

Traceback (most recent call last): .... File "/slot/sandbox/j/.local/lib/python3.6/site-packages/triton/code_gen.py", line 563, in _compile name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache) RuntimeError: CUDA: Error- unknown

It seems, that there is a problem in compatibility of CUDA and Triton versions or something? I have tried the following triton versions: 1.0.0, 1.1.1, 2.0.0.dev20220305 Neither is working.

NVIDIA-SMI 450.119.04 Driver Version: 450.119.04 CUDA Version: 11.0 Cuda compilation tools, release 11.1, V11.1.74 Build cuda_11.1.TC455_06.29069683_0

cydoroga avatar Mar 10 '22 11:03 cydoroga

Hi @cydoroga !

The v1 is buggy, and the author of Triton once recommended me to try v2.0. And so far v2.0 works for me, on CUDA 11.3.

btyu avatar Mar 10 '22 12:03 btyu

Hi @btyu! Thanks for the answer. As I said, I've tried v2.0 - the latest available version. It does not work.

Totally forgotten to mention: I was able to run triton v1.0.0 on another device with the same versions of CUDA and CUDA toolkit. And it works fine. But on this particular machine it is not working for some reason.

cydoroga avatar Mar 10 '22 12:03 cydoroga

@cydoroga What I mean is that you may need v11.3.

As mentioned in another issue (https://github.com/openai/triton/issues/322#issuecomment-950734254), they solved their problem by upgrading to CUDA v11.4.2. And to my experience, v11.3 also works fine. I actually don't know whether your problem has something to do with the CUDA version, so it's just a personal sugguestion.

As for that v1.0.0 works on the other device, yes v1.0.0 indeed sometimes can work on lower CUDA version but not always. As I said v1 is buggy, sometimes it works sometimes it does not. So the most direct way is to use v2.0 with CUDA v11.3+.

If this cannot solve it, I suggest you open a new issue and ask a Triton guy for help :)

btyu avatar Mar 10 '22 12:03 btyu

@cydoroga Yep, @btyu is right that CUDA 11.4.0 segfaults while compiling Triton blocksparse kernels. It's a ptxas bug that has been solved in CUDA 11.4.2. I would recommend using CUDA 11.5 if possible though, as it will provide the best performance on top of increased stability

ptillet avatar Mar 14 '22 14:03 ptillet

closing, as it seems to work now

ptillet avatar Feb 18 '23 01:02 ptillet