equinox icon indicating copy to clipboard operation
equinox copied to clipboard

`WeightNorm` causes unexpected PyTrees inequality

Open yuanz271 opened this issue 10 months ago • 8 comments

The following assertion fails.

import equinox as eqx
import equinox.random as jrandom

layer = eqx.nn.Linear(2, 2, key=jrandom.key(0))
m1 = eqx.nn.WeightNorm(layer)
m2 = eqx.nn.WeightNorm(layer)
assert eqx.tree_equal(m1, m2)  # unequal

The reason of failure is that WeightNorm._norm is different (== operation returns False) for every WeightNorm instance.

This also causes the failure of using tree_equal to check deserialized models.

yuanz271 avatar Mar 01 '25 17:03 yuanz271

This is expected. The values of the leaves are the same, but you have two new trees now, that makes two objects that JAX is treating separately. Static leaves (such as bound methods and their jaxprs) generally do not permit equality checks, and layer does have a bound method. If you partition the layers into the dynamic and static components, you will see that equality checks for the weights.

You can verify this:

import equinox as eqx
import jax.random as jr


layer = eqx.nn.Linear(2, 2, key=jr.key(0))
m1 = eqx.nn.WeightNorm(layer)
m2 = eqx.nn.WeightNorm(layer)
d1, s1 = eqx.partition(m1, eqx.is_array)
d2, s2 = eqx.partition(m2, eqx.is_array)

assert eqx.tree_equal(d1, d2)  # Dynamic components (arrays) are equal
assert not eqx.tree_equal(s1, s2)  # Static components are not equal

johannahaffner avatar Mar 01 '25 18:03 johannahaffner

@johannahaffner thank you.

This is expected. The values of the leaves are the same, but you have two new trees now, that makes two objects that JAX is treating separately. Static leaves (such as bound methods and their jaxprs) generally do not permit equality checks, and layer does have a bound method. If you partition the layers into the dynamic and static components, you will see that equality checks for the weights.

However, what's the motivation to declare _norm in the inventory? The other instance (bound) methods are not compared.

yuanz271 avatar Mar 01 '25 19:03 yuanz271

This is not specific to equinox, this is a general feature of how Python compares objects for equality. Callables, such as functions and methods, do not support equality checks, which require the implementation of an __eq__ method on the object. For these, Python will actually check if they are the exact same instance.

That means that equality checks for callables only pass if these point at the exact same thing - which is a rare special case in practice. Because equinox modules are immutable, we get a new layer back. In the second example below, we actually create the callable f afresh each time, and you see that equality does not check, even though they close over the same value.

def f(x):
    pass

g = f
h = f

assert g == h  # g and h point to the same object

def make_f(x):
    def f():
        return x
    return f

x = 42
g = make_f(x)
h = make_f(x)

assert g == h  # g and h point to different objects

The main takeaway is that you won't ever care about having your methods be the exact same thing, but you will care about having your leaves be the exact same thing! And you can check that by partitioning, and then asserting equality on the arrays.

johannahaffner avatar Mar 02 '25 12:03 johannahaffner

Maybe I didn't ask clearly. My point is that, if you define _norm as an instance method instead of a Module field, it will not be compared (like __call__ of a Module). So, why this particular way?

yuanz271 avatar Mar 02 '25 13:03 yuanz271

Ah! That is what you mean. This is consistent with how things are done elsewhere in the library. More complicated modules such as MultiheadAttention actually have several fields that are callable, e.g. here:

https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/nn/_attention.py#L122

Comparing any callable for equality is just not what is expected in Python, and I don't think that there is a good reason to push things into methods in order to be able to do so.

johannahaffner avatar Mar 02 '25 13:03 johannahaffner

That I understand because query_proj is a Module so that it handles well recursively. However, _norm is internally supposed to be static and stateless by WeightNorm. Since tree_equal compares arrays by value, the semantics is more like to discriminate by its function rather than reference at least for a Module (Module overloads __eq__ using tree_equal IIRC).

yuanz271 avatar Mar 02 '25 13:03 yuanz271

Right. I can't speak for @patrick-kidger, but elsewhere in the Equinox system we do make frequent use of fields to specify some norm, e.g. here. The reason being that a design like that allows to make the norm public as an optional input argument if requested, and users might then like to use a norm other than the Frobenius norm which jnp.linalg.norm defaults to. If this was implemented as an instance method, this would require a breaking change, if it is a field it just requires the addition of a keyword argument to WeightNorm.__init__, which can default to the current behaviour and would not be breaking.

johannahaffner avatar Mar 02 '25 13:03 johannahaffner

Okay, so! I think we can change this.

First of all, norm is assigned here:

https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/nn/_weight_norm.py#L91

and I assume the fact that it is dynamically creating new partials means that these are not comparing equal to each other later.

The fix is probably to cache based on axis, which is the only thing that changes. Or we could switch ft.partial to eqx.Partial, which I think has better equality semantics.

As for why this is a field, and indeed one that isn't marked static (as opposed to a method or a static field): this is because there have been a few use-cases in which people would like to dynamically patch this field using eqx.tree_at. It's unusual but valid!

I'd be happy to take a PR on the above! I think equality here is a reasonable thing to want :)

patrick-kidger avatar Mar 04 '25 08:03 patrick-kidger