ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Caching computations in forward mode

Open sethaxen opened this issue 5 years ago • 2 comments

Suppose we have a function f: ℝᵐ → ℝⁿ that for some reason we want to differentiate in forward mode, which will require calling all frules m times. This seems wasteful, as the pushforwards often depend on intermediates of the primal function that don't change. In the current implementation of frules, where the output of the pushforward is computed at the same time as the output of the primal, these intermediates would need to be recomputed m times. An example is symmetric eigendecomposition, where the eigendecomposition really only needs to be computed once but will instead be computed m times.

I'm sure there are good reasons for implementing this way. One I can think of is that it's easier to support mutating rules. Are there others?

sethaxen avatar May 17 '20 02:05 sethaxen

The reason for computing the primal and tangent together is for things like e.g. the frule for an ODE solver. In this case, you really don't want to spit out a closure from the frule because it would involve caching all of the intermediate state from the solve.

In terms of evaluation at multiple tangents, the plan is to enable pushing entire bases through at the same time.

willtebbutt avatar May 17 '20 07:05 willtebbutt

To be clear we actually do support the core of https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92 already. ForwardDiff2 even makes use of it. Every frule that we have supports it.

I think the real question of https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92 is "Are there other frules that we have not written because we need to support it better some how"

oxinabox avatar May 17 '20 13:05 oxinabox