blocks
blocks copied to clipboard
BN serialization
Using theano and blocks from master branches.
BN population_mean
and population_std
are not serialized. To reproduce call the following script (based on blocks-examples/mnist/__init__.py
) https://gist.github.com/kudkudak/04089a2a3a2442c935d9a26581b869b4:
-
python mnist.py --lr=0.1 --bn_momentum=0.1
. This will train on MNIST for 1 epoch and save. Last epoch accuracy should be around 90% -
python mnist.py --lr=0.0 --bn_momentum=0.1
. This will reload model and run it for one epoch without training (lr=0). In the first epoch its accuracy will be 66% not 90%! But after 1 epoch it settles down to 90% -
python mnist.py --lr=0.0 --bn_momentum=0.0
. If we turn off BN statistics updates the accuracy stays at the 66%.
Here is how to "fix" it without changing Blocks code https://gist.github.com/kudkudak/f8d80e1113afbc34b8aa5efac498ce82.
-
https://gist.github.com/kudkudak/f8d80e1113afbc34b8aa5efac498ce82#file-gistfile1-txt-L83 adds more parameters to save to Checkpoint
-
https://gist.github.com/kudkudak/f8d80e1113afbc34b8aa5efac498ce82#file-gistfile1-txt-L101 makes sure they get deserialized
Alternatively one could tag BN population_mean/population_std with PARAMETER role.
Not sure what is the best way to fix that internally in Blocks
Thanks! I'll take a look.
Thanks for raising this issue, @kudkudak
For now, here is a brain dump of what I think about it. First of all, people who use continue_training
won't be affected. It's specifically people who use Checkpoint
+ Load
who are affected. But there's a lot of advantage to doing so, so we should find a solution.
It is clear that David though of this when he implemented batch norm, for example, here is PersistentRole
: https://github.com/mila-udem/blocks/blob/master/blocks/roles.py#L96 . So of the top of my head we could fix everything by adding get_persistent_dict
and set_persistent_dict
to Model
and also making Load
use these instead of get_parameter_dict
and set_parameter_dict
.
@dwf , if you have time, can you please remind me what was your take on saving/loading batch norm stats?