equinox
equinox copied to clipboard
Segmentation fault in tree_flatten when subclass method is passed to superclass __init__
I have the following code, which causes a segmentation fault when tree_flatten is called on any instance of SubComponent. I think it has to do with the jtu.Partial(method, self)
in _wrap_method
so may be related to #291. In any case, my guess is the following:
- tree_flatten goes through the leaves of component, and reaches
transform
. -
transform
is a jtu.Partial(self.transform, self), due to the_wrap_method
. - the partial is itself a PyTree and so can itself be traversed for flattening.
- when the
self
leaf in the partial is reached, we have a leaf of typeSubComponent
, which can again be traversed further for flattening. - this has now started a loop (returning back to step 1), that terminates with either a segmentation fault, or the maximum recursion depth error.
MWE:
from typing import Callable
import jax
import equinox as eqx
class Component(eqx.Module):
transform: Callable[[float], float]
validator: Callable[[Callable], Callable]
def __init__(self, transform=lambda x: x, validator=lambda f: f) -> None:
self.transform = transform # Base transformation.
self.validator = validator # Validation wrapper around the transformation.
def __call__(self, x):
return self.validator(self.transform)(x) # Execute validated transformation.
class SubComponent(Component):
test: Callable[[float], float]
def __init__(self, test=lambda x: 2 * x) -> None:
self.test = test
# From my understanding this ends up equivalent to
# self.transform = jtu.Partial(self._transform, self)
super().__init__(self._transform)
# Custom implementation of transform that depends on information available in the
# SubComponent PyTree. If `test` was modfied with tree_at, then this method and the
# __call__, defined in the parent should both be updated.
def _transform(self, x):
return self.test(x)
a = SubComponent()
jax.tree_util.tree_flatten(a) # Will cause segmentation fault
print(a) # Will raise a maximum recursion depth error.
I don't really know a good solution, beyond the fact that I don't have the problem with the Pytree
class from 'simple_pytrees'. For now I can subclass from the simple_pytrees Pytree
, but I would much prefer to be using the eqx.Module. Of course, I may just be doing something wrong. Any help is greatly appreciated.