xla
xla copied to clipboard
FP8 Windowed Einsums with Multiple All-Gather Dots
Enables FP8 windowed einsums with all-gathers that have multiple dot users by shifting the dequantization of the FP8 operands to the output of the while loop.
@philipphack I don't fully understand this pass. I pinged you, asking for help to understand it. Can you check your Google Chat pings? :)
@frgossen this PR helps enable collective matmul for FP8 gemms. This is needed to overlap Tensor Parallel communication. The expected speedup depends on the model configs. But in the recent MoE models, we can expect 10-20% speedup.
@frgossen the patterns are FP8 windowed einsums with input dequantizations (i.e. type-conversion to a wider type like FP16 or BF16 and scaling) and all-gathers that have multiple dot users. The dequantization is moved from the all-gather of the additional dot to the output of the while loop that holds the cached all-gather result as indicated in the comment. This enables the creation of a specific FP8 GEMM custom call in gemm_rewriter.cc which replaces dequantization and dot.
(Also rebased to resolve conflicts.)
@frgossen @cheshire @reedwm Could you please review the latest version of this PR?
this PR helps enable collective matmul for FP8 gemms. This is needed to overlap Tensor Parallel communication. The expected speedup depends on the model configs. But in the recent MoE models, we can expect 10-20% speedup.
@abhinavgoel95 It's great that we expect 10-20% speedup. Can you please provide exact model and hardware configuration and the speedup you are seeing with this change?
@toli-y Are you able to run GPT-5B, FP8 precision, with ICI_MESH = [1,1,8], on 8 H100 GPUs?
@frgossen can you PTAL?
@ezhulenev to nag reviewers or assign others.
@frgossen the motivation for this change is the extension of the existing functionality for collective matmuls with multiple all-gather dots in the windowed einsum handler (see this comment) to support FP8 inputs. In this case, the dot operates on a dequantization of an FP8 operand, i.e. a type conversion from FP8 to a wider type followed by a multiplication by a broadcast scaling factor. The current change adds functionality for moving the dequantization from between input and all-gather to between the windowed loop and the dot as explained in the updated comment. The dequantization -> dot sequence is rewritten in gemm_rewriter.cc into a specific FP8 GEMM custom call. This change does not interfere with the existing functionality of the windowed einsum handler.
The HLO of WindowedEinsumE2EAllgatherMultiConsumerF8 in collective_ops_e2e_test.cc can also be used for monitoring performance (FP8 requires Hopper or newer architectures).
@frgossen for some reason I can't directly respond to these two comments:
Do you know why they interfer with the outcome? Is there a fundamental reason or can this be implemented more generically so that it does compose well with the rest of XLA?
The dot merger pass has a threshold for the input dimensions which means that it generally doesn't interfere in actual models. Since the test uses relatively small inputs, the pass is manually deactivated here to test the functionality.
One example here: I find it strange that the pattern matches a specific size tuple. If you rely on the SPMD partitioner to generate this specific pattern that seems overly fragile. What if the partitioner changes? Do the transformations here depend on the number of values passed through the while loop?
I think the general purpose of the windowed einsum handler is to transform windowed einsums that were created by the SPMD dot handler. As such, it relies on the output of that pass by design. The alternative would be the direct modification of the dot handler which is already relatively complex.
If this pass is so tightly coupled with the dot handler, shouldn't there be tests that reflect this? Do you think it makes sense to add test cases that test the two passes together?
@frgossen that's what the end-to-end tests in collective_ops_e2e_test.cc do. The SPMD dot handler transforms the input HLO into a collective matmul/windowed einsum which is transformed again by the windowed einsum handler.
Can you add a comment?
Done.