keras
keras copied to clipboard
Wrong binary accuracy with Jax
I have some very strange results out of the `
Consider the code below:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
inp = keras.Input(shape=(1,))
out = inp > 0.5
mm = keras.Model(inputs=inp, outputs=out)
x = np.random.rand(32, 1)
res = mm.predict(x)
met = keras.metrics.BinaryAccuracy()
met.update_state(x>0.5, res>0.5)
met.result()
I would expect to get 1 every single run. Instead I get some random result (close to 0.5).
Packages' versions (tf, keras, jax, np)
'2.17.0', '3.5.0', '0.4.26', '1.26.4'