keras icon indicating copy to clipboard operation
keras copied to clipboard

Wrong binary accuracy with Jax

Open eli-osherovich opened this issue 6 months ago • 4 comments

I have some very strange results out of the `

Consider the code below:

import os
os.environ["KERAS_BACKEND"] = "jax"
import keras


inp = keras.Input(shape=(1,))
out = inp > 0.5
mm = keras.Model(inputs=inp, outputs=out) 

x = np.random.rand(32, 1)

res = mm.predict(x)
met = keras.metrics.BinaryAccuracy()
met.update_state(x>0.5, res>0.5)
met.result()

I would expect to get 1 every single run. Instead I get some random result (close to 0.5).

Packages' versions (tf, keras, jax, np)

'2.17.0', '3.5.0', '0.4.26', '1.26.4'

eli-osherovich avatar Aug 28 '24 14:08 eli-osherovich