Feature Request: `tl.atomic_add` for bfloat16
For additional context, see pytorch/pytorch#97016. torch.index_put(..., accumulate=True) currently fails for torch.bfloat16 under torch.compile because tl.atomic_add doesn't support BFloat16.
The PTX instruction atom.add.bf16 requires compute capability 9.0+, however when you compile atomicAdd in CUDA with compute capability 8.0+ it generates a CAS loop instead. Would it be reasonable for triton to do the same?
Would it be reasonable for triton to do the same?
I think so.
How is the work on this feature coming along?
+1
+1, a representive case: split-k bfloat16 GEMM
I started work on this one, some preliminaries are at: https://github.com/plotfi/triton/commit/a9d3ce59cfddc9917438727e4df8969bef46b597
One thing to note is atomicAdd with bfloat16 is only supported on Hopper (sm_90). The cuda library's atomicAdd does a atomicCAS fall back (which I am going to try and teach the triton lowering to do):
Ampere Fallback:
$L__BB0_1:
// begin inline asm
{.reg .b16 c;
mov.b16 c, 0x3f80U;
fma.rn.bf16 %rs5,%rs2,c,%rs8;}
// end inline asm
atom.global.cas.b16 %rs4, [%rd1], %rs8, %rs5;
setp.ne.s16 %p1, %rs8, %rs4;
mov.u16 %rs8, %rs4;
@%p1 bra $L__BB0_1;
ret;
Hopper:
// begin inline asm
{ atom.add.noftz.bf16 %rs1,[%rd1],%rs2; }
// end inline asm
ret;
Any update on the BF16 atomic add support for device other than Hopper?
@plotfi here's a version with Triton that works but it's very slow:
@triton.jit
def atomic_add_cas(ptr, value, Lock, mask=None, sem: tl.constexpr = 'release'):
while tl.atomic_cas(Lock, 0, 1, sem=sem) == 1:
pass
tl.store(ptr, tl.load(ptr, mask=mask) + value, mask=mask)
tl.debug_barrier()
tl.atomic_xchg(Lock, 0)
Were you able to do it via inlined ptx by any chance?
By the way, bfloat16 atomic addition also crashes with Hopper in Triton.
+1 on the case of split-k bfloat16 GEMM.
Would love to learn more about the engineering challenges and help figuring out.
This is done via 236f6b54ce337db009ea573915022dafdbf61b82.