triton
triton copied to clipboard
[TUTORIALS] Minor fix for tutorial 06
Based on
@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)])
BATCH, N_HEADS, HEAD_DIM = 4, 32, 64
the HEAD_DIM is 64
in both pytest and benchmark, which triggers assertion failure from tl.static_assert(BLOCK_N <= HEAD_DIM)
since BLOCK_N can be 128
in the tunning.
Therefore, change the tunning size for BN
as [32, 64]
.
Both pytest and benchmark runs fine on GPU after the fix.