equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Segmentation fault in tree_flatten when subclass method is passed to superclass __init__

Open tttc3 opened this issue 1 year ago • 5 comments

I have the following code, which causes a segmentation fault when tree_flatten is called on any instance of SubComponent. I think it has to do with the jtu.Partial(method, self) in _wrap_method so may be related to #291. In any case, my guess is the following:

  1. tree_flatten goes through the leaves of component, and reaches transform.
  2. transform is a jtu.Partial(self.transform, self), due to the _wrap_method.
  3. the partial is itself a PyTree and so can itself be traversed for flattening.
  4. when the self leaf in the partial is reached, we have a leaf of type SubComponent, which can again be traversed further for flattening.
  5. this has now started a loop (returning back to step 1), that terminates with either a segmentation fault, or the maximum recursion depth error.

MWE:

from typing import Callable
import jax

import equinox as eqx

class Component(eqx.Module):
    transform: Callable[[float], float]
    validator: Callable[[Callable], Callable]

    def __init__(self, transform=lambda x: x, validator=lambda f: f) -> None:
        self.transform = transform  # Base transformation.
        self.validator = validator  # Validation wrapper around the transformation.

    def __call__(self, x):
        return self.validator(self.transform)(x)  # Execute validated transformation.


class SubComponent(Component):
    test: Callable[[float], float]

    def __init__(self, test=lambda x: 2 * x) -> None:
        self.test = test
        # From my understanding this ends up equivalent to
        # self.transform = jtu.Partial(self._transform, self)
        super().__init__(self._transform)

    # Custom implementation of transform that depends on information available in the
    # SubComponent PyTree. If `test` was modfied with tree_at, then this method and the
    # __call__, defined in the parent should both be updated.
    def _transform(self, x):
        return self.test(x)


a = SubComponent()
jax.tree_util.tree_flatten(a)  # Will cause segmentation fault
print(a)  # Will raise a maximum recursion depth error.

I don't really know a good solution, beyond the fact that I don't have the problem with the Pytree class from 'simple_pytrees'. For now I can subclass from the simple_pytrees Pytree, but I would much prefer to be using the eqx.Module. Of course, I may just be doing something wrong. Any help is greatly appreciated.

tttc3 avatar Apr 23 '23 13:04 tttc3