LibRecommender icon indicating copy to clipboard operation
LibRecommender copied to clipboard

Callbacks doesnt work

Open apdullahyayik opened this issue 3 years ago • 6 comments

When I pass keras callback on the fit method (during ranking task), it does not affect the training process.

callbacks = [
        EarlyStopping(
            monitor='recall',
            patience=1,
            restore_best_weights=True
        )
    ]
model.fit(train_data, verbose=3, shuffle=True, eval_data=eval_data, callbacks=callbacks, k=args.num_retrieved_items, metrics=["loss", "balanced_accuracy", "roc_auc", "pr_auc", "precision", "recall", "map", "ndcg"])

apdullahyayik avatar Jul 21 '21 20:07 apdullahyayik

This is my solution

1- I have added, metric_name_monitored to print_metrics_ranking function in evaluation/evaluate.py

def print_metrics_ranking(metrics, y_prob=None, y_true=None, y_reco_list=None,
                          y_true_list=None, users=None, k=10, train=True, metric_name_monitored=None):

2- In EvalMixin class of base.py added an object attribute (self.metric_value_monitored ) that holds metric value that is monitored. (Note that models are created by triple extend (Base, TfMixin, EvalMixin), so methods of EvalMixin can be used in train loop)

metric_name_monitored = kwargs.get('metric_name_monitored', None)
self.metric_value_monitored = print_metrics_ranking(metrics, **test_params, k=k, train=False,                                                                    metric_name_monitored=metric_name_monitored)

3- In train_feat method of base.py (where train loop is implemented), I have added codes below to monitor and check the fluctuation over monitored metric.

# Early stop
self.print_metrics(eval_data=eval_data, metrics=metrics,
                   **kwargs)

if metric_name_monitored in metrics:
    if metric_value_monitored_ is None:
        metric_value_monitored_ = self.metric_value_monitored
    else:
        if self.metric_value_monitored <= metric_value_monitored_:
            patience_limit -= 1
            if patience_limit == 0:
                print('Early stopped')
                break
            else:
                metric_value_monitored_ = self.metric_value_monitored

In the end, the proposed solution for the problem is abstracted as using patience_limit=2, metric_name_monitored='recall' keyword arguments in fit method of a model. When these params are not passed or wrong passed (such as spell errors like precison) no error is thrown but the given epoch keeps going to the end. This approach probably affects all models. It works cool, I can make a pull request if it is required.

If there is another compact way that the library already provides, please let me know. Until that time, I plan to use mine.

apdullahyayik avatar Jul 21 '21 22:07 apdullahyayik

OK thanks, I appreciate your work. The algorithms are mainly implementated using tf1.x syntax, so keras callbacks won't work in them. Maybe I should've clarified this in README. As for early stopping, I need some time to think it through.

massquantity avatar Jul 22 '21 14:07 massquantity

Thanks. For now, I am closing this issue. Note that whenever you notify me, I can send a PR.

apdullahyayik avatar Jul 22 '21 23:07 apdullahyayik

Well, I have to say your solution is very creative, but there's always a but...

Using and changing the same metric_value_monitored attribute across different base classes will easily confuse other code readers. To be honest, even I could be confused if you didn't tell me at first. Also assigning the return value of function print_metrics_ranking to metric_value_monitored doesn't feel right too, since it's a print function:) Besides, how do you plan to implement restore_best_weights=True when the training is stopped in tf1.x?

In terms of early-stopping, I did consider adding it when I was writing this library at the begining, but later on I found it difficult to implement it in an elegent way. The code would quickily become verbose. Futhermore, in my daily life I rarely use early-stopping. I may log the evaluation metrics during training, and decide the optimal epoch based on the log. Then train again with that epoch.

massquantity avatar Jul 24 '21 08:07 massquantity

So for this task, there are 3 todos like below:

1- print_metrics_ranking function has more than one role that are computing and printing metrics contrary to its name. I think these two roles should be separated to build a more readable design. After that, return operation would be understandable.

2- Currently I have checked only for wide and deep model using ranking task. So, the others should be checked.

3- Making restore_best_weights=True option in tf1.

apdullahyayik avatar Jul 24 '21 09:07 apdullahyayik

  1. Right.
  2. Not all models use the train loop in base.py, e.g. sequence models. So this requires a lot of work.
  3. Yes and this option is not trivial.

massquantity avatar Jul 24 '21 22:07 massquantity