Using class composition on abstract interfaces w/ methods and fields
I am wondering what the recommended practice is for class composition on abstract interfaces with both methods and fields. Concretely, here's what I'm talking about:
Forwarding the methods of an abstract interface is easy:
class AbstractModule(eqx.Module):
@abc.abstractmethod
def some_method(self):
raise NotImplementedError()
class ConcreteModule(AbstractModule):
@override
def some_method(self):
...
class ForwardedModule(AbstractModule):
example_module: ConcreteModule
def __init__(self, example_module: ConcreteModule):
self.example_module = example_module
@override
def some_method(self):
return self.example_module.some_method()
Forwarding fields is also easy:
class AbstractModule(eqx.Module):
some_field: eqx.AbstractVar[int]
class ConcreteModule(AbstractModule):
some_field: int
def __init__(self, some_field: int):
self.some_field = some_field
class ForwardedModule(AbstractModule):
some_field: int
def __init__(self, example_module: ConcreteModule):
self.some_field = example_module.some_field
What to do if there are both fields and methods? This must be such that there is only one value of "some_field" in the pytree.
class AbstractModule(eqx.Module):
some_field: eqx.AbstractVar[int]
@abc.abstractmethod
def some_method(self):
raise NotImplementedError()
class ConcreteModule(AbstractModule):
some_field: int
def __init__(self, some_field: int):
self.some_field = some_field
@override
def some_method(self):
# use `some_field` somehow
...
class ForwardedModule(AbstractModule):
some_field: int
def __init__(self, example_module: ConcreteModule):
self.some_field = example_module.some_field
@override
def some_method(self): # how to implement?
...
The only way I can think of doing it is the following, which feels like a hack and would lead to a confusing user API:
class ForwardedModule(AbstractModule):
some_field: int
example_module: ConcreteModule
def __init__(self, example_module: ConcreteModule):
self.some_field = example_module.some_field
self.example_module = example_module
@override
def some_method(self):
example_module = eqx.tree_at(lambda x: x.some_field, self.example_module, self.some_field)
return example_module.some_method()
Returning to just the 'field' case, it would usually be preferable to implement this as a @property that forwards on, just like the method case.
I think the issue you're bumping in to with reassignment is that you now have two distinct locations in the pytree structure that may in general become different (e.g. via tree_at or otherwise)? Having a single-source-of-truth with the @property makes this impossible.
This is a good suggestion. My users typically need to reference individual pytree leaves for fine control of which parameters are transformed via vmap, grad, etc (we’ve spoken about this before on #618). There’s a small tradeoff then where users could mistake properties as pytree leaves, and I’ve been liking the simplicity of defining leaves with eqx.AbstractVar. I think you’re right that this is the clean solution though.
So far I’ve had the policy where some of my classes are composable and the less important ones are not due to this issue. This may be the way to continue…
Though to avoid confusion with pytree leaves rather than using a property this is easily solved by using a get_* method!