addons
addons copied to clipboard
tfa.metrics.RSquare: ValueError: Shapes (1,) and () are incompatible
tensorflow 2.2 keras 2.3.0-tf tfa 0.10.0 python 3.6
use code:
metrics = [ tfa.metrics.RSquare(name='RSquare', dtype=tf.float32, y_shape=(1,)) ]
run multiple epochs, it went well for the first epoch, but return error information for the first epoch.
Error Info:
history = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, callbacks = callbacks)
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside run_distribute_coordinator already.
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing) 870 workers=workers, 871 use_multiprocessing=use_multiprocessing, --> 872 return_dict=True) 873 val_logs = {'val_' + name: val for name, val in val_logs.items()} 874 epoch_logs.update(val_logs)
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside run_distribute_coordinator already.
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in evaluate(self, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, return_dict) 1071 callbacks.on_test_begin() 1072 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. -> 1073 self.reset_metrics() 1074 with data_handler.catch_stop_iteration(): 1075 for step in data_handler.steps():
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in reset_metrics(self) 1289 """Resets the state of metrics.""" 1290 for m in self.metrics: -> 1291 m.reset_states() 1292 1293 def train_on_batch(self,
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow_addons/metrics/r_square.py in reset_states(self) 140 def reset_states(self) -> None: 141 # The state of the metric will be reset at the start of each epoch. --> 142 self.squared_sum.assign(0) 143 self.sum.assign(0) 144 self.res.assign(0)
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py in assign(self, value, use_locking, name, read_value) 844 with _handle_graph(self.handle): 845 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) --> 846 self._shape.assert_is_compatible_with(value_tensor.shape) 847 assign_op = gen_resource_variable_ops.assign_variable_op( 848 self.handle, value_tensor, name=name)
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py in assert_is_compatible_with(self, other) 1115 """ 1116 if not self.is_compatible_with(other): -> 1117 raise ValueError("Shapes %s and %s are incompatible" % (self, other)) 1118 1119 def most_specific_compatible_shape(self, other):
ValueError: Shapes (1,) and () are incompatible
Hi, can you try this change:
from tensorflow.python.keras.utils import losses_utils
# ...
class RSquare(Metric):
# ...
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
y_true = tf.cast(y_true, dtype=self._dtype)
y_pred = tf.cast(y_pred, dtype=self._dtype)
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
y_pred,
y_true,
)
y_pred.shape.assert_is_compatible_with(y_true.shape)
# ...