triton
triton copied to clipboard
map:at error when loading two tensors with same ptrs and masks
repro: https://gist.github.com/pyjhzwh/2ba871a53c2eac6575948467317bafa1
matrix_x00 = tl.load(x00_ptrs, mask=mask_x00, other=0.)
matrix_x01 = tl.load(x01_ptrs, mask=mask_x01, other=0.)
where x00_ptrs and x01_ptrs are the same, mask_x00 and mask_x01 are the same. But it will throw map:at
error unless uncomment line 110 acc += tl.dot(matrix_x01, matrix_w00)
OR uncomment line 95 mask_x01 = mask_x00
.
But line 86-93 should do the same thing as line 95. I am not sure why the current code would throw map:at
error.