equinox icon indicating copy to clipboard operation
equinox copied to clipboard

BatchNorm training instability fix

Open andrewdipper opened this issue 3 months ago • 1 comments

This is in reference to issue 659.

I modified BatchNorm to have two approaches "batch" and "ema". "batch" just uses the batch statistics during training time. If approach is not specified it defaults to "batch" with a warning. It's robust and seems to be the standard choice - it's far less likely to kill a model just by adding it.

"ema" is based of the smooth start method in the above issue. So keep a running mean and variance but instead of renormalizing Adam style the parts of the running averages that are zeroed are filled with the batch statistics. The problem is it's still not robust - the momentum parameter is simultaneously specifying a warmup period (when we're expecting the input distribution to change significantly) and how long we want the running average to be. So I added a linear warmup period.

Now for any choice of momentum there seems to be a warmup_period choice that will give good results. And validation performance was at least as good as with batch mode for my tests. However, I don't see a good default for warmup_period.

Some considerations:

  • having approach="batch" and the common axis_name="batch" is a little awkward
  • There's an example using BatchNorm - that will start raising a warning and should probably get changed
  • The current BatchNorm behavior can't be exactly replicated (ema / momentum=0.99 / warmup_period=1) is close but different at the start
  • There's one more piece of state hence the test_stateful.py change. Though this could be conditionally removed for approach="batch" if desired

Let me know what you think or if any changes or tests need to be added

image

andrewdipper avatar Mar 07 '24 02:03 andrewdipper