equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Using class composition on abstract interfaces w/ methods and fields

Open michael-0brien opened this issue 5 months ago • 4 comments

I am wondering what the recommended practice is for class composition on abstract interfaces with both methods and fields. Concretely, here's what I'm talking about:

Forwarding the methods of an abstract interface is easy:

class AbstractModule(eqx.Module):

    @abc.abstractmethod
    def some_method(self):
        raise NotImplementedError()

class ConcreteModule(AbstractModule):

    @override
    def some_method(self):
        ...


class ForwardedModule(AbstractModule):

    example_module: ConcreteModule

    def __init__(self, example_module: ConcreteModule):
        self.example_module = example_module

    @override
    def some_method(self):
        return self.example_module.some_method()

Forwarding fields is also easy:

class AbstractModule(eqx.Module):

    some_field: eqx.AbstractVar[int]

class ConcreteModule(AbstractModule):

    some_field: int

    def __init__(self, some_field: int):
        self.some_field = some_field


class ForwardedModule(AbstractModule):

    some_field: int

    def __init__(self, example_module: ConcreteModule):
        self.some_field = example_module.some_field

What to do if there are both fields and methods? This must be such that there is only one value of "some_field" in the pytree.

class AbstractModule(eqx.Module):

    some_field: eqx.AbstractVar[int]

    @abc.abstractmethod
    def some_method(self):
        raise NotImplementedError()

class ConcreteModule(AbstractModule):

    some_field: int

    def __init__(self, some_field: int):
        self.some_field = some_field


    @override
    def some_method(self):
        # use `some_field` somehow
        ...


class ForwardedModule(AbstractModule):

    some_field: int

    def __init__(self, example_module: ConcreteModule):
        self.some_field = example_module.some_field

    @override
    def some_method(self):  # how to implement? 
        ...

michael-0brien avatar Aug 12 '25 15:08 michael-0brien

The only way I can think of doing it is the following, which feels like a hack and would lead to a confusing user API:

class ForwardedModule(AbstractModule):

    some_field: int
    example_module: ConcreteModule

    def __init__(self, example_module: ConcreteModule):
        self.some_field = example_module.some_field
        self.example_module = example_module

    @override
    def some_method(self):
        example_module = eqx.tree_at(lambda x: x.some_field, self.example_module, self.some_field)
        return example_module.some_method()

michael-0brien avatar Aug 12 '25 15:08 michael-0brien

Returning to just the 'field' case, it would usually be preferable to implement this as a @property that forwards on, just like the method case.

I think the issue you're bumping in to with reassignment is that you now have two distinct locations in the pytree structure that may in general become different (e.g. via tree_at or otherwise)? Having a single-source-of-truth with the @property makes this impossible.

patrick-kidger avatar Aug 13 '25 07:08 patrick-kidger

This is a good suggestion. My users typically need to reference individual pytree leaves for fine control of which parameters are transformed via vmap, grad, etc (we’ve spoken about this before on #618). There’s a small tradeoff then where users could mistake properties as pytree leaves, and I’ve been liking the simplicity of defining leaves with eqx.AbstractVar. I think you’re right that this is the clean solution though.

So far I’ve had the policy where some of my classes are composable and the less important ones are not due to this issue. This may be the way to continue…

michael-0brien avatar Aug 13 '25 12:08 michael-0brien

Though to avoid confusion with pytree leaves rather than using a property this is easily solved by using a get_* method!

michael-0brien avatar Aug 13 '25 12:08 michael-0brien