Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions.
This will require https://github.com/openxla/xla/pull/15399 to work.
Code for jax/_src/cudnn/fusion.py provided by @hawkinsp.
The required change in XLA is done, this one is ready.
Is there a minimum cudnn version for this test to pass?
9.0.
The change looks fine to me, but it crashes in CI.
Do I see right, that both failing checks are in non-GPU configurations?
Yeah, you're right. Probably you need to error if lowering on a non-CUDA platform?
I think just add platform="cuda" to the register_lowering. Currently you're asserting that lowering works everywhere.
You should also skip the test if not on cuda @jtu.run_on_devices("cuda") iirc.
Done.
Sorry, it took me a long time to look at this. This test fails in our internal CI because it seems on V100 (which we run in CI) the rewrite to a cudnn fusion does not happen. Instead, the after optimization hlo ends up with a cublas gemm. Is that intended? Should the test be gated on particular GPU generations?
It should run on H100. Is this https://github.com/google/jax/pull/22699/files#diff-77b54950a53c3196a56e8f570cb6dcd4eca602b5a8b4220f5cd2acb86f060e7fR1548 not sufficient to filter by GPU type?
Anyway, I looked at other tests and added a check with skipTest(). It actually works on Ampere+.
It should run on H100. Is this https://github.com/google/jax/pull/22699/files#diff-77b54950a53c3196a56e8f570cb6dcd4eca602b5a8b4220f5cd2acb86f060e7fR1548 not sufficient to filter by GPU type?
It appears not. However, in general BUILD rules aren't enough, because we support running the tests via other means such as pytest. So a BUILD rule filter is helpful (it stops us from running pointless tests), but the test should also skip itself if the hardware it needs isn't present.
The rewrite also seems to fail on A100?
I tested it on A100.
I'm still finding this to fail in CI. It looks like the cudnn fusion is produced at the HLO fed to the compiler, but for some reason it gets rewritten away.
Are we guaranteed that the fusion will be emitted, or can it sometimes be autotuned away or something? Are there any other circumstances under which the fusion will fall back?
Indeed, I examined the tests we have (https://github.com/openxla/xla/blob/main/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc#L27, https://github.com/openxla/xla/blob/main/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc#L709) and realised, that the latter one relies on xla_gpu_cublas_fallback(false).
Fix: https://github.com/google/jax/pull/23505