triton
triton copied to clipboard
Blocksparse.matmul result does not align with torch
Hello,
in the following code the result returned by triton.ops.blocksparse.matmul
and torch.einsum
do not align (please nolayout
consists of all ones).
My understanding is that both outputs should be the same.
Thank you in advance!
import torch
from triton.ops.blocksparse.matmul import matmul
NumHeads, SeqLen, BlockSize, Embed = 1, 64, 32, 128
layout = torch.ones(NumHeads, Embed // BlockSize, SeqLen // BlockSize).long()
q = torch.randn((1, NumHeads, SeqLen, Embed), dtype=torch.float32).contiguous().to('cuda')
k = torch.randn((1, NumHeads, Embed, SeqLen), dtype=torch.float32).contiguous().to('cuda')
mm = matmul(layout, BlockSize, mode='dds', device='cuda')
o_tn = mm(q,k)
o_torch = torch.einsum('bhsd,bhdo->bhso', q, k)
torch.allclose(o_tn, o_torch, atol=1e-1)