Slow Initialization / JAX tree usage
Hi @patrick-kidger. I've recently been speed testing some code and found that Equinox is around 2x slower than a custom pytree.
import jax
import equinox as eqx
def func(x, y):
return (x - y) / (x + y)
x = jnp.linspace(0.0, 1, 10)
y = jnp.linspace(1.0, 2, 10)
print(jax.make_jaxpr(func)(x, y))
# { lambda ; a:f32[10] b:f32[10]. let
# c:f32[10] = sub a b
# d:f32[10] = add a b
# e:f32[10] = div c d
# in (e,) }
f = jax.jit(func)
f(x, y)
# %timeit f(x, y)
# 2.87 µs ± 95 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Now with a custom Module
class MyEqxArray(eqx.Module):
array: Array
def __add__(self, other):
return jax.tree.map(jnp.add, self, other)
def __sub__(self, other):
return jax.tree.map(jnp.subtract, self, other)
def __truediv__(self, other):
return jax.tree.map(jnp.divide, self, other)
mx = MyEqxArray(x)
my = MyEqxArray(y)
func(mx, my)
print(jax.make_jaxpr(func)(mx, my)) # same jaxpr 👍
# { lambda ; a:f32[10] b:f32[10]. let
# c:f32[10] = sub a b
# d:f32[10] = add a b
# e:f32[10] = div c d
# in (e,) }
f = jax.jit(func)
f(mx, my)
%timeit f(mx, my)
# 8.4 µs ± 969 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Now with a custom PyTree
from dataclasses import dataclass
@jax.tree_util.register_pytree_node_class
@dataclass
class CustomArray:
array: Array
def __add__(self, other):
return jax.tree.map(jnp.add, self, other)
def __sub__(self, other):
return jax.tree.map(jnp.subtract, self, other)
def __truediv__(self, other):
return jax.tree.map(jnp.divide, self, other)
def tree_flatten(self) -> tuple[tuple[Any], Any]:
return (self.array,), None
@classmethod
def tree_unflatten(cls, aux_data: Any, children: tuple[Any]) -> "CustomArray":
return cls(*children)
f = jax.jit(func)
f(mx, my)
# %timeit f(mx, my)
# 4.02 µs ± 960 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
So the timings are array:Module:custom = 2.87 : 8.4 : 4.02.
Is there any way to speed up Module?
Running this code on a fresh colab environment (and some block_until_readys), I see array is faster, but custom is slower
24.5 µs ± 9.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
31.6 µs ± 5.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
33.7 µs ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
However, the timing doesn't seem strictly additive, if I increase x = jnp.linspace(0.0, 1, 1000000)
838 µs ± 388 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.02 ms ± 506 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.38 ms ± 588 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
I also repeated this with block_until_ready and x = jnp.linspace(0.0, 1, 10_000)
JAX: 8.09 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Equinox: 16.9 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
PyTree: 10.2 µs ± 2 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
The Equinox overhead still seems large.
Just spitballing here, but there seem to be 3 option:
- It's possible to speed up Equinox to be the same speed as the custom PyTree 🎉. This is my preferred solution!
- It isn't and this is a not-going-to-fix 😢
- It isn't because of some of the fancy stuff in Module, e.g. locking/unlocking
__init__, etc, but there's interest in speed ups. Then maybe a good solution would be to add an ABC —AbstractModule— and then also vendor a faster bare-bonesFastModulethat doesn't do the fancy slow stuff (exceptfield(converter=...)which I've tested to be fast). (I've thought along these lines indataclassishand did a speed test with that custom dataclass + converter with results identical to the PyTree case in this Issue). With an ABC, ecosystem tools that expect a Module can be trivially adapted to work with FastModule.
I think what you're measuring here is an additive overhead of microseconds in the flattening and unflattening. This is expected/known, and rarely ever troublesome. Your computational work has to be essentially negligible for this to affect real-world results.
However, the timing doesn't seem strictly additive
I think you're measuring the noise in the actual operation itself here. Those standard deviations are pretty large, and overlap quite a lot!
in the flattening and unflattening
Where does Module do that differently than the above example jax.tree_util.register_pytree_node_class PyTree?
This may be the important point! CustomArray is a a PyTree like MyEqxArray. Why is MyEqxArray slower / where is the flattening overhead?
Your computational work has to be essentially negligible for this to affect real-world results.
This appears to be non-negligible in quax-derived objects where the overhead happens many times.
At least this is what I've found thus far in trying to figure out why unxt and coordinax operations are slow.
I purposefully haven't jitted any of the quaxify(jax.foo) in https://github.com/GalacticDynamics/quaxed nor the dunder methods in https://github.com/GalacticDynamics/quax-blocks. But when I use units things are much slower
import unxt as u
def convert_cart2d_to_polar(params, aux):
x, y = params["x"], params["y"]
r = jnp.sqrt(x**2 + y**2)
theta = jnp.arctan2(y, x)
return {"r": r, "theta": theta}, aux
params = {"x": u.Quantity(jnp.array([1.0, 2.0]), "m"), "y": u.Quantity(jnp.array([3.0, 4.0]), "m")}
aux = {}
jac, aux = jax.jacfwd(convert_cart2d_to_polar, has_aux=True)(params, aux)
jac
# {'r': Quantity['length']({'x': Quantity['length'](Array([[0.31622777, 0. ],
# [0. , 0.4472136 ]], dtype=float64), unit='m'), 'y': Quantity['length'](Array([[0.9486833 , 0. ],
# [0. , 0.89442719]], dtype=float64), unit='m')}, unit='m'),
# 'theta': Quantity['angle']({'x': Quantity['length'](Array([[-0.3, 0. ],
# [ 0. , -0.2]], dtype=float64), unit='m'), 'y': Quantity['length'](Array([[0.1, 0. ],
# [0. , 0.1]], dtype=float64), unit='m')}, unit='rad')}
# NOTE: the weird nesting is something else I'm trying to correct,
# but it appears to be because jax.jacfwd doesn't have an `is_leaf`
func = jax.jit(jax.jacfwd(convert_cart2d_to_polar, has_aux=True))
func(params, aux)
%timeit jax.block_until_ready(func(params, aux))
# 26.6 µs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# vs
params2 = {k: v.value for k, v in params.items()}
func(params2, aux)
%timeit jax.block_until_ready(func(params2, aux))
# 9.65 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
So it's 2.75x faster not to use a Quantity. If I don't jit then it becomes 8x faster.
in the flattening and unflattening
Where does Module do that differently than the above example
jax.tree_util.register_pytree_node_classPyTree? […] where is the flattening overhead?
It handles different types of fields and includes some checks:
https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/_module.py#L908
https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/_module.py#L953
I took a look at your MWE - it looks like we're seeing the same additive overhead as above. If you have many small modules and do tiny operations on all of these, you might indeed be in the regime in which this starts to matter. Not sure if you can batch instead? But I'm unfamiliar with quax, so I can't really say anything about that!
The only thing I can add -
# NOTE: the weird nesting is something else I'm trying to correct, # but it appears to be because jax.jacfwd doesn't have an `is_leaf`
your Jacobian actually looks exactly as expected. You have two elements in your radians array, and two angles, corresponding to two different points in Cartesian space. Both $r$ and $\theta$ have derivatives with respect to each of the elements of $x$ and $y$, they happen to be diagonal matrices because you take the Pythagorean sum element-wise.
To make this more readable and intuitive, you might want to try wl.pprint from the wadler_lindig library that Equinox now uses under the hood.
I took a look at your MWE - it looks like we're seeing the same additive overhead as above. If you have many small modules and do tiny operations on all of these, you might indeed be in the regime in which this starts to matter. Not sure if you can batch instead? But I'm unfamiliar with quax, so I can't really say anything about that!
Unfortunately batching isn't possible. Yes, the tiny operations with small modules appears to be the case with quax.
@patrick-kidger Is it possible in quax (or maybe via equinox) to provide a custom tree_flatten / tree_unflatten and have Module use that? To override the default behavior...
The only thing I can add -
# NOTE: the weird nesting is something else I'm trying to correct, # but it appears to be because jax.jacfwd doesn't have an `is_leaf`your Jacobian actually looks exactly as expected.
Yes, but see that the inner dicts are inside outer Quantity objects. This is because jacfwd can't stop at a certain leaf level, eg Quantity, rather than looking deeper to the underlying arrays.
But that's a separate issue.
@johannahaffner thank you for identifying that this is still the additive overhead!
@nstarman I think what'd probably be most desirable is if we can just speed up the existing flattening/unflattening implementation along the lines of whatever alternative you have in mind!
Taking unflattening as an example, pretty much the only difference between what we already have, and something that just assigns attributes self.foo = foo; self.bar = bar is that the latter hardcodes the attributes (rather than doing an iteration).
If this is indeed the source of the overhead then it wouldn't be very hard to dynamically generate such a 'hardcoded' function for each new Module.
@patrick-kidger I think that it's super worthwhile to speed up the flattening/unflattening implementation, however, this still can't provide the largest speed-up, which is to use https://docs.jax.dev/en/latest/_autosummary/jax.tree_util.register_dataclass.html.
What a solution be to have a metaclass argument to prevent pytree registration, so the user can manually control the registration?
Supporting this would be a trivial if statement in ModuleMeta and then we'd be able to do
@ft.partial(jax.tree_util.register_dataclass, ...)
class MyClass(eqx.Module, register_pytree=False): ...
The default would be register_pytree=True and I presume almost everyone would use this, but it does allow for manual performance optimizations.
I'd be happy to submit a PR on this!
The JAX flavour of dataclass registration behaves differently for __init__: https://github.com/jax-ml/jax/issues/25486
That could mean that we get very different behaviour and limitations depending on how a potential flag would be set.
What happens if you just add the decorator on top of an existing class?
How does the dataclass approach benchmark against codegen approach above?
I'd like to avoid adding extra flags, this quickly makes the easy path and the high-performance path diverge, which leads to bad UX.
The JAX flavour of dataclass registration behaves differently for
__init__: https://github.com/jax-ml/jax/issues/25486That could mean that we get very different behaviour and limitations depending on how a potential flag would be set.
A good point. But it also points to a good path forward. Equinox Modules are dataclasses and could also benefit from the same speed ups used inside register_dataclass. Much of https://github.com/jax-ml/jax/blob/main/jax/_src/tree_util.py#L923-L1089 is just parsing and then it registers
default_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
Is there any reason Equinox couldn't do the exact same using its own flatten/unflatten!? Then for Equinox Modules the "C++ registries [would] use the optimized C++ dataclass builtin instead of the argument functions".
What happens if you just add the decorator on top of an existing class?
Double registration! It errors.
How does the dataclass approach benchmark against codegen approach above?
I can benchmark an example. Which codegen approach are you referring to? Existing Equinox?
Is there any reason Equinox couldn't do the exact same using its own flatten/unflatten!?
I don't think this works with custom __init__ functions: https://github.com/jax-ml/jax/issues/25486
Although if you have a way to finesse this then I'm certainly interested.
Which codegen approach are you referring to?
To dynamically eval a function equivalent to our existing flattening, but with the for loop and if branches hardcoded for each module, so it's just a flat list of attribute-getting.
Although if you have a way to finesse this then I'm certainly interested.
Yeah. Just replacing https://github.com/patrick-kidger/equinox/blob/b4f9addedca56bb74b4fe06674e5566fa6cf9ff4/equinox/_module.py#L481-L487 with
default_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) # HERE
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
Where this uses Equinox's custom flatten_func, unflatten_func. In another PR we can work to optimize the (un)flatten funcs, but this PR would enable Equinox Modules to use the optimized C++ dataclass structures.
I don't think that works, for the same reason as the issue I link:
import dataclasses
import jax._src.tree_util
@dataclasses.dataclass
class X:
x: int
def __init__(self, x: str):
assert type(x) is str, f"Got type {type(x)}"
self.x = int(x)
reg = jax._src.tree_util.default_registry
reg.register_dataclass_node(X, [], ["x"])
leaves, treedef = reg.flatten(X("1"))
treedef.unflatten(leaves) # AssertionError: Got type <class 'int'>
The JAX C++ path does not understand custom __init__ methods.
(In addition it would also be messing with JAX-internal implementation details.)
I think if this is a thing you want then the necessary change would be to submit a PR to JAX, adjust their C++ implementation of dataclass unflattening to support custom __init__ methods.
I think if this is a thing you want then the necessary change would be to submit a PR to JAX, adjust their C++ implementation of dataclass unflattening to support custom init methods.
I tried just replacing the Python unflatten in https://github.com/jax-ml/jax/blob/92be510f0b504d8f87a181721801ed759886dedc/jax/_src/tree_util.py#L924 with a version that bypasses __init__ and it doesn't work.
It looks like yes, I'd have to modify the C++ implementation. I think there's a workable approach to this based on dataclassish's implementation of building a separate __dataclass_init__ and using that in the unflattening instead of the user's __init__.
https://github.com/GalacticDynamics/dataclassish/blob/b22db248de44fe16bc323cbd4a1067e6d4747d22/src/dataclassish/_src/converters.py#L282-L297
Yeah, looks like something like this might be possible:
https://github.com/jax-ml/jax/blob/9f5f6edb85487569127429c7ac8be70b3d8cb2f9/jaxlib/xla/pytree.cc#L880-L891
However I'm not sure JAX will want to support a different kind of dataclass that has an alternative constructor.
There are a few possibilities:
- make a PR in JAX for a
PyTreeKind::kDataclassAltwhich is a dataclass with an alternative constructor, e.g.classmethod(unflatten(...)) - make a PR in JAX to generalize
PyTreeKind::kDataclassto avoid constructing viatype(**kwargs). This feels very challenging Python has stuff like descriptor-typed-fields. - Try to implement this C++ stuff here in equinox. I'm not sure if it's possible to hook into the PyTree machinery at the C++ level outside of JAX...
- Something else.
I suspect option 2 might be reasonable. FWIW JAX already doesn't support the full gamut of dataclass things -- in particular as here, custom __init__ methods -- so it's already an improvement to handle these.
In addition we have eqx.field(converter=...) which partly obviates the need for descriptor-typed fields anyway.
suspect option 2 might be reasonable.
My hesitation with 2 is how challenging this will be to do right. dataclass takes care of defining a constructor that can handle all the edge cases so type(**kwargs) trivially works. But yes, this does mean that JAX can't handle custom __init__. The way around this is to define in C++ the dataclass constructor. Hard. I don't think Python itself hasn't even gotten around to doing dataclasses in C.
But for both option 1 and 2 I think I should open an Issue on JAX itself!
On a separate note:
I think that a potential speedup to initialization could happen by avoiding all the _make_initable_wrapper(cls) machinery if there wasn't a custom __init__.
- https://github.com/patrick-kidger/equinox/blob/4995b2bed015d6922ca46868cbaf59c767b44682/equinox/_module.py#L510
- https://github.com/patrick-kidger/equinox/blob/4995b2bed015d6922ca46868cbaf59c767b44682/equinox/_module.py#L512
- https://github.com/patrick-kidger/equinox/blob/4995b2bed015d6922ca46868cbaf59c767b44682/equinox/_module.py#L555
On a separate note: I think that a potential speedup to initialization could happen by avoiding all the _make_initable_wrapper(cls) machinery if there wasn't a custom init.
On this, I actually have a prototype overhaul of the definition of eqx.Module. The current implementation has grown from its original 100-ish line definition into something fairly unwieldy! Amongst other things this removes the initable-cls machinery in favor of checking to see if all attributes are set, and using that as a proxy for whether we are in __init__.
Not 100% sure how this will land yet, but it's on my radar.
Something I was thinking of suggesting was changing _has_dataclass_init from a dict to a private final variable
class Module:
_HAS_DATACLASS_INIT: Final
It should work the same way, but is perhaps faster and can help in the _make_initable_wrapper simplification.
This runs afoul of the 'no custom attributes on Module' rule we have. Ever-so-hypothetically a user may have defined a _HAS_DATACLASS_INIT attribute themselves.
It's a bit of a philosophical point for something with an unusual-to-pick name, but I like knowing that Module is completely free of such points.
A different solution is to inspect the __dataclass_params__.init value!
That looks to store has_dataclass_init just like the _has_dataclass_init dict does. And since both require a traversal of the MRO they should be equivalent.
Ah, that's a really good point! I hadn't appreciated that. I'll keep that in mind when I work on simplifying the definition.