jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax_debug_nans` should test for the presence of NaNs on device, not on the host

Open hawkinsp opened this issue 5 years ago • 3 comments

The JAX NaN checker computes: onp.any(onp.isnan(buf.to_py())) i.e., this transfers the array to the host and then tests for the presence of NaNs.

It would be much more bandwidth efficient to test for the presence of NaNs on device and then transfer a single "NaNs present" boolean value to the host.

hawkinsp avatar Apr 06 '20 19:04 hawkinsp

Just checking, HLO-wise can we rely on Neq(x, x) to indicate a NaN on all backends? Or should we CustomCall on GPU/CPU, which would avoid the compilation and perhaps ensure we're checking for nans correctly?

mattjj avatar Apr 07 '20 01:04 mattjj

I think we can rely on: a) XLA producing correct answers, at least in this debugging mode which is intended to debug user code, and b) x != x for a NaN on all XLA platforms. If we don't test this, we should.

hawkinsp avatar Apr 07 '20 01:04 hawkinsp

Also, given we are already compiling a computation, I don't see a lot of downside to appending a NaN check boolean output to it as well.

hawkinsp avatar Apr 07 '20 01:04 hawkinsp