returnn icon indicating copy to clipboard operation
returnn copied to clipboard

RF BatchNorm running var small diff between TF-layers, pure RF and direct PyTorch, biased vs unbiased

Open albertz opened this issue 1 year ago • 0 comments

There are multiple implementations of batch norm, but here, three different cases are relevant:

  1. The pure RF implementation (which is used e.g. when use_mask=True)
  2. RF TF-layers backend, via LayerBase.batch_norm
  3. RF PyTorch backend, which uses torch.nn.functional.batch_norm

In case 1 and 2, we use the biased estimate of the batch variance to update the running variance.

In case 3, the unbiased estimate of the batch variance is used to update the running variance. Specifically this is a factor (n-1)/n difference.

It's unclear what is really better. See my question here. What do other frameworks use to update the running variance?

  • PyTorch: As said above, unbiased batch variance. But see https://github.com/pytorch/pytorch/issues/1410 and https://github.com/pytorch/pytorch/issues/77427 about exactly this topic.
  • Flax: Biased batch variance
  • Keras (ops.moments PyTorch backend, TF backend via tf.nn.moments): Biased batch variance

It's also a bit unclear how to solve this. I guess we want that RF behaves the same for every backend. Case 1 is independent from the backend, so there is no problem. But case 2 and 3 are different, which is a problem, i.e. basically a bug in RF that it is different across backends. But also, when using use_mask=True, you would expect the same behavior as with use_mask=False (regarding whether biased or unbiased variance is used).

In any case, there should be a new option for this, and there should be a meaningful default. I guess we need a new behavior version for a new default here, and for the old behavior version, keep the current (inconsistent) behavior (i.e. the default depends on case and backend).

What would be a meaningful default? If there is a clear answer what is better, than it should be that one, but it's probably not clear. Maybe it actually does not really matter at all. Then just the question remains, which variant is faster. In PyTorch, this fused function torch.nn.functional.batch_norm only supports this one variant, i.e. using unbiased batch variance to update the running variance.

I just skimmed over Batch Renormalization, and this variant sounds to be more meaningful. Not sure if this should be the default but this should definitely be a possible variant in RF. Not sure if this should be an option for our existing rf.BatchNorm or whether this should be a separate implementation.

albertz avatar Jun 12 '24 09:06 albertz