triton icon indicating copy to clipboard operation
triton copied to clipboard

Blocksparse.matmul result does not align with torch

Open oleksost opened this issue 5 months ago • 0 comments

Hello, in the following code the result returned by triton.ops.blocksparse.matmul and torch.einsum do not align (please nolayout consists of all ones).

My understanding is that both outputs should be the same.

Thank you in advance!

import torch
from triton.ops.blocksparse.matmul import matmul

NumHeads, SeqLen, BlockSize, Embed = 1, 64, 32, 128
layout = torch.ones(NumHeads, Embed // BlockSize, SeqLen // BlockSize).long()
q = torch.randn((1, NumHeads, SeqLen, Embed), dtype=torch.float32).contiguous().to('cuda')
k = torch.randn((1, NumHeads, Embed, SeqLen), dtype=torch.float32).contiguous().to('cuda')

mm = matmul(layout, BlockSize, mode='dds', device='cuda')
o_tn = mm(q,k)
o_torch = torch.einsum('bhsd,bhdo->bhso', q, k)

torch.allclose(o_tn, o_torch, atol=1e-1)    

oleksost avatar Sep 11 '24 20:09 oleksost