blocks icon indicating copy to clipboard operation
blocks copied to clipboard

BN serialization

Open kudkudak opened this issue 7 years ago • 3 comments

Using theano and blocks from master branches.

BN population_mean and population_std are not serialized. To reproduce call the following script (based on blocks-examples/mnist/__init__.py) https://gist.github.com/kudkudak/04089a2a3a2442c935d9a26581b869b4:

  1. python mnist.py --lr=0.1 --bn_momentum=0.1. This will train on MNIST for 1 epoch and save. Last epoch accuracy should be around 90%

  2. python mnist.py --lr=0.0 --bn_momentum=0.1. This will reload model and run it for one epoch without training (lr=0). In the first epoch its accuracy will be 66% not 90%! But after 1 epoch it settles down to 90%

  3. python mnist.py --lr=0.0 --bn_momentum=0.0. If we turn off BN statistics updates the accuracy stays at the 66%.

kudkudak avatar Apr 18 '17 07:04 kudkudak

Here is how to "fix" it without changing Blocks code https://gist.github.com/kudkudak/f8d80e1113afbc34b8aa5efac498ce82.

  1. https://gist.github.com/kudkudak/f8d80e1113afbc34b8aa5efac498ce82#file-gistfile1-txt-L83 adds more parameters to save to Checkpoint

  2. https://gist.github.com/kudkudak/f8d80e1113afbc34b8aa5efac498ce82#file-gistfile1-txt-L101 makes sure they get deserialized

Alternatively one could tag BN population_mean/population_std with PARAMETER role.

Not sure what is the best way to fix that internally in Blocks

kudkudak avatar Apr 18 '17 08:04 kudkudak

Thanks! I'll take a look.

dmitriy-serdyuk avatar Apr 18 '17 13:04 dmitriy-serdyuk

Thanks for raising this issue, @kudkudak

For now, here is a brain dump of what I think about it. First of all, people who use continue_training won't be affected. It's specifically people who use Checkpoint + Load who are affected. But there's a lot of advantage to doing so, so we should find a solution.

It is clear that David though of this when he implemented batch norm, for example, here is PersistentRole: https://github.com/mila-udem/blocks/blob/master/blocks/roles.py#L96 . So of the top of my head we could fix everything by adding get_persistent_dict and set_persistent_dict to Model and also making Load use these instead of get_parameter_dict and set_parameter_dict.

@dwf , if you have time, can you please remind me what was your take on saving/loading batch norm stats?

rizar avatar Apr 18 '17 13:04 rizar