flax icon indicating copy to clipboard operation
flax copied to clipboard

Inconsistent results of Flax Dense and Pytorch Linear

Open devzhk opened this issue 1 year ago • 5 comments

Hi everyone,

I'm reading the Flax document about converting Pytorch models to Flax and trying to test some code myself. But I observed inconsistency between Flax Dense and Pytorch linear. There is clearly a big gap, which confuses me because they are mathematically equivalent, and the numerical difference should not be this big. Are there any details I miss here? I'd like to understand this issue. Any input will be appreciated!

AssertionError: 
Arrays are not almost equal to 7 decimals

Mismatched elements: 1024 / 1024 (100%)
Max absolute difference: 0.00105295
Max relative difference: 0.39149174
 x: array([[-0.5895869, -1.6013966, -0.3044982, ...,  0.691458 , -0.2171268,
        -0.3255446],
       [ 1.3270106,  1.0085726,  0.5121558, ...,  0.4522566,  2.0366738,
        -1.3362962]], dtype=float32)
 y: array([[-0.5898346, -1.6013664, -0.3044372, ...,  0.6914209, -0.2169722,
        -0.3249587],
       [ 1.3268726,  1.0090327,  0.5124962, ...,  0.4526754,  2.0368207,
        -1.3363303]], dtype=float32)

The following is the minimal code to reproduce it.

import torch
import jax
import jax.numpy as jnp
import numpy as onp
from jax import random
import flax.linen as nn

B, Cin, Cout = 2, 768, 512
key1, key2 = random.split(random.PRNGKey(0))
x = onp.random.normal(size=(B, Cin)).astype(onp.float32)    # input

dense = nn.Dense(Cout, name='dense')   # Flax module
params = dense.init(key1, jnp.ones((B, Cin))) 

torch_dense = torch.nn.Linear(Cin, Cout) # Pytorch module
state_dict = jax.device_get(params)

with torch.no_grad():
    torch_dense.weight.copy_(torch.from_numpy(state_dict['params']['kernel'].T))
    torch_dense.bias.copy_(torch.from_numpy(state_dict['params']['bias']))

jax_x = jnp.asarray(x)
torch_x = torch.from_numpy(x)

dense_out = dense.apply(params, x)
with torch.no_grad():
    torch_out = torch_dense(torch_x).numpy()

onp.testing.assert_almost_equal(jax.device_get(dense_out), torch_out)

Environment:

  • CUDA version: 11.7
  • Pytorch version: 2.0.1
  • jax version: 0.4.10
  • jaxlib version: 0.4.10
  • Flax version: 0.6.10

Originally posted by @devzhk in https://github.com/google/flax/discussions/3123

devzhk avatar Jun 02 '23 16:06 devzhk

There isn't a guarantee that the output of 2 matrix multiplies will be the same across framework. Both will call cudnn in the end but potentially with different strategies causing different numerical results.

Depending on which GPU you are using it could also happen that the matrix multiplication happens in tf32 (A100 or later).

jheek avatar Jun 02 '23 21:06 jheek

Thanks for your response! I agree that they are not guaranteed to produce the same results. But I expect they are approximately the same up to a small error. What surprises me is that the numerical difference is so large up to 0.001 (34% relative difference). Do you think this level of difference is expected?

devzhk avatar Jun 05 '23 00:06 devzhk

BTW, I tested it on A100 but didn't enable TF32.

devzhk avatar Jun 05 '23 00:06 devzhk

I guess this is the best we can get? The conclusion is that there is no way of making pytorch and jax closer numerically than as discussed here

mra-h avatar Aug 16 '23 15:08 mra-h

Sometimes just running JAX in eager vs jit mode already causes a difference 😅 . That said, you can get 4-5 decimal places close to Pytorch which works for many applications.

cgarciae avatar Aug 16 '23 16:08 cgarciae