Truncated Normal initializer doesn't match PyTorch
System information
- WSL 2 on Windows
- Flax 0.8.4, JAX 0.4.30
- Python 3.10
- RTX 2080
- CUDA 12.3
Both nn.initializers.truncated_normal and jax.nn.initializers.truncated_normal aren't similar enough to PyTorch's nn.init.trunc_normal_. All of these use a lower of -2 and upper of 2 by default.
I'm running a test to make sure the outputs are similar if given the same arguments.
Here's my JAX code.
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
def summary_stats(name, x):
print(f'Stats for {name}:')
print('shape: ', x.shape)
print('min: ', x.min())
print('max: ', x.max())
print('std: ', jnp.std(x))
# print(x)
def make_trunc(key, stddev):
lower = -2
upper = -lower
shape = (4096,)
return nn.initializers.truncated_normal(stddev, lower=lower, upper=upper)(key, shape=shape)
# return jax.nn.initializers.truncated_normal(stddev, lower=lower, upper=upper)(key, shape=shape)
summary_stats('trunc', make_trunc(random.key(0), .02))
summary_stats('trunc', make_trunc(random.key(1), .04))
summary_stats('trunc', make_trunc(random.key(2), .06))
Here's my PyTorch code:
import torch
import torch.nn as nn
def summary_stats(name, x):
print(f'Stats for {name}:')
print('shape: ', x.shape)
print('min: ', x.min().item())
print('max: ', x.max().item())
print('std: ', x.std().item())
# print(x)
t = torch.zeros((4096,))
nn.init.trunc_normal_(t, std=0.02)
summary_stats('pytorch', t)
nn.init.trunc_normal_(t, std=0.04)
summary_stats('pytorch', t)
nn.init.trunc_normal_(t, std=0.06)
summary_stats('pytorch', t)
JAX output:
Stats for trunc:
shape: (4096,)
min: -0.039960504
max: 0.039985776
std: 0.018035976
Stats for trunc:
shape: (4096,)
min: -0.0797701
max: 0.07959695
std: 0.03491081
Stats for trunc:
shape: (4096,)
min: -0.11983735
max: 0.119733654
std: 0.053539284
PyTorch output:
Stats for pytorch:
shape: torch.Size([4096])
min: -0.06634494662284851
max: 0.0743524581193924
std: 0.020439231768250465
Stats for pytorch:
shape: torch.Size([4096])
min: -0.13382470607757568
max: 0.12441056221723557
std: 0.03931436687707901
Stats for pytorch:
shape: torch.Size([4096])
min: -0.22086666524410248
max: 0.20988918840885162
std: 0.05979840084910393
Although the std values look close enough, the min and max seem off.
However, let's look at the JAX output again if I set lower=-4, even though PyTorch is using -2.
JAX output:
Stats for trunc:
shape: (4096,)
min: -0.07253291
max: 0.076029524
std: 0.020624608
Stats for trunc:
shape: (4096,)
min: -0.13531744
max: 0.12941647
std: 0.039439432
Stats for trunc:
shape: (4096,)
min: -0.21360189
max: 0.20679429
std: 0.060959056
Now min/max line up with PyTorch better. I haven't figured out in the source code what explains this, but it would be nice to document it if it's an intended design.
On a quick look at the torch documentation and the source code of jax.random.truncated_normal, it seems that:
- the torch version runs on a fixed normal distribution and redraws a new random value when a number is out of bounds (doc)
- the JAX version runs a customized normal distribution based on the given upper & lower bounds (
jax.random.uniformcalled here called to uniformly sample between a customized min-max range)
This might explain why the min/max values of Pytorch are more divergent from 0, as it is based on a distribution that has a higher chance to be out-of-bound.
If you'd like to know more, I'd recommend open an issue/question on JAX Github for a response from the authors.
Thanks for taking a look.
I've been plotting histograms and I've observed that I can get the same behavior between PyTorch and JAX with this procedure:
- Use the same std deviation argument in both PyTorch and JAX.
- Take the lower/upper values that you're using in PyTorch and divide by the std deviation to get the lower/upper to use in JAX.
In JAX, if you change the std deviation parameter, the "shape" of the histogram doesn't change. If the xaxis is set to auto, then you essentially see the same shape but with different bounds. This is not true for PyTorch. In PyTorch to get the same behavior, you'd multiply both the lower/upper and std deviation by the same factor.
I think that Convert PyTorch models to Flax should have a section dedicated to initializers. I'm porting training code, not just weights, so it's helpful to have notes on initializers.
In my work so far I think I've noticed that to get PyTorch behavior
- The default Flax Conv init for kernels and bias would be
nn.initializers.variance_scaling(1/3, "fan_in", "uniform")instead of lecun_normal. But forbias_init, you have to implement it yourself: https://github.com/google/flax/discussions/4131 - The default Conv Transpose init for kernels and bias would be
nn.initializers.variance_scaling(1/3, "fan_out", "uniform")instead of lecun_normal. But forbias_init, you have to implement it yourself: https://github.com/google/flax/discussions/4131 - The default Flax Embed init would be
nn.initializers.normal(1)instead ofnn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)