equinox
equinox copied to clipboard
`eqx.nn.Conv2d` produces slightly different results than `torch.nn.Conv2d`
Hi,
I wanted to check if it is possible for the two operations to produce identical results. I found that jnp.isclose
assertion fails on the output when comparing torch
vs equinox
convolution2d operations. For a single conv operation, the results are similar up to atol=1e-7
but, the concern is that this difference compounds over the multiple layers of a network resulting in vastly different results at the final layer.
dtype
of weights and bias is float32
for both.
Script I used for comparison:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
import torch.nn as tnn
import torch
import numpy as np
def test_conv2d(getkey):
random_image = jax.random.uniform(key=jax.random.PRNGKey(0), shape=(1, 3, 224, 224))
t_cv = tnn.Conv2d(3, 3, kernel_size=3)
e_cv = eqx.nn.Conv2d(3, 3, kernel_size=3, key=getkey())
t_w, t_b = t_cv.weight.detach().numpy(), t_cv.bias.detach().numpy().reshape(e_cv.bias.shape)
_, treedef = jtu.tree_flatten(e_cv)
e_cv = jtu.tree_unflatten(treedef, [jnp.asarray(t_w), jnp.asarray(t_b)])
assert (e_cv.weight == t_w).all() and (e_cv.bias == t_b).all()
t_out = t_cv(torch.tensor(np.asarray(random_image))).detach().numpy()
e_out = jax.vmap(e_cv)(random_image)
assert jnp.isclose(t_out, e_out).all()
Hmm. That's definitely unfortunate, when transferring over pretrained models.
This is probably down to jax.lax.conv_general_dilated
, which eqx.nn.Conv2d
is ultimately just a thin wrapper for. If you can, I'd suggest trying either:
- Changing the
precision
argument tojax.lax.conv_general_dilated
, and seeing if this fixes things. - Removing the dependency on Equinox and simplifying this down to an issue with just
jax.lax.conv_general_dilated
, and then opening an issue on the JAX GitHub issue tracker.
Thanks,
- The issue persists even after moving to highest precision.
- I have opened an issue: https://github.com/google/jax/issues/11790 to follow this up.
Feel free to close this issue.
Great. Hopefully this can be resolved one way or another; I can imagine this would be a bit of a blocker for eqxvision.
I'll leave this issue open for now.
Just an FYI -- I'm finding this is also the case for pytorch on the gpu vs cpu.
import copy
import torch.nn as tnn
t_cv = tnn.Conv2d(3, 3, kernel_size=3)
img = torch.randn(1, 3, 224, 224);
assert torch.allclose(copy.deepcopy(t_cv)(img), t_cv(img))
assert torch.allclose(copy.deepcopy(t_cv).to('cuda:0')(img.to('cuda:0')).cpu(), t_cv(img)) # False
Need to drop the atol to 1e-6 to get it to pass ...
assert torch.allclose(copy.deepcopy(t_cv).to('cuda:0')(img.to('cuda:0')).cpu(), t_cv(img), atol=1e-6) # True
Would be interesting to quantify the difference between outputs of models with many layers to see if the output actually compounds.
I'm also wondering what the difference is between gpu models