triton icon indicating copy to clipboard operation
triton copied to clipboard

[BACKEND] BF16 atomic_add support

Open plotfi opened this issue 7 months ago • 1 comments

This PR adds BF16 support for atomics, which are less precise but cheaper

BF16 accumulators have proven to be useful in the context of Split-K's where it is necessary to have cheaper atomic accumulation across two SMs

BF16 Atomics are also needed some of the following AMD related work:

  • AMD buffer atomics (ie BufferAtomicRMWOp)
  • There is also a for a path to add unit tests for bf16 atomics for AMD's backend

BF16 atomics across A100, H100 and MI300 at: https://godbolt.org/z/jW3EMbxrG

@bertmaher @SamGinzburg @davidberard98

New contributor declaration

  • [X] I am not making a trivial change, such as fixing a typo in a comment.

  • [X] I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • [X] I have added tests.
      • /test for lit tests
      • /python/test for end-to-end tests

plotfi avatar Apr 17 '25 00:04 plotfi

@scxiao This is the PR to enable BF16 atomics in Triton

SamGinzburg avatar Apr 17 '25 15:04 SamGinzburg

I'd go ahead and publish this for review by the core team; we've had a sufficient number of asks for bf16 atomic_add that it feels like we should go ahead and introduce this.

bertmaher avatar Apr 25 '25 16:04 bertmaher

I'd go ahead and publish this for review by the core team; we've had a sufficient number of asks for bf16 atomic_add that it feels like we should go ahead and introduce this.

Triton on HIP backend also received a lot of requests to support bf16 atomic ops.

scxiao avatar Apr 25 '25 20:04 scxiao

Hi @ptillet , @ThomasRaoux , what do you all think of this? I know we've been around on bf16 atomic_add before, but we've pretty regularly been asked about it by internal users, and given it's supported by the hardware, it seems reasonable for Triton to also support it, I'd think?

bertmaher avatar Apr 28 '25 17:04 bertmaher

@joviliast Feedback welcome

plotfi avatar Apr 28 '25 17:04 plotfi

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

joviliast avatar Apr 29 '25 11:04 joviliast

Hi @ptillet , @ThomasRaoux , what do you all think of this? I know we've been around on bf16 atomic_add before, but we've pretty regularly been asked about it by internal users, and given it's supported by the hardware, it seems reasonable for Triton to also support it, I'd think?

I agree, I'm supportive of this change and unless @ptillet has a concern we should go ahead with this.

ThomasRaoux avatar Apr 29 '25 11:04 ThomasRaoux

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

+1, I wonder if we can add few cases to existing test_atomic_rmw instead?

ThomasRaoux avatar Apr 29 '25 11:04 ThomasRaoux

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

Hi @plotfi, @joviliast in his PR #6418 added a few other test cases, maybe you can make the same changes here to cover more cases.

scxiao avatar Apr 29 '25 13:04 scxiao

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

+1, I wonder if we can add few cases to existing test_atomic_rmw instead?

@joviliast @ThomasRaoux I started adding test cases, and I'm not wondering if bf16 could work with any atom operation other than an fadd?

plotfi avatar May 01 '25 01:05 plotfi

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

+1, I wonder if we can add few cases to existing test_atomic_rmw instead?

@joviliast @ThomasRaoux I started adding test cases, and I'm not wondering if bf16 could work with any atom operation other than an fadd?

the spec suggests that min/max should work as well? https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-atom

ThomasRaoux avatar May 01 '25 01:05 ThomasRaoux

the spec suggests that min/max should work as well? https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-atom

The PTX spec looks a little strange here in regards to max/min because it looks like it is supported for the "Atomic operation with vector type" but not the "Atomic operation with scalar type" section. Am I missing something?

Maybe this is why for LLVM codegen, it falls back to a CAS here:

https://godbolt.org/z/7Yn3snxcs

Edit: Ah Ok, I see now how Triton handles min/max for float:

https://github.com/triton-lang/triton/blob/c7fc1e38d03982910e94125f28817dbd81162cbe/python/triton/language/semantic.py#L1460-L1462

Currently they seem to only handle 32/64bit float. I think it makes sense to add f16/bf16 but in a separate patch?

plotfi avatar May 01 '25 02:05 plotfi

Currently they seem to only handle 32/64bit float. I think it makes sense to add f16/bf16 but in a separate patch?

fine with me

ThomasRaoux avatar May 01 '25 13:05 ThomasRaoux

@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though.

plotfi avatar May 01 '25 19:05 plotfi

@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though.

any chance you can merge the test_core test with the existing atomic tests as suggested in this comment: https://github.com/triton-lang/triton/pull/6519#issuecomment-2838956756

rest looks good to me

ThomasRaoux avatar May 01 '25 19:05 ThomasRaoux

@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though.

any chance you can merge the test_core test with the existing atomic tests as suggested in this comment: #6519 (comment)

