Correctly compute metrics over entire evalset/testset
When computing metrics for the evaluation/test datasets, the entire set should be processed to get the correct metrics. Our examples often compute metrics with the following pattern:
def compute_metrics(logits, labels):
metrics = {
'accuracy': jnp.mean(jnp.argmax(logits, -1) == labels) ,
# ...
}
metrics = jax.lax.pmean(metrics, 'batch')
return metrics
@functools.partial(jax.pmap, axis='batch')
def eval_step(model, batch):
logits = model(batch['input'])
return compute_metrics(logits, batch['label'])
test_metrics = []
for batch in range(test_iter):
metrics = eval_step(optimizer.target, batch)
test_metrics.append(metrics)
test_metrics = flax.training.common_utils.get_metrics(test_metrics)
test_summary = jax.tree_map(lambda x: x.mean(), test_metrics)
Problems:
- If the last batch has a different number of examples the metrics will be wrong: The combination of
jax.lax.pmean()andcommon_utils.get_metrics()will make count every batch the same way, i.e. in above example the accuracy of a full batch will have the same influence as the accuracy of a half-full batch, but to get correct metrics the half-full batch should only be weighted 50%. - If any batch size is not divisible by the number of devices then above code could fail at the
@pmap. - If batch size is not fixed then
@pmapwill re-compileeval_step()for every batch size. - Above scheme is incompatible altogether with metrics that cannot be averaged per batch (e.g. precision that has the number of detected in the denominator instead of the number of examples like accuracy).
Solution : Pad batches with dummy examples so they all have the same size, apply mask so only real examples are used for the metrics computation, and finally weigh batches according to the number of real examples. For addressing the last issue, a more flexible metric computation is needed.
This could be implemented separately in every example, but it probably makes more sense to factor out this functionality into a library module.
We should make sure that all examples either compute the metrics correctly or at least add a comment to the README where we decide against it for simplicity:
- [ ] graph
- [ ] imagenet
- [ ] lm1b
- [ ] mnist
- [ ] nlp_seq
- [ ] pixelcnn
- [ ] ppo
- [x] seq2seq
- [ ] sst2
- [ ] vae
- [ ] wmt
Part of example improvements #231
This issue supersedes #262
@andsteing I wonder whether we should require this for all our examples (do we realistically think we have the bandwidth to fix this?), or just close this issue and only keep the HOWTO that @avital just opened. WDYT?
@andsteing gentle ping, maybe we should close this issue and focus on your HOWTO instead?
The current plan is to add the HOWTO and some utility functions at the same time in #2111, and then to use that utility function to rewrite some of the examples (but not all).
Once that's done, I'll close this issue.
Note that the example can only be rewritten once a new release is created that includes the change. Since this is not urgent, I'll simply wait for the next natural release to appear that contains #2111.
Closing this since #2111 is submitted. @andsteing please re-open if you think this is incorrect.