flax
flax copied to clipboard
Inconsistent results of Flax Dense and Pytorch Linear
Hi everyone,
I'm reading the Flax document about converting Pytorch models to Flax and trying to test some code myself. But I observed inconsistency between Flax Dense and Pytorch linear. There is clearly a big gap, which confuses me because they are mathematically equivalent, and the numerical difference should not be this big. Are there any details I miss here? I'd like to understand this issue. Any input will be appreciated!
AssertionError:
Arrays are not almost equal to 7 decimals
Mismatched elements: 1024 / 1024 (100%)
Max absolute difference: 0.00105295
Max relative difference: 0.39149174
x: array([[-0.5895869, -1.6013966, -0.3044982, ..., 0.691458 , -0.2171268,
-0.3255446],
[ 1.3270106, 1.0085726, 0.5121558, ..., 0.4522566, 2.0366738,
-1.3362962]], dtype=float32)
y: array([[-0.5898346, -1.6013664, -0.3044372, ..., 0.6914209, -0.2169722,
-0.3249587],
[ 1.3268726, 1.0090327, 0.5124962, ..., 0.4526754, 2.0368207,
-1.3363303]], dtype=float32)
The following is the minimal code to reproduce it.
import torch
import jax
import jax.numpy as jnp
import numpy as onp
from jax import random
import flax.linen as nn
B, Cin, Cout = 2, 768, 512
key1, key2 = random.split(random.PRNGKey(0))
x = onp.random.normal(size=(B, Cin)).astype(onp.float32) # input
dense = nn.Dense(Cout, name='dense') # Flax module
params = dense.init(key1, jnp.ones((B, Cin)))
torch_dense = torch.nn.Linear(Cin, Cout) # Pytorch module
state_dict = jax.device_get(params)
with torch.no_grad():
torch_dense.weight.copy_(torch.from_numpy(state_dict['params']['kernel'].T))
torch_dense.bias.copy_(torch.from_numpy(state_dict['params']['bias']))
jax_x = jnp.asarray(x)
torch_x = torch.from_numpy(x)
dense_out = dense.apply(params, x)
with torch.no_grad():
torch_out = torch_dense(torch_x).numpy()
onp.testing.assert_almost_equal(jax.device_get(dense_out), torch_out)
Environment:
- CUDA version: 11.7
- Pytorch version: 2.0.1
- jax version: 0.4.10
- jaxlib version: 0.4.10
- Flax version: 0.6.10
Originally posted by @devzhk in https://github.com/google/flax/discussions/3123
There isn't a guarantee that the output of 2 matrix multiplies will be the same across framework. Both will call cudnn in the end but potentially with different strategies causing different numerical results.
Depending on which GPU you are using it could also happen that the matrix multiplication happens in tf32 (A100 or later).
Thanks for your response! I agree that they are not guaranteed to produce the same results. But I expect they are approximately the same up to a small error. What surprises me is that the numerical difference is so large up to 0.001 (34% relative difference). Do you think this level of difference is expected?
BTW, I tested it on A100 but didn't enable TF32.
I guess this is the best we can get? The conclusion is that there is no way of making pytorch and jax closer numerically than as discussed here
Sometimes just running JAX in eager vs jit mode already causes a difference 😅 . That said, you can get 4-5 decimal places close to Pytorch which works for many applications.