rest looks good to me

Ah yeah, I did managed to get these tests working with bf16. had to work around the numpy issues as #6519 did. Are there any other tests you'd like to integrate? I will make sure the tests I enabled at the very least cover the same set that #6519 did.

plotfi avatar May 01 '25 22:05 plotfi

@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though.

any chance you can merge the test_core test with the existing atomic tests as suggested in this comment: #6519 (comment)

rest looks good to me

Took a closer look at #6519, it seems what I did to handle the BF16 with NP differs slightly but does about the same thing. I did use FP16 for the NP code and the accuracy checks do pass this way, with one exception (where I modified the atol). Let me know if we want to keep accuracy checks when comparing np.float16 versus cuda bfloat16.

Edit: the failed AMD tests have given me my answer

plotfi avatar May 01 '25 22:05 plotfi

Tested on H100 and MI300, with the mantissa drop and the rtol=0.2/0.5 change the tests pass and we still get some sense for the correctness of the result

@ThomasRaoux Does this seem good to you so far?

plotfi avatar May 05 '25 00:05 plotfi

Tested on H100 and MI300, with the mantissa drop and the rtol=0.2/0.5 change the tests pass and we still get some sense for the correctness of the result

@ThomasRaoux Does this seem good to you so far?

Trying to tune the relative tolerance. I still got a CI error on one of the tests.

plotfi avatar May 05 '25 18:05 plotfi

The nvidia-h100 tester appears to be stuck

plotfi avatar May 06 '25 20:05 plotfi

@ThomasRaoux Any other suggestions? I think the H100 CI is stuck, but otherwise things are passing everywhere else with the more integrated testing.

plotfi avatar May 06 '25 22:05 plotfi

looks like a real regression on AMD?

ThomasRaoux avatar May 08 '25 14:05 ThomasRaoux

looks like a real regression on AMD?

It looks like a slight precicion discrepancy in 1.6% of the results:

        if dtype_x_str == 'bfloat16':
>           torch.testing.assert_close(dst_ref, dst, rtol=0.01, atol=1e-2)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 1 / 64 (1.6%)
E           Greatest absolute difference: 0.015625 at index (3, 7) (up to 0.01 allowed)
E           Greatest relative difference: 0.050048828125 at index (3, 7) (up to 0.01 allowed)

I was going to up the rtol to 0.02 and the the atol to 0.06. wdyt @ThomasRaoux ?

plotfi avatar May 08 '25 20:05 plotfi

looks like a real regression on AMD?

It looks like a slight precicion discrepancy in 1.6% of the results:

        if dtype_x_str == 'bfloat16':
>           torch.testing.assert_close(dst_ref, dst, rtol=0.01, atol=1e-2)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 1 / 64 (1.6%)
E           Greatest absolute difference: 0.015625 at index (3, 7) (up to 0.01 allowed)
E           Greatest relative difference: 0.050048828125 at index (3, 7) (up to 0.01 allowed)

I was going to up the rtol to 0.02 and the the atol to 0.06. wdyt @ThomasRaoux ?

I see, yeah let's just update the tolerance for bf16 there

ThomasRaoux avatar May 08 '25 20:05 ThomasRaoux

@ThomasRaoux On my gfx942 machine this test seems to weirdly pass though. Having trouble reproducing this locally.

plotfi avatar May 08 '25 20:05 plotfi

@ThomasRaoux On my gfx942 machine this test seems to weirdly pass though. Having trouble reproducing this locally.

could it be because of non-deterministic add order?

ThomasRaoux avatar May 08 '25 20:05 ThomasRaoux

@ThomasRaoux On my gfx942 machine this test seems to weirdly pass though. Having trouble reproducing this locally.

could it be because of non-deterministic add order?

I was thinking that some non-determinism could be the culprit, but I tried running:

while true;
do;
    python3 -m pytest \
        python/test/unit/language/test_core.py::test_tensor_atomic_add_access_patterns[shape213-decrease-4-1-bfloat16];
done

But it seems to pass every time..

Edit:

Wait, ran it all over again and waited a bit and it hit the non-deterministic failure case. And its not even at the same index as the CI.

plotfi avatar May 08 '25 20:05 plotfi

@ThomasRaoux On my gfx942 machine this test seems to weirdly pass though. Having trouble reproducing this locally.

could it be because of non-deterministic add order?

Do we want to alter this test to sem='acquire' instead of sem='relaxed' or should I up the rtol/atol tolerance?

plotfi avatar May 08 '25 20:05 plotfi

sem does not make the atomics more deterministic, changing rtol/atol is the only thing that makes sense.

peterbell10 avatar May 08 '25 20:05 peterbell10

sem does not make the atomics more deterministic, changing rtol/atol is the only thing that makes sense.

Sorry my mistake

plotfi avatar May 09 '25 01:05 plotfi