equinox icon indicating copy to clipboard operation
equinox copied to clipboard

`eqx.nn.Conv2d` produces slightly different results than `torch.nn.Conv2d`

Open paganpasta opened this issue 1 year ago • 4 comments

Hi,

I wanted to check if it is possible for the two operations to produce identical results. I found that jnp.isclose assertion fails on the output when comparing torch vs equinox convolution2d operations. For a single conv operation, the results are similar up to atol=1e-7 but, the concern is that this difference compounds over the multiple layers of a network resulting in vastly different results at the final layer.

dtype of weights and bias is float32 for both.

Script I used for comparison:

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import equinox as eqx
import torch.nn as tnn
import torch
import numpy as np


def test_conv2d(getkey):
    random_image = jax.random.uniform(key=jax.random.PRNGKey(0), shape=(1, 3, 224, 224))
    t_cv = tnn.Conv2d(3, 3, kernel_size=3)
    e_cv = eqx.nn.Conv2d(3, 3, kernel_size=3, key=getkey())
    t_w, t_b = t_cv.weight.detach().numpy(), t_cv.bias.detach().numpy().reshape(e_cv.bias.shape)
    _, treedef = jtu.tree_flatten(e_cv)
    e_cv = jtu.tree_unflatten(treedef, [jnp.asarray(t_w), jnp.asarray(t_b)])
    assert (e_cv.weight == t_w).all() and (e_cv.bias == t_b).all()

    t_out = t_cv(torch.tensor(np.asarray(random_image))).detach().numpy()
    e_out = jax.vmap(e_cv)(random_image)

    assert jnp.isclose(t_out, e_out).all()

paganpasta avatar Aug 08 '22 09:08 paganpasta

Hmm. That's definitely unfortunate, when transferring over pretrained models.

This is probably down to jax.lax.conv_general_dilated, which eqx.nn.Conv2d is ultimately just a thin wrapper for. If you can, I'd suggest trying either:

  • Changing the precision argument to jax.lax.conv_general_dilated, and seeing if this fixes things.
  • Removing the dependency on Equinox and simplifying this down to an issue with just jax.lax.conv_general_dilated, and then opening an issue on the JAX GitHub issue tracker.

patrick-kidger avatar Aug 08 '22 10:08 patrick-kidger

Thanks,

  • The issue persists even after moving to highest precision.
  • I have opened an issue: https://github.com/google/jax/issues/11790 to follow this up.

Feel free to close this issue.

paganpasta avatar Aug 08 '22 12:08 paganpasta

Great. Hopefully this can be resolved one way or another; I can imagine this would be a bit of a blocker for eqxvision.

I'll leave this issue open for now.

patrick-kidger avatar Aug 08 '22 13:08 patrick-kidger

Just an FYI -- I'm finding this is also the case for pytorch on the gpu vs cpu.

import copy
import torch.nn as tnn

t_cv = tnn.Conv2d(3, 3, kernel_size=3)
img = torch.randn(1, 3, 224, 224);
assert torch.allclose(copy.deepcopy(t_cv)(img), t_cv(img))
assert torch.allclose(copy.deepcopy(t_cv).to('cuda:0')(img.to('cuda:0')).cpu(), t_cv(img))    # False

Need to drop the atol to 1e-6 to get it to pass ...

assert torch.allclose(copy.deepcopy(t_cv).to('cuda:0')(img.to('cuda:0')).cpu(), t_cv(img), atol=1e-6)    # True

Would be interesting to quantify the difference between outputs of models with many layers to see if the output actually compounds.

I'm also wondering what the difference is between gpu models

jenkspt avatar Aug 08 '22 23:08 jenkspt