Error in stats.kendalls_tau as Keras Metric
I am trying to use TensorFlow probability as a metric in Keras. With respect to kendalls_tau, I get the following error:
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
def kendalls_tau(y_true, y_pred):
a = tf.reshape(y_true, shape=(-1,))
b = tf.reshape(y_pred, shape=(-1,))
kendall = tfp.stats.kendalls_tau(a, b)
return kendall
inputs = tf.keras.layers.Input(shape=(3,))
outputs = tf.keras.layers.Dense(2)(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="Adam", loss="mse", metrics=kendalls_tau)
x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2) )
model.fit(x, y)
TypeError: in user code:
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function *
return step_function(self, iterator)
<ipython-input-4-14a2210abe73>:5 kendalls_tau *
kendall = tfp.stats.kendalls_tau(a, b)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau **
lexa = lexicographical_indirect_sort(y_true, y_pred)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
left, _, lexicographic = tf.while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
return func(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
return while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
return while_v2.while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
body_graph = func_graph_module.func_graph_from_py_func(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
tf.cond(not_equal, secondary_sort, lambda: lexicographic))
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
return func(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
true_graph = func_graph_module.func_graph_from_py_func(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
tensorshape_util.set_shape(x, [n])
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
tensor.set_shape(shape)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
shape = tensor_shape.TensorShape(shape)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
self._dims = [Dimension(d) for d in dims]
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
self._dims = [Dimension(d) for d in dims]
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
six.raise_from(
<string>:3 raise_from
TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/data/Mestrado/Ensaios/drbc_tf.py in
257 x = np.random.random((2, 3))
258 y = np.random.randint(0, 2, (2, 2) )
---> 259 model.fit(x, y)
260
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1190 _r=1):
1191 callbacks.on_train_batch_begin(step)
-> 1192 tmp_logs = self.train_function(iterator)
1193 if data_handler.should_sync:
1194 context.async_wait()
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
883
884 with OptionalXlaContext(self._jit_compile):
--> 885 result = self._call(*args, **kwds)
886
887 new_tracing_count = self.experimental_get_tracing_count()
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
927 # This is the first call of __call__, so we have to initialize.
928 initializers = []
--> 929 self._initialize(args, kwds, add_initializers_to=initializers)
930 finally:
931 # At this point we know that the initialization is complete (or less
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
757 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
758 self._concrete_stateful_fn = (
--> 759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
760 *args, **kwds))
761
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
3057 args, kwargs = None, None
3058 with self._lock:
-> 3059 graph_function, _ = self._maybe_define_function(args, kwargs)
3060 return graph_function
3061
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3454
3455 self._function_cache.missed.add(call_context_key)
-> 3456 graph_function = self._create_graph_function(args, kwargs)
3457 self._function_cache.primary[cache_key] = graph_function
3458
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3289 arg_names = base_arg_names + missing_arg_names
3290 graph_function = ConcreteFunction(
-> 3291 func_graph_module.func_graph_from_py_func(
3292 self._name,
3293 self._python_function,
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
1005 _, original_func = tf_decorator.unwrap(python_func)
1006
-> 1007 func_outputs = python_func(*func_args, **func_kwargs)
1008
1009 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
666 # the function a weak reference to itself to avoid a reference cycle.
667 with OptionalXlaContext(compile_with_xla):
--> 668 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
669 return out
670
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
992 except Exception as e: # pylint:disable=broad-except
993 if hasattr(e, "ag_error_metadata"):
--> 994 raise e.ag_error_metadata.to_exception(e)
995 else:
996 raise
TypeError: in user code:
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function *
return step_function(self, iterator)
<ipython-input-4-14a2210abe73>:5 kendalls_tau *
kendall = tfp.stats.kendalls_tau(a, b)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau **
lexa = lexicographical_indirect_sort(y_true, y_pred)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
left, _, lexicographic = tf.while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
return func(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
return while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
return while_v2.while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
body_graph = func_graph_module.func_graph_from_py_func(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
tf.cond(not_equal, secondary_sort, lambda: lexicographic))
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
return func(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
true_graph = func_graph_module.func_graph_from_py_func(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
tensorshape_util.set_shape(x, [n])
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
tensor.set_shape(shape)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
shape = tensor_shape.TensorShape(shape)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
self._dims = [Dimension(d) for d in dims]
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
self._dims = [Dimension(d) for d in dims]
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
six.raise_from(
<string>:3 raise_from
TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
How can I fix this?
You can wrap kendalls_tau with tf.py_function:
def tf_kendalls_tau(y_true, y_pred):
kt = tf.py_function(
kendalls_tau,
(y_true, y_pred),
tf.float32
)
return kt
model.compile(optimizer="Adam", loss="mse", metrics=tf_kendalls_tau)
The error is caused by calling the TensorFlow graph version of the function. When you call model.fit(), AutoGraph converts the kendall_tau function into a TensorFlow graph. We can use tf.py_function to prevent this, which allows us to represent kendall_tau in the graph using Python constructs.
While this works, I don't think this issue is resolved. Calling the AutoGraph-converted version of tfp.stats.kendalls_tau throws this error. I'm going to look into why this is the case.
Does someone with more experience already see an issue here? The AutoGraph-converted version of tfp.stats.kendalls_tau should be callable, right?
Full code that worked for me:
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
def kendalls_tau(y_true, y_pred):
a = tf.reshape(y_true, shape=(-1,))
b = tf.reshape(y_pred, shape=(-1,))
kendall = tfp.stats.kendalls_tau(a, b)
return kendall
def tf_kendalls_tau(y_true, y_pred):
kt = tf.py_function(
kendalls_tau,
(y_true, y_pred),
tf.float32
)
return kt
inputs = tf.keras.layers.Input(shape=(3,))
outputs = tf.keras.layers.Dense(2)(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="Adam", loss="mse", metrics=tf_kendalls_tau)
x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2) )
model.fit(x, y)
I've started looking at this, but the stats function is not suitable to use as a keras Metric, for a number of reasons. You probably want to use the approximate version which is O(n) instead and can be found in tensorflow addons, as that was intended for this use case. https://github.com/tensorflow/addons/blob/master/tensorflow_addons/metrics/kendalls_tau.py
the tfp version expects two [n] tensors and as far as I know keras models cannot output scalars.
I've looked into this and I think there's a few issues, but fundamentally I'm not certain the original proposed use case makes much sense - if using the py_function shim you can use the scipy version of Kendall's Tau - but even then I'm not certain it is doing what one would want as the scipy version return nan for lists of length 1.
It's possible to remove the assertions and make the tfp kendall's tau behave more like the scipy one, but still working on it.
Thank you for following up.
I'd love to help out on this. Let me know if there's anything I can do that would be valuable.
I've created https://github.com/tensorflow/probability/pull/1455 which changes behavior to more closely match scipy's implementation. I don't know if this addresses all of the issues raised here but wanted to share it in case it helps.
I think this is fixed.