jax icon indicating copy to clipboard operation
jax copied to clipboard

Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions.

Open sergachev opened this issue 1 year ago • 6 comments

This will require https://github.com/openxla/xla/pull/15399 to work.

Code for jax/_src/cudnn/fusion.py provided by @hawkinsp.

sergachev avatar Jul 27 '24 13:07 sergachev

The required change in XLA is done, this one is ready.

sergachev avatar Aug 08 '24 21:08 sergachev

Is there a minimum cudnn version for this test to pass?

9.0.

sergachev avatar Aug 27 '24 16:08 sergachev

The change looks fine to me, but it crashes in CI.

hawkinsp avatar Aug 27 '24 16:08 hawkinsp

Do I see right, that both failing checks are in non-GPU configurations?

sergachev avatar Aug 27 '24 17:08 sergachev

Yeah, you're right. Probably you need to error if lowering on a non-CUDA platform?

I think just add platform="cuda" to the register_lowering. Currently you're asserting that lowering works everywhere.

You should also skip the test if not on cuda @jtu.run_on_devices("cuda") iirc.

hawkinsp avatar Aug 27 '24 17:08 hawkinsp

Done.

sergachev avatar Aug 27 '24 18:08 sergachev

Sorry, it took me a long time to look at this. This test fails in our internal CI because it seems on V100 (which we run in CI) the rewrite to a cudnn fusion does not happen. Instead, the after optimization hlo ends up with a cublas gemm. Is that intended? Should the test be gated on particular GPU generations?

hawkinsp avatar Sep 04 '24 19:09 hawkinsp

It should run on H100. Is this https://github.com/google/jax/pull/22699/files#diff-77b54950a53c3196a56e8f570cb6dcd4eca602b5a8b4220f5cd2acb86f060e7fR1548 not sufficient to filter by GPU type?

sergachev avatar Sep 04 '24 22:09 sergachev

Anyway, I looked at other tests and added a check with skipTest(). It actually works on Ampere+.

sergachev avatar Sep 04 '24 23:09 sergachev

It should run on H100. Is this https://github.com/google/jax/pull/22699/files#diff-77b54950a53c3196a56e8f570cb6dcd4eca602b5a8b4220f5cd2acb86f060e7fR1548 not sufficient to filter by GPU type?

It appears not. However, in general BUILD rules aren't enough, because we support running the tests via other means such as pytest. So a BUILD rule filter is helpful (it stops us from running pointless tests), but the test should also skip itself if the hardware it needs isn't present.

hawkinsp avatar Sep 05 '24 00:09 hawkinsp

The rewrite also seems to fail on A100?

hawkinsp avatar Sep 06 '24 01:09 hawkinsp

I tested it on A100.

sergachev avatar Sep 06 '24 09:09 sergachev

I'm still finding this to fail in CI. It looks like the cudnn fusion is produced at the HLO fed to the compiler, but for some reason it gets rewritten away.

Are we guaranteed that the fusion will be emitted, or can it sometimes be autotuned away or something? Are there any other circumstances under which the fusion will fall back?

hawkinsp avatar Sep 08 '24 18:09 hawkinsp

Indeed, I examined the tests we have (https://github.com/openxla/xla/blob/main/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc#L27, https://github.com/openxla/xla/blob/main/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc#L709) and realised, that the latter one relies on xla_gpu_cublas_fallback(false).

Fix: https://github.com/google/jax/pull/23505

sergachev avatar Sep 09 '24 12:09 sergachev