keras icon indicating copy to clipboard operation
keras copied to clipboard

Add `epoch_logs` and `training_logs` variables from Model.fit as class attributes

Open qmarcou opened this issue 2 years ago • 5 comments

System information.

TensorFlow version (you are using): 2.8.0 Are you willing to contribute it (Yes/No) : Yes

Describe the feature and the current behavior/state.

Use-case: I'm using the keras-tuner library to optimize hyperparameters, but I also want to be able to use Early stopping. In order to get dataset orthogonalization right I'd like to be able to divide my dataset in 4 splits:

  • train: on which the actual training is made
  • train_val: which is used for Early stopping
  • dev_val or validation: used to monitor target metrics to tune hyperparameters
  • test: well to test my final best model

I want to make a clear separation between train_val and dev_val so as not to overfit the dev_val set due to early stopping. Thus I would like to perform evaluation on 2 different datasets on each epoch end. Model.fit only handles 1 validation set.

Because keras_tuner uses a ModelCheckpoint callback to save the best models, I can't add the evaluation outside the fit method in some kind of wrapper.

If I could instantiate a custom Callback to evaluate my model on my dev_val set and add the results to the logs before ModelCheckpoint.on_epoch_end is called my problem would be solved, but the logs dict is passed as read only argument and cannot be modified. Right now the logs variables are temporary variables inside the model.fit funciton:

val_logs = {'val_' + name: val for name, val in val_logs.items()}
epoch_logs.update(val_logs)

callbacks.on_epoch_end(epoch, epoch_logs)
training_logs = epoch_logs

Solving my problem by subclassing the Model class and overriding the fit method to incorporate the train_val evaluation seems very heavy for such a simple task, that should be performed by Callbacks

Will this change the current api? How? This would enable to use custom Callback to update the logs dict while right now these are read only

Who will benefit from this feature? I guess anybody wanting to compute some set of metrics only at batch end for instance, or on different validation sets For instance this would be an elegant way to solve this SO question Possibly many other applications for just extending pieces of predict/evaluate/or fit.

Contributing

  • Do you want to contribute a PR? (yes/no): yes
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing): simply add the two variables as Model attributes (~5 lines of code change)

qmarcou avatar Jun 09 '22 15:06 qmarcou

Damn after going through the hassle of doing the required monkey patching to test my modification proposal, I've realized that because of the way the logs dict is passed in CallbackList.on_epoch_end I could simply update the logs dict by reference and the modification would be propagated to the History callback ...

This piece of code gives the intended result:

from unittest import TestCase
from tensorflow import keras

class DataEvaluator(keras.callbacks.Callback):
    """A callback to evaluate a model on specific data."""

    def __init__(self, x, y,
                 eval_prefix="dataEval_", **kwargs):
        self.x_data = x
        self.y_data = y
        self.eval_prefix = eval_prefix
        self.eval_kwargs = kwargs
        if kwargs.get("return_dict") is not None:
            raise AttributeError("return_dict cannot be specified when used "
                                 "in the DataEvaluator callback")

    def on_epoch_end(self, epoch, logs=None):
        # Evaluate the model on the provided data
        eval_output = self.model.evaluate(
            x=self.x_data,
            y=self.y_data,
            return_dict=True,
            **self.eval_kwargs
        )
        # Add the prefix to all keys of the dict
        eval_output = {self.eval_prefix + str(key): val for key, val in
                       eval_output.items()}

        # Check that none of the new keys overshadow the previous ones
        if len(set(eval_output.keys())
                       .intersection(logs.keys())) > 0:
            raise RuntimeError("Names resulting from the DataEvaluator "
                               "overshadows names from the logs dict, "
                               "pick a different eval_prefix")
        # Update epoch_logs dict
        logs.update(eval_output)

class TestDataEvaluator(TestCase):
    def test(self):
        x_train = [[0], [1], [2], [3]]
        y_train = [0, 0, 1, 1]
        model = keras.Sequential(keras.layers.Dense(1, activation="sigmoid"))
        model.compile(loss="binary_crossentropy", metrics="accuracy")
        data_eval = callbacks.DataEvaluator(x=x_train, y=y_train,
                                            eval_prefix="dataEvalTester_")
        x_test = [[3], [2], [1], [0]]
        y_test = [0, 0, 1, 1]
        hist: keras.callbacks.History = model.fit(x=x_test,
                                                  y=y_test,
                                                  epochs=3,
                                                  callbacks=[data_eval],)

        print(hist.history)
        # {'loss': [0.5381441116333008, 0.5373239517211914, 0.5367312431335449], 'accuracy': [0.5, 0.75, 0.75], 'dataEvalTester_loss': [1.1190409660339355, 1.1207287311553955, 1.1221365928649902], 'dataEvalTester_accuracy': [0.25, 0.25, 0.25]}

Unless you think there's a better way to solve this I'll take this solution. However maybe it could be of some help to add a few words on this in the Callback documentation? For now examples exploiting the logs only read them in the tutorial

qmarcou avatar Jun 10 '22 15:06 qmarcou

@qmarcou The documentation is already present and there are examples which describe this. Since, your issue is solved, can you please close this issue. Thanks!

gowthamkpr avatar Jun 12 '22 00:06 gowthamkpr

@gowthamkpr I could not find any of the mentioned examples, only examples with a read only usage of the logs, could you please point to them for future reference? I'll close the issue afterwards, Thanks for your help! Best regards

qmarcou avatar Jun 13 '22 09:06 qmarcou

@qmarcou Sorry for the delayed response. Would you like to create a PR for the documentation? Thanks!

gowthamkpr avatar Jul 13 '22 13:07 gowthamkpr

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

google-ml-butler[bot] avatar Jul 25 '22 22:07 google-ml-butler[bot]

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

google-ml-butler[bot] avatar Aug 11 '22 16:08 google-ml-butler[bot]

Closing as stale. Please reopen if you'd like to work on this further.

google-ml-butler[bot] avatar Aug 18 '22 16:08 google-ml-butler[bot]

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] avatar Aug 18 '22 16:08 google-ml-butler[bot]