triton
triton copied to clipboard
Broadcasting issues with scalars
@triton.jit
def kernel0(out_ptr2, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
tmp6 = 1.0
# this version works
# tl.store(out_ptr2, tmp6)
# this version: AttributeError: 'pointer_type' object has no attribute 'get_block_shapes'
tl.store(out_ptr2, tmp6, xmask)
kernel0(..., numel=1, XBLOCK=1)
We were hitting some errors related to single element writes at the end of a reduction. In this case the mask is a bit pointless ([[True]]), but we need to generate workarounds for this.