xla
xla copied to clipboard
[GPU] Enable cuDNN GEMM fusion level 3 by default.
Affects only Hopper+ and cuDNN 9+: https://github.com/openxla/xla/blob/fb41b76a8b08216b80abb49ceb5c07373d9c45c5/xla/service/gpu/gemm_fusion_autotuner.cc#L556.
Description of fusion level 1: https://github.com/openxla/xla/blob/fb41b76a8b08216b80abb49ceb5c07373d9c45c5/xla/xla.proto#L742.
Approving this change to have it imported. I'll run internal benchmarks and look for runtime and compile time impact. I'll report back.
Here are the results across our internal benchmark suite. The tl;dr: is that runtime is not improved, but compile time takes a hit. HLO passes are slower (speedup of 0.66x), which can be explained by the autotuning overhead. NVPTX and ptxas are used less and therefore a bit faster. This makes it unattractive to enable xla_gpu_cudnn_gemm_fusion_level=1, in my opinion.
Metric Geomean Speedups/Size Reductions
H100
HLO memory read+written 1.00x
Buffer Allocations 1.00x
NVPTX Compilation time 1.04x
HLO Passes Time 0.66x
Run Backend Time 1.05x
Wall Time 1.00x
Device Time 1.00x
Device Memcpy Time 0.98x
Thank you! Given that there is an effect on NVPTX and ptxas compilation times it looks like some fusions do switch to cuDNN, but the impact on these benchmarks is not large enough to see it in the wall time. And it seems like there are no failures. In this case I would propose to raise the level to 2 or 3 and re-evaluate.
And it seems like there are no failures.
At least no crashes, that's right.
In this case I would propose to raise the level to 2 or 3 and re-evaluate.
Okay, we can do that.
OK, I raised it to 3.
Could you please also tell which exact cuDNN version did you try?
Could you please also tell which exact cuDNN version did you try?
We benchmarked this on cuDNN 9.0.0. Is this version recent enough?
In any case, we have benchmark results for xla_gpu_cudnn_gemm_fusion_level=2 on cuDNN 9.1.1.
The slowdown on HLO Passes Time can again be explained by the additional auto-tuning. The geomean wall time and device time is marginally improved. Most benchmarks are flat but a few show speedups. I'll confirm the speedups are real.
Metric Geomean Speedups/Size Reductions
H100
HLO memory read+written 1.00x
Buffer Allocations 1.00x
NVPTX Compilation time 1.07x
HLO Passes Time 0.45x
Run Backend Time 1.09x
Wall Time 1.01x
Device Time 1.01x
Device Memcpy Time 1.00x
Next step will be xla_gpu_cudnn_gemm_fusion_level=3 .
Most benchmarks are flat but a few show speedups. I'll confirm the speedups are real.
One internal 8B model is improved, but unfortunately not as much as the first run suggested. Re-running the model shows speedups around 1.02x.
Next step will be
xla_gpu_cudnn_gemm_fusion_level=3.
Results for xla_gpu_cudnn_gemm_fusion_level=3 with cuDNN 9.1.1 are below:
Metric Geomean Speedups/Size Reductions
H100
HLO memory read+written 1.00x
Buffer Allocations 1.00x
NVPTX Compilation time 1.02x
HLO Passes Time 0.53x
Run Backend Time 1.04x
Wall Time 0.99x
Device Time 1.00x
Device Memcpy Time 0.99x
Some benchmarks failed with xla_gpu_cudnn_gemm_fusion_level=3 and are not included in the results. I don't know the reasons for the failures yet, but will take a look.
Overall, cuDNN GEMMs do not perform better on H100 than Triton on our internal benchmark suite. The suite contains about 100 models, primarily transformers but also other popular architectures. This result is surprising, because Triton does not achieve peak performance on H100. Some optimizations like pipelining are currently disabled due to bugs. Did you do your own comparisons of cuDNN GEMMs against Triton on H100?
Overall, cuDNN GEMMs do not perform better on H100 than Triton on our internal benchmark suite. The suite contains about 100 models, primarily transformers but also other popular architectures. This result is surprising, because Triton does not achieve peak performance on H100. Some optimizations like pipelining are currently disabled due to bugs.
Do GEMM fusions take significant part of end-to-end time in your benchmarks?
Did you do your own comparisons of cuDNN GEMMs against Triton on H100?
Yes. Here is what we get on a set of 19 benchmarks of various model configurations from the JAX toolbox with cuDNN 9.1.1:
Total GEMM fusion kinds - 139 Out of them improved by cuDNN - 56 Geomean improvement - 1.10x Max improvement - 1.70x
I verified that GemmFusionAutotuner runs cuDNN GEMMs as expected. The current VLOGing doesn't make it super obvious which GEMM alternative is being picked. Maybe I'm missing something - any pointers are appreciated!
Total GEMM fusion kinds - 139 Out of them improved by cuDNN - 56 Geomean improvement - 1.10x Max improvement - 1.70x
Did you compute the Geomean for end-to-end benchmarks or the GEMMs individually?
I verified that cuDNN GEMMs are being used in benchmarks as expected. I indeed see dot ops of "kind":"__cudnn$fusion" in the optimized HLO. The runtime benefits are small though, but I only checked a single benchmark so far. Doing this at scale will require better VLOGing to aggregate data across many runs.
Did you compute the Geomean for end-to-end benchmarks or the GEMMs individually?
Individual GEMM fusions.
Doing this at scale will require better VLOGing to aggregate data across many runs.
@Amir-19 added an XLA flag --xla_gpu_dump_autotune_logs_to to gather such results. He is also working on a tool which summarizes these dumps to tables like this (example for 1 HLO module):
┏━━━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Fusion ┃ Count ┃ Backend ┃ cuDNN ┃ Plan ┃ Triton ┃ cuBLAS ┃ Improvement ┃
┡━━━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━┩
│ 1007 │ 1 │ cuBLAS │ inf │ None │ 9213 │ 3916 │ 1.00 │
│ 1572 │ 1 │ cuBLAS │ inf │ None │ 9322 │ 4126 │ 1.00 │
│ 1780 │ 1 │ cuDNN │ 638 │ 0 │ 1227 │ 674 │ 1.06 │
│ 1781 │ 39 │ cuDNN │ 645 │ 0 │ 1207 │ 676 │ 1.05 │
│ 2063 │ 40 │ cuDNN │ 311 │ 0 │ 436 │ 584 │ 1.40 │
│ 2150 │ 40 │ cuDNN │ 297 │ 9 │ 313 │ 365 │ 1.05 │
│ 2151 │ 40 │ cuDNN │ 283 │ 2 │ 302 │ 443 │ 1.07 │
└────────┴───────┴─────────┴───────┴──────┴────────┴────────┴─────────────┘
Geomean improvement: 1.08x
Mean improvement: 1.09x
Max improvement: 1.40x
(times are in microseconds, 'Backend' is the fastest one)
@Amir-19 added an XLA flag --xla_gpu_dump_autotune_logs_to to gather such results.
Perfect! This looks useful.
@sergachev since this PR is not ready for a re-review yet, can we convert it to draft for now? :)
I pushed 2 additional commits, just for this evaluation. I'll make them separate PRs if we decide to proceed with them. First reduces the number of cuDNN plans per fusion and the compilation time accordingly. Second changes the definitions of the fusion levels according to the expected performance improvements.
We propose to evaluate again level 1 with cuDNN 9.2.1: https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-2-1.