`jax_debug_nans` should test for the presence of NaNs on device, not on the host
The JAX NaN checker computes:
onp.any(onp.isnan(buf.to_py()))
i.e., this transfers the array to the host and then tests for the presence of NaNs.
It would be much more bandwidth efficient to test for the presence of NaNs on device and then transfer a single "NaNs present" boolean value to the host.
Just checking, HLO-wise can we rely on Neq(x, x) to indicate a NaN on all backends? Or should we CustomCall on GPU/CPU, which would avoid the compilation and perhaps ensure we're checking for nans correctly?
I think we can rely on:
a) XLA producing correct answers, at least in this debugging mode which is intended to debug user code, and
b) x != x for a NaN on all XLA platforms. If we don't test this, we should.
Also, given we are already compiling a computation, I don't see a lot of downside to appending a NaN check boolean output to it as well.