jax-resnet icon indicating copy to clipboard operation
jax-resnet copied to clipboard

Training state of ResNet coupled with mutable batch_stats collection

Open cgarciae opened this issue 3 years ago • 1 comments

Hey @n2cholas!

This is not an immediate issue but I was playing around with jax_resnet and noticed that ConvBlock decides if it should update it batch statistics or not depending on whether the batch_stats collection is mutable or not. This initially sounds like a safe bet but if you embed ResNet inside a another module that by chance also uses BatchNorm and you want to train the other module but freeze ResNet, it is not clear how you would do this.

https://github.com/n2cholas/jax-resnet/blob/5b00735aa0a68ec239af4a728ad4a596c1b551f6/jax_resnet/common.py#L43-L44

To solve this you have to:

  • Accept a use_running_average (or equivalent) argument in ConvBlock.__call__ and pass it to norm_cls.
  • Refactor ResNet to be a custom Module (instead of Sequential) so you also accept this in __call__ and pass it around to the relevant submodules that expect it.

Some repos use a single train flag to determine the state of both BatchNorm and Dropout.

Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.

cgarciae avatar Aug 31 '22 23:08 cgarciae

Thanks for raising this @cgarciae, definitely a relevant use case. I would prefer having a use_running_average member variable in ConvBlock. Perhaps in the future we can add a use_running_average=None argument in ConvBlock.__call__ if there is sufficient demand, then use nn.merge_param just like Flax does, but my general preference is to configure the behaviour of the module during construction (with @nn.compact you do both at once anyway).

Would be amazing if you could open a PR. Let me know if you have any issues with the environment/tests.

n2cholas avatar Sep 03 '22 23:09 n2cholas