equinox icon indicating copy to clipboard operation
equinox copied to clipboard

[Feature Request] filter_jacfwd, filter_jacrev

Open smorad opened this issue 7 months ago • 3 comments

Equinox's filter_* functions are very helpful functions. It would be great if there werefilter_jacfwd and filter_jacrev functions.

For my specific use case, I am using eqx.filter_value_and_grad(fn)(model, ...). As far as I can tell, If I wanted to compute the Jacobian of fn with respect to the inputs, I would need to rewrite fn with the type signature fn(model_static, model_params, ...). This means I won't be able to use eqx.filter_value_and_grad or really any other eqx.filter_* on fn. Furthermore any calls to eqx.filter_* inside of fn must also be rewritten as jax.*!

smorad avatar Nov 28 '23 11:11 smorad