perf: enable module flatten/unflatten fastpath
This PR enables power users to write their own flattening/unflattening procedures.
The advantage is that Module can be sped up to be as fast as custom pytree structures and jax arrays.
I'm attaching a Jupyter Notebook showing that in this example we achieve a ~50% speedup, which is ~100% of the overhead, making raw arrays, simple pytrees, and Module all equivalently fast.
I'm happy to add documentation / tests.
So Equinox looks to avoid special-casing any module methods. In particular it may be the case that someone already has a method called tree_unflatten etc, for some purpose unrelated to the tree map'ing of the module itself.
(The style you have here is actually what we used to have back in the distant early days of Equinox, and moved away from it.)
I'm not super what a better alternative is. Perhaps JAX might allow re-registering a PyTree with different flatten/unflatten rules.
What about special-casing equinox-prefixed versions of these methods? eqx_tree_unflatten, etc. (they could even be private _eqx_tree_unflatten). That would be easy to support for fast-paths.
@patrick-kidger if I can't crack #1119 (or even if I can, IMO it'd be nice to be able to customize) would _eqx_tree_unflatten be fine to add to Module?
So I'm really leaning against adding something like that to eqx.Module. Part of the design thesis of eqx.Module, as compared to jax.tree_util, is that custom flatten/unflatten functions are error-prone and simply aren't necessary – it suffices to just set dataclass fields instead.
Supposing #1119 comes good, what would be your use-case?
E.g. not have the _MISSING (/flatten_sentinel) logic if I'm sure modules will be fully initialized. Avoid the wrapper stuff if needed. Write a mypyc transpiled mixin class with the (un)flattening logic that removes most (not all, because this would need @mypyc_attr(allow_interpreted_subclasses=True)) of the python overhead. For frequently-in-hot-loop classes like unxt.Quantity it would be nice to be able to achieve JAX speeds.
And with https://github.com/patrick-kidger/equinox/pull/1119, it would be cool to show the user what the (un)flattening code is doing by attaching the generated functions to the classes!
class ModuleMeta:
def __new__(...):
cls._eqx_tree_flatten, cls._eqx_tree_unflatten = generate_functions(cls)
jax.tree_util.register_stuff(cls._eqx_tree_flatten, cls._eqx_tree_unflatten)
class MyClass(eqx.Module):
attr1: float
MyClass._eqx_tree_flatten?
>>> def flatten(self):
... return (self.attr1,), ()
I've rebased this PR on #1119.
And I've changed it so that the (un)flattening methods start with _eqx_
New approximate timings on that performance notebook:
- Baseline is
@partial(jax.tree_util.register_pytree_with_keys_class)~ 7.7 µs -
Modulebefore #1119 ~ 12.5 µs -
Moduleafter #1119 ~ 10 µs -
Moduleafter #1119, if I delete all wrapper field related code ~ 9 µs - This PR ~ 8.3 µs using a handwritten (un)flattening w/out the wrapper fields. So its ~1 µs faster and it's the power-user's choice for not including the wrapper field stuff :).
This PR allows for an ~88% improvement ((12.5 - 8.3) / (12.5 - 7.7)) if a power user wants to write these methods. The default (from #1119) is a ~50% improvement.