equinox
equinox copied to clipboard
[Feature Request] filter_jacfwd, filter_jacrev
Equinox's filter_*
functions are very helpful functions. It would be great if there werefilter_jacfwd
and filter_jacrev
functions.
For my specific use case, I am using eqx.filter_value_and_grad(fn)(model, ...)
. As far as I can tell, If I wanted to compute the Jacobian of fn
with respect to the inputs, I would need to rewrite fn
with the type signature fn(model_static, model_params, ...)
. This means I won't be able to use eqx.filter_value_and_grad
or really any other eqx.filter_*
on fn. Furthermore any calls to eqx.filter_*
inside of fn
must also be rewritten as jax.*
!