keras icon indicating copy to clipboard operation
keras copied to clipboard

`model.validate` resets training metrics

Open diggerk opened this issue 3 years ago • 10 comments

System information

  • Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 20.04
  • TensorFlow installed from (source or binary): Binary
  • TensorFlow version (use command below): 2.8.0
  • Python version: 3.9
  • GPU model and memory: Irrelevant
  • Exact command to reproduce: See the description

Describe the problem

The Model.validate function has a surprising side effect - it resets all metrics, including the loss. We call Model.validate in a callback to report validation metrics many times during our epochs because our epochs are long, and we just found out that we overwrite training metrics while doing so. It'd be great to either modify Model.validate to not reset metrics or to have a mechanism of saving and restoring metrics.

Source code / logs

This is the callback we that logs validation metrics every N steps. The model.validate call, apparently, resets training metrics.

class MetricsLogger(tf.keras.callbacks.Callback):
    """Logs training metrics every N steps.

    This callback can be used to log metrics more often than once every epoch,
    which allows to make the training process introspectable when epochs are
    very long and when the out of the box per-epoch logging is not enough.

    The metrics are logged to Weights and Biases.
    """

    def __init__(
        self,
        model: tf.keras.Model,
        val_ds: tf.data.Dataset,
        freq: int,
        run: wandb_sdk.wandb_run.Run,
        prefix: str = "",
    ):
        """
        Args:
            model: Model being trained.
            val_ds: Dataset to be used when calculating validation metrics.
            freq: Frequency of logging metrics, in number of batches. Logging
                too often with a large validation dataset will significantly
                slow down training.
            run: W&B run to publish metrics to.
            prefix: Prefix to prepend to metric names. If this is set to a
                non-empty value, metrics published with this callback will have
                names different from the metrics TensorFlow publishes by the end
                of each epoch. Defaults to "".
        """
        self.model = model
        self.val_ds = val_ds
        self.freq = freq
        self.run = run
        self.prefix = prefix
        self.step = 0

    def on_batch_end(self, batch, logs=None):
        self.step += 1
        if self.step % self.freq == 0:
            if logs:
                self.run.log(self._add_prefix(logs), step=self.step)

            val_logs = self.model.evaluate(x=self.val_ds, verbose=0, return_dict=True)
            val_logs = {f"val_{key}": value for key, value in val_logs.items()}
            self.run.log(self._add_prefix(val_logs), step=self.step)

    def _add_prefix(self, logs: Dict[str, Any]) -> Dict[str, Any]:
        if self.prefix:
            return {f"{self.prefix}_{k}": v for k, v in logs.items()}
        return logs

diggerk avatar May 11 '22 23:05 diggerk

@diggerk In order to expedite the trouble-shooting process, please provide a complete code snippet to reproduce the issue reported here.I am facing a different error while reproducing this issue. Thank you!

sushreebarsa avatar May 12 '22 04:05 sushreebarsa

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

google-ml-butler[bot] avatar May 19 '22 04:05 google-ml-butler[bot]

Closing as stale. Please reopen if you'd like to work on this further.

google-ml-butler[bot] avatar May 26 '22 05:05 google-ml-butler[bot]

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] avatar May 26 '22 05:05 google-ml-butler[bot]

Here's a colab demoing the problem. It's a modified mnist tutorial that shows that when a callback is used to run model validation more often than once per epoch, training metrics are lost.

https://colab.research.google.com/drive/1-lMkaKlzK5SEba11tJ0eA-LJxFleNFxx

diggerk avatar Jun 10 '22 20:06 diggerk

The comment from google-ml-butler says to reopen the issue if it needs further work, but I don't seem to have permissions to do that. @sushreebarsa, can you please reopen this?

diggerk avatar Jun 10 '22 20:06 diggerk

@diggerk Reopening this issue as per the above comment. Thank you!

sushreebarsa avatar Jun 11 '22 16:06 sushreebarsa

Closing as stale. Please reopen if you'd like to work on this further.

google-ml-butler[bot] avatar Jun 18 '22 16:06 google-ml-butler[bot]

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] avatar Jun 18 '22 16:06 google-ml-butler[bot]

@gadagashwini I was able to replicate the issue on colab, please find the gist here. Thank you!

sushreebarsa avatar Jun 19 '22 15:06 sushreebarsa