flax icon indicating copy to clipboard operation
flax copied to clipboard

Correctly compute metrics over entire evalset/testset

Open andsteing opened this issue 5 years ago • 3 comments

When computing metrics for the evaluation/test datasets, the entire set should be processed to get the correct metrics. Our examples often compute metrics with the following pattern:

def compute_metrics(logits, labels): 
  metrics = { 
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == labels) , 
     # ...
  } 
  metrics = jax.lax.pmean(metrics, 'batch') 
  return metrics 

@functools.partial(jax.pmap, axis='batch')
def eval_step(model, batch):
  logits = model(batch['input'])
  return compute_metrics(logits, batch['label'])

test_metrics = [] 
for batch in range(test_iter): 
  metrics = eval_step(optimizer.target, batch) 
  test_metrics.append(metrics) 
test_metrics = flax.training.common_utils.get_metrics(test_metrics) 
test_summary = jax.tree_map(lambda x: x.mean(), test_metrics)

Problems:

  1. If the last batch has a different number of examples the metrics will be wrong: The combination of jax.lax.pmean() and common_utils.get_metrics() will make count every batch the same way, i.e. in above example the accuracy of a full batch will have the same influence as the accuracy of a half-full batch, but to get correct metrics the half-full batch should only be weighted 50%.
  2. If any batch size is not divisible by the number of devices then above code could fail at the @pmap.
  3. If batch size is not fixed then @pmap will re-compile eval_step() for every batch size.
  4. Above scheme is incompatible altogether with metrics that cannot be averaged per batch (e.g. precision that has the number of detected in the denominator instead of the number of examples like accuracy).

Solution : Pad batches with dummy examples so they all have the same size, apply mask so only real examples are used for the metrics computation, and finally weigh batches according to the number of real examples. For addressing the last issue, a more flexible metric computation is needed.

This could be implemented separately in every example, but it probably makes more sense to factor out this functionality into a library module.

We should make sure that all examples either compute the metrics correctly or at least add a comment to the README where we decide against it for simplicity:

  • [ ] graph
  • [ ] imagenet
  • [ ] lm1b
  • [ ] mnist
  • [ ] nlp_seq
  • [ ] pixelcnn
  • [ ] ppo
  • [x] seq2seq
  • [ ] sst2
  • [ ] vae
  • [ ] wmt

Part of example improvements #231

This issue supersedes #262

andsteing avatar Oct 12 '20 08:10 andsteing

@andsteing I wonder whether we should require this for all our examples (do we realistically think we have the bandwidth to fix this?), or just close this issue and only keep the HOWTO that @avital just opened. WDYT?

marcvanzee avatar Feb 03 '22 16:02 marcvanzee

@andsteing gentle ping, maybe we should close this issue and focus on your HOWTO instead?

marcvanzee avatar May 12 '22 07:05 marcvanzee

The current plan is to add the HOWTO and some utility functions at the same time in #2111, and then to use that utility function to rewrite some of the examples (but not all).

Once that's done, I'll close this issue.

Note that the example can only be rewritten once a new release is created that includes the change. Since this is not urgent, I'll simply wait for the next natural release to appear that contains #2111.

andsteing avatar May 12 '22 11:05 andsteing

Closing this since #2111 is submitted. @andsteing please re-open if you think this is incorrect.

marcvanzee avatar Sep 06 '22 12:09 marcvanzee