triton icon indicating copy to clipboard operation
triton copied to clipboard

Feature Request: `tl.atomic_add` for bfloat16

Open peterbell10 opened this issue 2 years ago • 7 comments

For additional context, see pytorch/pytorch#97016. torch.index_put(..., accumulate=True) currently fails for torch.bfloat16 under torch.compile because tl.atomic_add doesn't support BFloat16.

The PTX instruction atom.add.bf16 requires compute capability 9.0+, however when you compile atomicAdd in CUDA with compute capability 8.0+ it generates a CAS loop instead. Would it be reasonable for triton to do the same?

peterbell10 avatar Mar 22 '23 18:03 peterbell10

Would it be reasonable for triton to do the same?

I think so.

How is the work on this feature coming along?

daadaada avatar May 09 '23 18:05 daadaada

+1

EricSteinberger avatar Oct 06 '23 02:10 EricSteinberger

+1, a representive case: split-k bfloat16 GEMM

LyricZhao avatar Oct 30 '23 03:10 LyricZhao

I started work on this one, some preliminaries are at: https://github.com/plotfi/triton/commit/a9d3ce59cfddc9917438727e4df8969bef46b597

One thing to note is atomicAdd with bfloat16 is only supported on Hopper (sm_90). The cuda library's atomicAdd does a atomicCAS fall back (which I am going to try and teach the triton lowering to do):

Ampere Fallback:

$L__BB0_1:
        // begin inline asm
        {.reg .b16 c;
  mov.b16 c, 0x3f80U;
  fma.rn.bf16 %rs5,%rs2,c,%rs8;}
        // end inline asm
        atom.global.cas.b16     %rs4, [%rd1], %rs8, %rs5;
        setp.ne.s16     %p1, %rs8, %rs4;
        mov.u16         %rs8, %rs4;
        @%p1 bra        $L__BB0_1;
        ret;

Hopper:

        // begin inline asm
        { atom.add.noftz.bf16 %rs1,[%rd1],%rs2; }
        // end inline asm
        ret;

plotfi avatar Nov 17 '23 00:11 plotfi

Any update on the BF16 atomic add support for device other than Hopper?

harveyp123 avatar Jul 01 '24 17:07 harveyp123

@plotfi here's a version with Triton that works but it's very slow:

@triton.jit
def atomic_add_cas(ptr, value, Lock, mask=None, sem: tl.constexpr = 'release'):
    while tl.atomic_cas(Lock, 0, 1, sem=sem) == 1:
        pass
    tl.store(ptr, tl.load(ptr, mask=mask) + value, mask=mask)
    tl.debug_barrier()
    tl.atomic_xchg(Lock, 0)

Were you able to do it via inlined ptx by any chance?

mobicham avatar Mar 22 '25 14:03 mobicham

By the way, bfloat16 atomic addition also crashes with Hopper in Triton.

mobicham avatar Mar 24 '25 14:03 mobicham

+1 on the case of split-k bfloat16 GEMM.

Would love to learn more about the engineering challenges and help figuring out.

austin362667 avatar May 20 '25 10:05 austin362667

This is done via 236f6b54ce337db009ea573915022dafdbf61b82.

antiagainst avatar May 20 '25 15:05 antiagainst

This is done via 236f6b5.

Will you release this feature in the next version of triton?

wenhaoli-xmu avatar Jun 06 '25 05:06 wenhaoli-xmu