keras
keras copied to clipboard
Inconsistent manner of the metric `SpecificityAtSensitivity` among different backends
Hi, developers. I want to monitor the metric SpecificityAtSensitivity
values during the training process. I've checked the doc to make sure it could be used with compile
API.
However, I find this metric does not work out of box like others, for example TruePositives
. It could not be used directly in different backends, and even sometimes with "accuracy"! Below are the tables of three backends to show the results I found. "with acc" means set metrics=["acc", SpecificityAtSensitivity(...)]
. And a colab link here to reproduce the error: issue_inconsistent_manner_of_SpecificityAtSensitivity.ipynb
tensorflow
run_eager \ metrics | with "acc" | without "acc" |
---|---|---|
False (default) | ❎ | ❎ |
True | ✔️ | ✔️ |
<class 'NotImplementedError'> Cannot convert a symbolic tf.Tensor (Cast_12:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
torch
run_eager \ metrics | with "acc" | without "acc" |
---|---|---|
False (default) | ❎ | ✔️ |
True | ❎ | ✔️ |
<class 'NotImplementedError'> Cannot copy out of meta tensor; no data!
jax
run_eager \ metrics | with "acc" | without "acc" |
---|---|---|
False (default) | ❎ | ❎ |
True | ❎ | ❎ |
<class 'jax.errors.ConcretizationTypeError'> Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. The error occurred while tracing the function wrapped_fn at /usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/core.py:153 for make_jaxpr. This concrete value was not available in Python because it depends on the values of the arguments args[1] and args[2].