equinox icon indicating copy to clipboard operation
equinox copied to clipboard

perf: enable module flatten/unflatten fastpath

Open nstarman opened this issue 3 months ago • 7 comments

This PR enables power users to write their own flattening/unflattening procedures. The advantage is that Module can be sped up to be as fast as custom pytree structures and jax arrays. I'm attaching a Jupyter Notebook showing that in this example we achieve a ~50% speedup, which is ~100% of the overhead, making raw arrays, simple pytrees, and Module all equivalently fast.

I'm happy to add documentation / tests.

overhead.ipynb.zip

nstarman avatar Oct 12 '25 08:10 nstarman

So Equinox looks to avoid special-casing any module methods. In particular it may be the case that someone already has a method called tree_unflatten etc, for some purpose unrelated to the tree map'ing of the module itself.

(The style you have here is actually what we used to have back in the distant early days of Equinox, and moved away from it.)

I'm not super what a better alternative is. Perhaps JAX might allow re-registering a PyTree with different flatten/unflatten rules.

patrick-kidger avatar Oct 12 '25 18:10 patrick-kidger

What about special-casing equinox-prefixed versions of these methods? eqx_tree_unflatten, etc. (they could even be private _eqx_tree_unflatten). That would be easy to support for fast-paths.

nstarman avatar Oct 12 '25 18:10 nstarman

@patrick-kidger if I can't crack #1119 (or even if I can, IMO it'd be nice to be able to customize) would _eqx_tree_unflatten be fine to add to Module?

nstarman avatar Oct 20 '25 20:10 nstarman

So I'm really leaning against adding something like that to eqx.Module. Part of the design thesis of eqx.Module, as compared to jax.tree_util, is that custom flatten/unflatten functions are error-prone and simply aren't necessary – it suffices to just set dataclass fields instead.

Supposing #1119 comes good, what would be your use-case?

patrick-kidger avatar Oct 21 '25 13:10 patrick-kidger

E.g. not have the _MISSING (/flatten_sentinel) logic if I'm sure modules will be fully initialized. Avoid the wrapper stuff if needed. Write a mypyc transpiled mixin class with the (un)flattening logic that removes most (not all, because this would need @mypyc_attr(allow_interpreted_subclasses=True)) of the python overhead. For frequently-in-hot-loop classes like unxt.Quantity it would be nice to be able to achieve JAX speeds.

nstarman avatar Oct 21 '25 17:10 nstarman

And with https://github.com/patrick-kidger/equinox/pull/1119, it would be cool to show the user what the (un)flattening code is doing by attaching the generated functions to the classes!

class ModuleMeta:
    def __new__(...):
        cls._eqx_tree_flatten, cls._eqx_tree_unflatten = generate_functions(cls)
        jax.tree_util.register_stuff(cls._eqx_tree_flatten, cls._eqx_tree_unflatten)

class MyClass(eqx.Module):
    attr1: float

MyClass._eqx_tree_flatten?
>>> def flatten(self):
...  return (self.attr1,), ()

nstarman avatar Oct 21 '25 17:10 nstarman

I've rebased this PR on #1119. And I've changed it so that the (un)flattening methods start with _eqx_

New approximate timings on that performance notebook:

  • Baseline is @partial(jax.tree_util.register_pytree_with_keys_class) ~ 7.7 µs
  • Module before #1119 ~ 12.5 µs
  • Module after #1119 ~ 10 µs
  • Module after #1119, if I delete all wrapper field related code ~ 9 µs
  • This PR ~ 8.3 µs using a handwritten (un)flattening w/out the wrapper fields. So its ~1 µs faster and it's the power-user's choice for not including the wrapper field stuff :).

This PR allows for an ~88% improvement ((12.5 - 8.3) / (12.5 - 7.7)) if a power user wants to write these methods. The default (from #1119) is a ~50% improvement.

nstarman avatar Nov 05 '25 01:11 nstarman