PyTorch Hooks in Equinox?
I would like to record some model activations in an architecture-invariant way. In PyTorch, we can use forward hooks to do this, by registering a hook on modules that match some criteria (maybe all modules that are an MLP class, for example).
Is there an equivalent strategy in Equinox?
One idea is to create a class Wrapper(eqx.Module) that simply wraps a module and calls some callback in __call__ with the underlying module's activations, then somehow replace modules in an equinox module.
class Wrapper(eqx.Module):
wrapped: eqx.Module
callback: ...
def __init__(self, module, callback):
self.wrapped = module
self.callback = callback
def __calll__(self, *args, **kwargs):
outs = self.wrapped(*args, **kwargs)
self.callback(outs) # this would save to disk or something
Then in the main script, I could do something like:
model = MyViT()
for i in range(n_layers):
model = eqx.tree_at(lambda m: m.layers[i].mlp, replace_fn=lambda m: Wrapper(m, my_callback))
Is there a better/more obvious way to do this?
You can use jax.tree.leaves to get all Modules you want.
For example, if you need linear:
is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_linear = lambda m: [x
for x in jax.tree.leaves(m, is_leaf=is_linear)
if is_linear(x)]
linears = get_linear(model)
wrapped = [Wrapper(m, callback) for m in linears]
eqx.tree_at(get_linear, model, wrapped)
However, I'm not sure if your callback could run in the jitted module.
Wow that's really neat, I can try that. I think I can use jax.debug.callback or jax.experimental.io_callback--not sure which will be better.