keras
keras copied to clipboard
Add `training` parameter to `compute_metrics`
Related to https://github.com/keras-team/keras/issues/16809
Assign to Rick for further inputs.
Thanks for the PR! There's a concern of breaking backward compatibility for those who have custom models that have overridden compute_metrics(), so unless model.train_step() or model.test_step() inspect the signature of compute_metrics() I don't find an easy backward-compatible change afaict. Also, it looks to me that this PR is missing checking training boolean before calling into _compute_metrics_result() in compute_metrics().
Thanks for the PR! There's a concern of breaking backward compatibility for those who have custom models that have overridden
compute_metrics(), so unlessmodel.train_step()ormodel.test_step()inspect the signature of compute_metrics() I don't find an easy backward-compatible change afaict.`
Thanks for the review! Please see https://github.com/keras-team/tf-keras/issues/509 for a proposal to solve the API breaking issue.
Also, it looks to me that this PR is missing checking
trainingboolean before calling into_compute_metrics_result()in `compute_metrics().
The training argument is meant for subclasses overriding compute_metrics() and willing to use this information to adapt the logics accordingly. The behavior of compute_metrics() in the base class is not modified.
Thanks for the link to the proposal. While we're assessing the usefulness of it, I'm not fully convinced how the metrics are computed should be based on whether it's training or evaluation (especially because the default implementation doesn't need that).
Wouldn't it be easier that your training customizes train_step/test_step to achieve the behavior?
The training information during metric computation can be useful for several reasons:
- The metrics are expensive and one would like to disable/sub-sample them on train.
- The training and evaluation data carry different semantics. For example, training data is shuffled, but evaluation data have an ordering, e. g. time series. A metric may need to be aware of this ordering information for decorrelation purposes.
Yes it is possible to customize train_step/test_step instead of compute_metrics but this increases the maintenance cost for the user. Those methods carry more logic than just metric computations.
That being said, maybe there is a more consistent way to achieve this? Since Metric is a subclass of base_layer.Layer and the latter being already 'training' aware, maybe there is a more API-consistent way to pipe this information into metrics?
The variables of a regular layer in Keras are updated using gradient descent. But in ML in general there are approaches that don't use gradient descent such as non parametric models (KNN, GMM, decision trees, ...). Metrics could be seen as special cases of such models, where update_state is a proxy for updating the variables without using gradient descent.
By generalizing the keras API to enable layers to update their parameters without gradient descent, i.e. using a method like update_variables(x, y_true, y_pred, **kwargs), this would allow non parametric models to be implemented in Keras, and metrics, being seen as special cases of such, would have a natural exposure to the training argument.
@rchao Can you please assist on above comments from @nicolaspi. Thank you!
The API change looks good to me. I don't quite understand the factoring though.
This PR aims to grant users a better control over the metrics computation, it is done in two ways:
-
Adding the
trainingargument tocompute_metrics. -
Adding logics to collect the metrics results at the epoch level independently from
compute_metrics. It allows users to have control on metrics collection withincompute_metrics(by potentially removing them) while being safeguarded by the epoch level collection. This explains the extra method_compute_metrics_resultand the extra line there: https://github.com/keras-team/keras/blob/3545340d24f1fb7bf5e3d1e2ee860a09a1598eeb/keras/engine/training.py#L1579Side note : The logic of
compute_metricsis a bit inconsistent in the current factoring because it mixes the semantics between updating the states throughMetric.update_stateand collecting the results throughMetric.result. The update part is only made on thecompiled_metrics, because the updates ofcompiled_lossandLayer._metricsare done elsewhere. But the collection is done on all metrics. IMO I would decouple those logics with two public methodsupdate_compiled_metricsandcollect_metricswith a cache mechanism that computesMetric.resultonly if the state have been updated, otherwise returning the cached value.
I've re-read through the code, but I find the factoring in the PR quite confusing.
IMO I would decouple those logics with two public methods update_compiled_metrics and collect_metrics with a cache mechanism that computes Metric.result only if the state have been updated, otherwise returning the cached value.
Can you try implementing that in the PR? Maybe that will be a nice improvement.
Can you try implementing that in the PR? Maybe that will be a nice improvement.
Ok, I will try.
@nicolaspi - Keras recently introduced get_metrics_result API for Model that addresses part of this issue / PR. Can you please rebase and only address the training parameter added to compute_metrics
@nicolaspi - Keras recently introduced
get_metrics_resultAPI forModelthat addresses part of this issue / PR. Can you please rebase and only address thetrainingparameter added tocompute_metrics
@sampathweb Thanks, I have rebase the code and did the changes to address the training parameter only.
Hi @fchollet, @rchao, Any ETA for merging this? Thanks
This should get merged in a couple of days. Thanks for the patience!
Hello @nicolaspi, my apologies, but we've encountered issues when attempting to merge this internally, because it's breaking the backward compatibility and thus numerous tests (there are many subclassed models overriding compute_metrics() in which super().compute_metrics() is called, and this raises an error because training argument is not provided).
Given the impact on the existing users, I don't think we can easily have this merged in. Can you subclass Model and override train_step/compute_metrics for your specific use case?
Hello @nicolaspi, my apologies, but we've encountered issues when attempting to merge this internally, because it's breaking the backward compatibility and thus numerous tests (there are many subclassed models overriding
compute_metrics()in whichsuper().compute_metrics()is called, and this raises an error becausetrainingargument is not provided). Given the impact on the existing users, I don't think we can easily have this merged in.
@rchao Damn. :(
Can you subclass
Modeland overridetrain_step/compute_metricsfor your specific use case?
I am already running this solution, but it is a maintenance nightmare for keeping the code up to date with keras.
Any plan on designing a way for feeding information to user's methods without breaking the public API? I have suggested this already, but a less clumsy solution could be using context managers:
with TrainingContext(is_training=True):
self.train_step()
...
with TrainingContext(is_training=False):
self.test_step()
I am already running this solution, but it is a maintenance nightmare for keeping the code up to date with keras.
Can you describe more why this is the case? Only events I'm seeing may require additional maintenance is when we break the backward compatibility, which is rare.
but a less clumsy solution could be using context managers:
I see a path forward with this proposal, but within the compute_metrics method, how would one access whether it's training or not?
@nicolaspi Can you please check @rchao's comments and keep us posted ? Thank you!
Hi @nicolaspi Any update on this PR? Please. Thank you!
Hi @nicolaspi Any update on this PR? Please. Thank you!
Hi, sorry I am 'out of office' at the moment. I will give an update within two weeks.
Hi @nicolaspi Any update on this PR? Please. Thank you!
Hi, sorry for the delay. I am out of sync with this issue. I am closing it. I will reopen or open a new issue in the future if I happen to be in sync again. Thanks for the time invested.