trax
trax copied to clipboard
Fix issues related to new behavior of JAX DeviceArray.copy()
Fix issues related to new behavior of JAX DeviceArray.copy()
In https://github.com/google/jax/pull/10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is now np.asarray(device_array).