jax
jax copied to clipboard
[ROCm] Disable CUDA specific test for ROCm.
/cc @hawkinsp
@hawkinsp gentle ping
Sorry this dropped off my list of things to do. Apologies!
I had missed that multiprocess support isn't hooked into ROCM/RCCL (which is what nccl_unique_id_callback does). I'd guess that's a fairly mechanical thing to fix but since that's the state of things I agree with your original fix: disable the test.
https://github.com/google/jax/pull/12556 fixes.