Training state of ResNet coupled with mutable batch_stats collection
Hey @n2cholas!
This is not an immediate issue but I was playing around with jax_resnet and noticed that ConvBlock decides if it should update it batch statistics or not depending on whether the batch_stats collection is mutable or not. This initially sounds like a safe bet but if you embed ResNet inside a another module that by chance also uses BatchNorm and you want to train the other module but freeze ResNet, it is not clear how you would do this.
https://github.com/n2cholas/jax-resnet/blob/5b00735aa0a68ec239af4a728ad4a596c1b551f6/jax_resnet/common.py#L43-L44
To solve this you have to:
- Accept a
use_running_average(or equivalent) argument inConvBlock.__call__and pass it tonorm_cls. - Refactor
ResNetto be a custom Module (instead ofSequential) so you also accept this in__call__and pass it around to the relevant submodules that expect it.
Some repos use a single train flag to determine the state of both BatchNorm and Dropout.
Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.
Thanks for raising this @cgarciae, definitely a relevant use case. I would prefer having a use_running_average member variable in ConvBlock. Perhaps in the future we can add a use_running_average=None argument in ConvBlock.__call__ if there is sufficient demand, then use nn.merge_param just like Flax does, but my general preference is to configure the behaviour of the module during construction (with @nn.compact you do both at once anyway).
Would be amazing if you could open a PR. Let me know if you have any issues with the environment/tests.