neural-tangents
neural-tangents copied to clipboard
`utils.is_on_cpu` breaks inside of JIT.
For example,
def f(x):
print(_arr_is_on_cpu(x))
f(np.array([1.0])) # Prints False
jit(f)(np.array([1.0])) #Prints True
FYI I feel that this might be impossible to fix, but I also think that once
https://github.com/google/jax/issues/2225
https://github.com/google/jax/issues/2226
and perhaps other bugs related to device placement are fixed in JAX, we shouldn't need this method and a bunch of other device-placement code in predict
.
An update on this: since a76bbb494f19af4f8c9c1a1b0904e91b105f769e + JAX device-placement bugfixes, this function is only used in nt.predict.max_learning_rate
, but other functions don't need it, so it's less problematic now. I still don't know if it's possible to make a function like this in JAX that works under JIT though...