functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Add a linearize like jax.linearize

Open zou3519 opened this issue 3 years ago • 2 comments

jvp(f, x, tangents) computes both f(x) and the Jacobian-vector product.

The goal of linearize(f) is to return a function that can compute forward-mode AD multiple times without re-evaluating f(x).

zou3519 avatar Apr 21 '22 17:04 zou3519

Is there a workaround until a function like this has been properly implemented?

TJHeeringa avatar Sep 01 '22 08:09 TJHeeringa

A workaround is to use jvp -- jvp will re-evaluate the function you pass to it every time, so the performance won't be as good.

zou3519 avatar Sep 06 '22 13:09 zou3519

Added in https://github.com/pytorch/pytorch/pull/94173

kshitij12345 avatar Sep 20 '23 13:09 kshitij12345