equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Out-of-place updates with custom `__init__()` constructor

Open Charl-AI opened this issue 2 years ago • 3 comments

Hi @patrick-kidger!


Sorry for the long-winded feature request, it got out of hand as I was writing it! The basic question is this: would you be interested in adding a .replace() method to eqx.Module that can deal with custom __init__() constructors?

I've been really enjoying equinox, but one challenge I've noticed is that it is very difficult to do out-of-place updates with custom __init__() constructors. This is because dataclasses.replace() internally calls __init__(), so smart constructors (i.e. ones that are not identical to the auto-generated dataclass one) break it. This is demonstrated below with a demo involving a basic counter class:

Works Doesn't work

import equinox as eqx
import dataclasses

class Counter(eqx.Module):
    x: int

    def __init__(self, x: int):
        self.x = x

    def increment(self):
        return dataclasses.replace(self, x=self.x+1)

C1 = Counter(0)
assert C1.x == 0
C2 = C1.increment()
assert C2.x == 1


import equinox as eqx
import dataclasses

class Counter(eqx.Module):
    x: int

    def __init__(self, z: int):
        # 'smart' constructor inits x by calculating from z
        self.x = 2 * z

    def increment(self):
        return dataclasses.replace(self, x=self.x+1)

C1 = Counter(0)
assert C1.x == 0
C2 = C1.increment() # fails: __init__() got unexpected kwarg 'x'

Obviously, it's not just equinox which has this issue: the dataclass docs warn about issues with replace-ing init=False fields and Flax recommends avoiding overriding __init__() entirely in flax.struct.dataclass (i.e. implementing smart constructors as separate classmethods).

Personally, I don't love this solution, I think one of the strengths of Equinox is that people can write fairly Pythonic PyTorch-style modules; asking people to define separate classmethods as constructors feels like a step backwards and IMO it makes APIs more complicated. Out-of-place, .replace()-style updates should be a key feature of equinox.Module, but it is a nasty gotcha if people have to choose between that and writing custom __init__() methods.

One possible solution could be to add a custom .replace() method to eqx.Module that knows how to handle custom __init__() constructors. Below is an example of how you might do that, although it does require a hidden argument to be added to __init__() (inspired by this issue, where they mention the _root attribute in NamedTuple explained here).


import equinox as eqx
import dataclasses


class Counter(eqx.Module):
    x: int

    def __init__(self, z: int, _replacing=False, **kwargs):
        if _replacing:
            self.__dict__.update(kwargs)
        else:
            self.x = 2 * z

    def increment(self):
        return self.replace(x=self.x + 1)

    def replace(self, **changes):
        # z=0 is just a placeholder, it is  not used in the __init__() any more
        return self.__class__(z=0, _replacing=True, **changes)


C1 = Counter(0)
assert C1.x == 0
C2 = C1.increment()
assert C2.x == 1

There are a bunch of issues with this exact implementation, but is this a feature you'd be interested in adding to eqx.Module in principle? Other (untested) ideas for implementing this functionality could be monkey-patching __init__() to the standard dataclass constructor, then un-monkey-patching it, or potentially a hidden argument approach with single dispatch based on the type of the hidden argument passed (inspired by this issue). Perhaps there's some metaclass madness that would make all of this easy?

Thanks for taking the time to read, let me know what you think :)

Charl-AI avatar Jun 22 '22 15:06 Charl-AI

Hey there! Thanks for your interest. I think we clearly have the same conceptual preferences: using __init__ rather than a new classmethod was important to me to.

Have you seen equinox.tree_at? I think this does what you're after, e.g:

counter = Counter(0)

inc_counter = eqx.tree_at(lambda c: c.x, counter, counter.x + 1)
# equivalently:
inc_counter = eqx.tree_at(lambda c: c.x, counter, replace_fn=lambda i: i + 1)

and moreover this is powerful enough to make essentially arbitrary changes to any PyTree; not just a dataclass.

Of course if there was a particular operation you did frequently (like incrementing a counter) you could wrap equinox.tree_at into a standalone function. (It's also important that it not be a module method, as we don't want to special-case any methods. This is the same preference-for-simplicity rationale as not wanting to define a custom replacement for __init__, discussed earlier.)

patrick-kidger avatar Jun 22 '22 16:06 patrick-kidger

Thanks for getting back to me, that's exactly what I was looking for! I can just wrap it into a flax-style replace method for my use case, something like this works:

def replace(self, **changes):
        keys, vals = zip(*changes.items())
        return eqx.tree_at(lambda c: [c.__dict__[key] for key in keys], self, vals)

Really loving the design of Equinox right now!

Charl-AI avatar Jun 22 '22 17:06 Charl-AI

I'm glad to hear it!

patrick-kidger avatar Jun 22 '22 17:06 patrick-kidger