equinox
equinox copied to clipboard
Best practices to convert `torch.nn.Module` to `eqx.Module`
Thanks for the great package!
I was wondering whether there was some documentation regarding the best practice for converting torch.nn.Module
to eqx.Module
?
In particular
- It is quite clear that the
register_parameters
would be replace by an attribute .e.g.weights: Array
- How one should handle
register_buffer
? - and
add_module(name, intertwiner_basis)
? especially when thename
is not known in advanced e.g.f"module_{variable}"
Thanks a lot!
(for context I'm looking at porting escnn
to jax
cf https://github.com/QUVA-Lab/escnn/issues/55)
Hey there!
Buffers are most simply handled by storing them as an array (just like a parameter) and then calling jax.lax.stop_gradient(self.my_buffer)
when you access them in __call__
.
Alternatively, you can follow the freeze parameter example, which involves passing them through jax.grad
as a nondifferentiable argument.
Modules are typically stored as attributes just like parameters, e.g. see the source code for eqx.nn.MLP
.
If you need dynamically-named parameters then you can store those in a dictionary, and then store the dictionary on the parent module. Modules themselves are not variadically-sized.
(Note that this dynamism should only happen at __init__
time. At call time it is strongly recommended not to mutate your model, as it's easy to lose track of the changes when flattening/unflattening across JIT/grad/etc boundaries.)
Thanks @patrick-kidger for your answer! :)
I'd still have a question on how best handle the following scenario, where I have a linear layer which matrix M
is given on a basis B
(fixed) with coefficients W
(to learn).
With eqx.tree_inference(layer, True)
I can change the value of layer.inference
, but I'd like to also store the matrix M = B @ W
at evaluation time.
I feel that the two options are
- returning a new layer when calling
layer = layer.eval()
passing as (optional) argumentM
- passing
M
as a state (as in the batch norm layer)
I would be really keen on knowing your opinion, as none of the two options seems ideal :)
class Linear(eqx.Module):
W: Array
B: Array
M: Array
inference: bool
def __init__(self, ...):
self.B = ...
self.W = ...
self.inference = inference
def __call__(self, x):
if self.inference:
return self.M @ x
else:
return self.B @ self.W @ x
def eval(self):
self.M = self.B @ self.W
I think doing the conversion before inference time probably makes most sense. Here's an example of training a linear layer with a symmetric weight matrix:
#
# Train-time: resolve on-the-fly.
#
class Symmetric(eqx.Module):
array: Array
def get(self):
return 0.5 * (self.array + self.array.T)
is_symmetric = lambda x: isinstance(x, Symmetric)
@eqx.filter_jit
def train_step(model, ...):
model = jax.tree_util.tree_map(lambda x: x.get() if is_symmetric(x) else x, model, is_leaf=is_symmetric)
... # compute gradients, update, etc.
model = eqx.nn.Linear(...)
model = eqx.tree_at(lambda m: m.weight, model, replace_fn=Symmetric)
for _ in range(steps):
model = train_step(model, ...)
#
# Inference time: perform conversion.
#
inference_model = eqx.tree_inference(model, True)
inference_model = jax.tree_util.tree_map(lambda x: x.get() if is_symmetric(x) else x, inference_model, is_leaf=is_symmetric)
inference_model(...) # evaluate
Doing some kind of train->inference conversion is pretty common -- e.g. quantisation, pruning, absorbing adjacent batchnorm and linear layers into a single linear transformation, etc. etc.
Also, note that I don't do something like self.M = self.B @ self.W
. You can't assign to eqx.Module
s outside of __init__
-- much like tuples, they are immutable. You should use eqx.tree_at
to create an out-of-place update instead.
This is a deliberate design choice, as it helps to reason about changes in the presence of jit, grad, etc.
Thanks that's really useful, wasn't aware of eqx.tree_at
!
I eventually implemented something like
class Linear(eqx.Module):
....
def eval():
new = eqx.tree_inference(self, True)
return eqx.tree_at(lambda m: m.matrix, new, replace=matrix)
model = Linear(...)
model = model.eval()
@patrick-kidger
In my eqx.Module
I have some Array
attributes which I want to learn (i.e. parameters) and other Array
attributes which aren't and that I'm only setting at evaluation time with eqx.tree_at(lambda m: m.matrix, new, replace=matrix)
.
Would there be a similar way to params, static = eqx.partition(model, eqx.is_array)
but which would filter out the non parameters Array
attributes?
Yep, this is totally possible. First of all, if you just want to have non-learnt arrays then call lax.stop_gradient
after access:
def __call__(self, ...):
buffer = lax.stop_gradient(self.buffer)
# now use `buffer` wherever you want to use it.
If you need to do something more complicated with filtering, then you can use a wrapper class:
class FooArray(eqx.Module):
array: Array
class Model(eqx.Module):
def __init__(self, ...):
self.foo = FooArray(some_array)
...
model = Model(...)
is_foo = lambda x: isinstance(x, FooArray)
has_foo, no_foo = eqx.partition(model, is_foo, is_leaf=is_foo)
Here's a fully-fledged example for creating a linear transformation with a symmetric matrix.
Thanks @patrick-kidger!
fyi I've been working on porting escnn
to jax
& equinox
as the only Jax
supported equivariant NN library is e3nn_jax
and it only support the O(3)
whilst escnn
supports many subgroups (and it does not so far support equinox
) .
See for instance the EquivariantModule
and Linear
classes and an MNIST example with this escnn_jax
library.
If that's something you're interested in and/or have any suggestions/remarks I'd be keen on hearing them :)
Thanks! I've just had a quick look.
- I recommend against using Distrax, it's known to have a number of correctness/performance issues: #252, #269.
- It's wildly experimental so please don't actually depend on it, but I spotted GeometricTensor, and so you might find Quax interesting, as a way to represent array-likes with some extra metadata etc.
- You don't need to have inherit from
abc.ABC
; all Equinox classes are ABCs out-of-the-box. In fact they also support abstract attributes and abstract class attributes, as an extension on top of Python's built-in ABCs. -
This should probably be an iteration-over-pytree, not just
layers
only? Also note that you're technically doing O(n^2) work by callingtree_inference
at multiple tree depths. - I'd really recommend using an autoformatter+linter etc. -- take a look at Equinox's pre-commit hooks for inspiration. Likewise I'd recommend using the modern pyproject.toml approach over the older
setup.py
.
This should probably be an iteration-over-pytree, not just layers only? Also note that you're technically doing O(n^2) work by calling tree_inference at multiple tree depths.
Regarding this, I completely agree, how can I achieve this? with something like the following?
is_layer = lambda m: isinstance(m, eqx.Module)
new = jax.tree_util.tree_map(lambda m: m.train(mode), self, is_leaf=is_layer)
is new = eqx.tree_at(...)
needed?
@patrick-kidger would you have an idea by any chance whether it's usually better/faster in Jax
when 'filling in an array' to (1) create an empty array, iterate and fill values with arr = are.at[...].set(...)
, or (2) creating an empty list arr = []
, iterate whilst appending values, and finally concatenating them?
Also to handle both statelful and stateless modules I found myself adding something like
for layer in self.layers:
if "state" in inspect.signature(layer).parameters:
x, state = layer(x, state)
else:
x = layer(x)
Is there any way around? Could wrap the stateless module with
def state_wrapper(layer: eqx.Module):
if "state" in inspect.signature(layer).parameters:
return layer
else:
return lambda x, state: layer(x), state
or something like eqx.nn.Lambda
.
Would it be worth adding to eqx.nn.Sequential
an optional state: eqx.nn.State = None
argument to pass along and if not None
would be returned? The could wrap everything into a eqx.nn.Sequential
layer and simply call it once.
Regarding this, I completely agree, how can I achieve this? with something like the following? is_layer = lambda m: isinstance(m, eqx.Module) new = jax.tree_util.tree_map(lambda m: m.train(mode), self, is_leaf=is_layer) is new = eqx.tree_at(...) needed?
I'd recommend against this. Equinox modules are really just pytrees like any other, so it's not appropriate to special case them. Moreover what if there is some non-E3NN-Module that doesn't implement a .train
method at all?
In the spirit of nominative subtyping, I would instead recommend the following pattern:
# Declare that this method should exist
class E3NNModule(eqx.Module):
@abc.abstractmethod
def train(self, mode):
...
# Now go looking for such layers, knowing that the train method must exist.
is_layer = lambda m: isinstance(m, E3NNModule)
new = jax.tree_util.tree_map(lambda m: m.train(mode), self, is_leaf=is_layer)
# On your concrete classes, go ahead and provide an implementation.
class SomeModule(E3NNModule):
def train(self, mode):
If you have nested E3NNModule
s inside each other, make sure that the .train
methods of the wrapper module call the .train
method of the wrapped module.
Incidentally the above is exactly the sort of thing I do very widely across my JAX libraries -- I'm a big fan of using ABCs to explicitly declare what interfaces are available.
would you have an idea by any chance whether it's usually better/faster in Jax when 'filling in an array' to (1) create an empty array, iterate and fill values with arr = are.at[...].set(...), or (2) creating an empty list arr = [], iterate whilst appending values, and finally concatenating them?
I would recommend (2). JAX's heuristics for in-place updates are sometimes not great.
Also to handle both statelful and stateless modules I found myself adding something like
Hmm, these aren't really designed to be used interchangeably. After all, one could easily define a module with a completely arbitrary custom signature, it's not like the only two valid ones are (x,)
and (x, state)
. Conversely, if someone defines a module with signature (x, foo, state)
-- that just so happens to use the name state
-- then your check will trigger incorrectly.
Stateful layers are pretty unusual -- in particular batchnorm is used very infrequently. What's your use case?
def state_wrapper(layer: eqx.Module):
if "state" in inspect.signature(layer).parameters:
return layer
else:
return lambda x, state: layer(x), state
Note that this snippet is dangerous. The else
branch does not return a pytree, so any parameters inside layer
will not actually be trained.