keras
keras copied to clipboard
jax numpy operation not recognised, other backends work fine
The following (pretty trivial) code works with the tensorflow and torch backends, but not with jax:
keras.ops.cumsum(keras.ops.eye(4))
Below is the error:
.venv/lib/python3.10/site-packages/keras/src/backend/jax/numpy.py:306:0: note: see current operation: %31 = "mhlo.pad"(%30, %1) {edge_padding_high = dense<[1, 0]> : tensor<2xi64>, edge_padding_low = dense<0> : tensor<2xi64>, interior_padding = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<2x4xf32>
Arguments received by Cumsum.call():
• x=jnp.ndarray(shape=(4, 4), dtype=float32)
I am running it on an Intel MacBook, in case it's relevant.
jax 0.4.20 jax-metal 0.0.5
Hi,
I was able to run the code successfully on colab, could you please upgrade the Keras version and try again. Here is the working Gist for reference https://colab.sandbox.google.com/gist/sachinprasadhs/3be1c8ea579ab0d60de9bf08241962b5/19059.ipynb
Thanks @sachinprasadhs. I ran the exact same code from the gist in my machine, using keras==3.0.2. I ran it both with jax=0.4.20 and jax-metal, as above, and with the newest jax==0.4.23. In both cases, I had the same result. I could only successfully run the code if I first uninstalled jax-metal and then installed jax. But then, of course, nothing runs on the GPU.
Apparently it is a bug with jax-metal. If I use the jax-metal packages and instead of
keras.ops.cumsum(keras.ops.eye(4))
I use the namespace jax.numpy for the same methods, then I get the same error.
Still, from a keras perspective, is this something that could (or even should?) be "patched" as to avoid code breaking for some users? For example, if the backend is jax and the machine is using a metal GPU, then certain operations could default to the CPU using numpy. This would bring performance issues but again, could be a design choice to avoid breaking code in the absence of a comprehensive jax backend for Apple gear. Just a thought. Thanks again
Of course, if a decision to follow this suggestion is made, I would be happy to contribute a PR to get things started.