keras
keras copied to clipboard
`model.validate` resets training metrics
System information
- Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 20.04
- TensorFlow installed from (source or binary): Binary
- TensorFlow version (use command below): 2.8.0
- Python version: 3.9
- GPU model and memory: Irrelevant
- Exact command to reproduce: See the description
Describe the problem
The Model.validate function has a surprising side effect - it resets all metrics, including the loss. We call Model.validate in a callback to report validation metrics many times during our epochs because our epochs are long, and we just found out that we overwrite training metrics while doing so. It'd be great to either modify Model.validate to not reset metrics or to have a mechanism of saving and restoring metrics.
Source code / logs
This is the callback we that logs validation metrics every N steps. The model.validate call, apparently, resets training metrics.
class MetricsLogger(tf.keras.callbacks.Callback):
"""Logs training metrics every N steps.
This callback can be used to log metrics more often than once every epoch,
which allows to make the training process introspectable when epochs are
very long and when the out of the box per-epoch logging is not enough.
The metrics are logged to Weights and Biases.
"""
def __init__(
self,
model: tf.keras.Model,
val_ds: tf.data.Dataset,
freq: int,
run: wandb_sdk.wandb_run.Run,
prefix: str = "",
):
"""
Args:
model: Model being trained.
val_ds: Dataset to be used when calculating validation metrics.
freq: Frequency of logging metrics, in number of batches. Logging
too often with a large validation dataset will significantly
slow down training.
run: W&B run to publish metrics to.
prefix: Prefix to prepend to metric names. If this is set to a
non-empty value, metrics published with this callback will have
names different from the metrics TensorFlow publishes by the end
of each epoch. Defaults to "".
"""
self.model = model
self.val_ds = val_ds
self.freq = freq
self.run = run
self.prefix = prefix
self.step = 0
def on_batch_end(self, batch, logs=None):
self.step += 1
if self.step % self.freq == 0:
if logs:
self.run.log(self._add_prefix(logs), step=self.step)
val_logs = self.model.evaluate(x=self.val_ds, verbose=0, return_dict=True)
val_logs = {f"val_{key}": value for key, value in val_logs.items()}
self.run.log(self._add_prefix(val_logs), step=self.step)
def _add_prefix(self, logs: Dict[str, Any]) -> Dict[str, Any]:
if self.prefix:
return {f"{self.prefix}_{k}": v for k, v in logs.items()}
return logs
@diggerk In order to expedite the trouble-shooting process, please provide a complete code snippet to reproduce the issue reported here.I am facing a different error while reproducing this issue. Thank you!
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.
Closing as stale. Please reopen if you'd like to work on this further.
Here's a colab demoing the problem. It's a modified mnist tutorial that shows that when a callback is used to run model validation more often than once per epoch, training metrics are lost.
https://colab.research.google.com/drive/1-lMkaKlzK5SEba11tJ0eA-LJxFleNFxx
The comment from google-ml-butler says to reopen the issue if it needs further work, but I don't seem to have permissions to do that. @sushreebarsa, can you please reopen this?
@diggerk Reopening this issue as per the above comment. Thank you!
Closing as stale. Please reopen if you'd like to work on this further.
@gadagashwini I was able to replicate the issue on colab, please find the gist here. Thank you!