[Backend] Optimize membar insertion on hopper
New contributor declaration
-
[x ] I am not making a trivial change, such as fixing a typo in a comment.
-
[ x] I have written a PR description following these rules.
-
[ x] I have run
pre-commit run --from-ref origin/main --to-ref HEAD. -
Select one of the following.
- [ ] I have added tests.
-
/testforlittests -
/unittestfor C++ tests -
/python/testfor end-to-end tests
-
- [ ] This PR does not need a test because
FILL THIS IN.
- [ ] I have added tests.
-
Select one of the following.
- [x ] I have not added any
littests. - [ ] The
littests I have added follow these best practices, including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
- [x ] I have not added any
Per my benchmarking, redundant bar has around 10% performance impact on the example included: before:
Benchmarking matmul_warp_specialized on hopper
====================================
K warp-specialized cublas
512 515.24 575.15
1024 565.71 637.72
2048 590.59 621.93
4096 597.59 636.09
8192 606.03 665.42
16384 629.67 643.29
after:
Benchmarking matmul_warp_specialized on hopper
====================================
K warp-specialized cublas
512 533.60 572.95
1024 607.96 635.05
2048 644.24 657.84
4096 653.42 664.32
8192 672.24 679.02
16384 680.23 671.07
Some functionality of this pr may be a duplicate of https://github.com/triton-lang/triton/pull/7846/files .
But some others parts, like "mbarrier.try_wait should function as a synchronization" is not. This has potential problem, other warps are not guaranteed to have reached mbarrier.try_wait.
I have tested https://github.com/triton-lang/triton/pull/7846 locally, the performance is not as good:
Benchmarking matmul_warp_specialized on hopper
====================================
K warp-specialized cublas
512 521.02 580.96
1024 566.28 655.77
2048 602.45 621.87
4096 601.08 629.14
8192 613.34 639.98
16384 623.64 690.13
Adding some printf in Membar.cpp shows some bar not eliminated, between mbarrier.try_wait and mbarrier.arrive.
Also I have found that, the mbarrier.init outside the ttg.warp-specialization-op has reached warp specialization partitions, and causes extra local_barrier.
As there are implicit __syncthreads() before entering and after leaving ttg.warp-specialization-op, this propagation should be stopped.
Also I have found that, the mbarrier.init outside the ttg.warp-specialization-op has reached warp specialization partitions, and causes extra local_barrier. As there are implicit
__syncthreads()before entering and after leaving ttg.warp-specialization-op, this propagation should be stopped.
yeah warp_specialization op can count as a bar sync