triton icon indicating copy to clipboard operation
triton copied to clipboard

[Backend] Optimize membar insertion on hopper

Open lijinpei opened this issue 6 months ago • 5 comments

New contributor declaration

  • [x ] I am not making a trivial change, such as fixing a typo in a comment.

  • [ x] I have written a PR description following these rules.

  • [ x] I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • [ ] I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • [ ] This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • [x ] I have not added any lit tests.
    • [ ] The lit tests I have added follow these best practices, including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)

lijinpei avatar Oct 05 '25 17:10 lijinpei

Per my benchmarking, redundant bar has around 10% performance impact on the example included: before:

Benchmarking matmul_warp_specialized on hopper
====================================
    K  warp-specialized    cublas
  512            515.24    575.15
 1024            565.71    637.72
 2048            590.59    621.93
 4096            597.59    636.09
 8192            606.03    665.42
16384            629.67    643.29

after:

Benchmarking matmul_warp_specialized on hopper
====================================
    K  warp-specialized    cublas
  512            533.60    572.95
 1024            607.96    635.05
 2048            644.24    657.84
 4096            653.42    664.32
 8192            672.24    679.02
16384            680.23    671.07

lijinpei avatar Oct 05 '25 17:10 lijinpei

Some functionality of this pr may be a duplicate of https://github.com/triton-lang/triton/pull/7846/files . But some others parts, like "mbarrier.try_wait should function as a synchronization" is not. This has potential problem, other warps are not guaranteed to have reached mbarrier.try_wait.

lijinpei avatar Oct 05 '25 17:10 lijinpei

I have tested https://github.com/triton-lang/triton/pull/7846 locally, the performance is not as good:

Benchmarking matmul_warp_specialized on hopper
====================================
    K  warp-specialized    cublas
  512            521.02    580.96
 1024            566.28    655.77
 2048            602.45    621.87
 4096            601.08    629.14
 8192            613.34    639.98
16384            623.64    690.13

Adding some printf in Membar.cpp shows some bar not eliminated, between mbarrier.try_wait and mbarrier.arrive.

lijinpei avatar Oct 05 '25 17:10 lijinpei

Also I have found that, the mbarrier.init outside the ttg.warp-specialization-op has reached warp specialization partitions, and causes extra local_barrier. As there are implicit __syncthreads() before entering and after leaving ttg.warp-specialization-op, this propagation should be stopped.

lijinpei avatar Oct 06 '25 08:10 lijinpei

Also I have found that, the mbarrier.init outside the ttg.warp-specialization-op has reached warp specialization partitions, and causes extra local_barrier. As there are implicit __syncthreads() before entering and after leaving ttg.warp-specialization-op, this propagation should be stopped.

yeah warp_specialization op can count as a bar sync

ThomasRaoux avatar Oct 06 '25 16:10 ThomasRaoux