keras icon indicating copy to clipboard operation
keras copied to clipboard

jax numpy operation not recognised, other backends work fine

Open dkgaraujo opened this issue 6 months ago • 4 comments

The following (pretty trivial) code works with the tensorflow and torch backends, but not with jax:

keras.ops.cumsum(keras.ops.eye(4))

Below is the error:

.venv/lib/python3.10/site-packages/keras/src/backend/jax/numpy.py:306:0: note: see current operation: %31 = "mhlo.pad"(%30, %1) {edge_padding_high = dense<[1, 0]> : tensor<2xi64>, edge_padding_low = dense<0> : tensor<2xi64>, interior_padding = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<2x4xf32>


Arguments received by Cumsum.call():
  • x=jnp.ndarray(shape=(4, 4), dtype=float32)

I am running it on an Intel MacBook, in case it's relevant.

dkgaraujo avatar Jan 15 '24 22:01 dkgaraujo

jax 0.4.20 jax-metal 0.0.5

dkgaraujo avatar Jan 15 '24 22:01 dkgaraujo

Hi,

I was able to run the code successfully on colab, could you please upgrade the Keras version and try again. Here is the working Gist for reference https://colab.sandbox.google.com/gist/sachinprasadhs/3be1c8ea579ab0d60de9bf08241962b5/19059.ipynb

sachinprasadhs avatar Jan 17 '24 19:01 sachinprasadhs

Thanks @sachinprasadhs. I ran the exact same code from the gist in my machine, using keras==3.0.2. I ran it both with jax=0.4.20 and jax-metal, as above, and with the newest jax==0.4.23. In both cases, I had the same result. I could only successfully run the code if I first uninstalled jax-metal and then installed jax. But then, of course, nothing runs on the GPU.

Apparently it is a bug with jax-metal. If I use the jax-metal packages and instead of

keras.ops.cumsum(keras.ops.eye(4))

I use the namespace jax.numpy for the same methods, then I get the same error.

Still, from a keras perspective, is this something that could (or even should?) be "patched" as to avoid code breaking for some users? For example, if the backend is jax and the machine is using a metal GPU, then certain operations could default to the CPU using numpy. This would bring performance issues but again, could be a design choice to avoid breaking code in the absence of a comprehensive jax backend for Apple gear. Just a thought. Thanks again

dkgaraujo avatar Jan 17 '24 22:01 dkgaraujo

Of course, if a decision to follow this suggestion is made, I would be happy to contribute a PR to get things started.

dkgaraujo avatar Jan 17 '24 22:01 dkgaraujo