Update PyTorch to Flax porting documentation to include information on weight initialisation
Stumbled across google/jax#4862 discussion on parameter initialisation differences between PyTorch and Flax.
Would it be worth adding a note in the documentation that highlights these weight init differences? Users porting training scripts may expect the initialisation to be the same but find training behaviour is different as a result of different initialisation.
The issue discusses Linear and Conv layers mainly, but there are differences in other layers such as embedding layers.
Would be willing to contribute this myself if you can point me to / provide a quick guide on contributing to docs. Thanks :)
Hey @vvvm23, there are other subtle differences such as layer hyper parameters and even numerical differences due XLA/CuDNN. I think we can give a broad statement without specific details just to make users aware of the potential issues.
+1 for both
- adding broad statement that there are numerical differences, and linking to #3128 as an example
- adding detailed information about known differences (like weight initializers for dense or convolutional layers)
to our guide https://flax.readthedocs.io/en/latest/guides/convert_pytorch_to_flax.html
A broad statement would be helpful, no need to go into great detail.