functorch
functorch copied to clipboard
Add a linearize like jax.linearize
jvp(f, x, tangents) computes both f(x) and the Jacobian-vector product.
The goal of linearize(f) is to return a function that can compute forward-mode AD multiple times without re-evaluating f(x).
Is there a workaround until a function like this has been properly implemented?
A workaround is to use jvp -- jvp will re-evaluate the function you pass to it every time, so the performance won't be as good.
Added in https://github.com/pytorch/pytorch/pull/94173