flax icon indicating copy to clipboard operation
flax copied to clipboard

Update PyTorch to Flax porting documentation to include information on weight initialisation

Open vvvm23 opened this issue 2 years ago • 3 comments

Stumbled across google/jax#4862 discussion on parameter initialisation differences between PyTorch and Flax.

Would it be worth adding a note in the documentation that highlights these weight init differences? Users porting training scripts may expect the initialisation to be the same but find training behaviour is different as a result of different initialisation.

The issue discusses Linear and Conv layers mainly, but there are differences in other layers such as embedding layers.

Would be willing to contribute this myself if you can point me to / provide a quick guide on contributing to docs. Thanks :)

vvvm23 avatar Jun 01 '23 18:06 vvvm23

Hey @vvvm23, there are other subtle differences such as layer hyper parameters and even numerical differences due XLA/CuDNN. I think we can give a broad statement without specific details just to make users aware of the potential issues.

cgarciae avatar Jun 05 '23 13:06 cgarciae

+1 for both

  • adding broad statement that there are numerical differences, and linking to #3128 as an example
  • adding detailed information about known differences (like weight initializers for dense or convolutional layers)

to our guide https://flax.readthedocs.io/en/latest/guides/convert_pytorch_to_flax.html

andsteing avatar Jun 05 '23 13:06 andsteing

A broad statement would be helpful, no need to go into great detail.

vvvm23 avatar Jun 05 '23 17:06 vvvm23