neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

`utils.is_on_cpu` breaks inside of JIT.

Open sschoenholz opened this issue 5 years ago • 2 comments

For example,

def f(x):
  print(_arr_is_on_cpu(x))

f(np.array([1.0])) # Prints False
jit(f)(np.array([1.0])) #Prints True

sschoenholz avatar Feb 05 '20 21:02 sschoenholz

FYI I feel that this might be impossible to fix, but I also think that once https://github.com/google/jax/issues/2225 https://github.com/google/jax/issues/2226 and perhaps other bugs related to device placement are fixed in JAX, we shouldn't need this method and a bunch of other device-placement code in predict.

romanngg avatar Mar 31 '20 19:03 romanngg

An update on this: since a76bbb494f19af4f8c9c1a1b0904e91b105f769e + JAX device-placement bugfixes, this function is only used in nt.predict.max_learning_rate, but other functions don't need it, so it's less problematic now. I still don't know if it's possible to make a function like this in JAX that works under JIT though...

romanngg avatar Jun 25 '20 11:06 romanngg