DifferentiationInterface.jl icon indicating copy to clipboard operation
DifferentiationInterface.jl copied to clipboard

Add common interface to mark functions as non-differentiable across backends

Open adrhill opened this issue 1 year ago • 6 comments

This could be implemented in a similar way to DifferentiateWith.

adrhill avatar Aug 19 '24 14:08 adrhill

The trouble will be the same as for DifferentiateWith: it won't work with all backends. For instance there's no way to tell FiniteDiff not to differentiate through a function (unlike Zygote or Enzyme).

But here it is worse than for DifferentiateWith because several backends may lead to different answers without erroring. At least, when it doesn't error, DifferentiateWith gives the same output regardless of whether the custom chain rule is hit.

gdalle avatar Aug 19 '24 16:08 gdalle

So I guess there are 2 different use cases here:

  1. Marking a function that actually would otherwise make a non-zero contribution to the differential as non-differentiable, thus changing the differential.
  2. Marking as non-differentiable a function call that cannot contribute to the differential to avoid either performance issues or errors raised by the AD backend due to unsupported language features.

It seems like (1) is tricky to support for the reasons in https://github.com/gdalle/DifferentiationInterface.jl/issues/415#issuecomment-2296967975, but (2) should be safe, right? Since with e.g. FiniteDifferences we wouldn't want or need the behavior to change at all, while with operator-overloading ADs we'd want to strip away types and with source-to-source ADs we'd want to tell them to just compile the usual function.

sethaxen avatar Aug 19 '24 16:08 sethaxen

I guess you're right for 2. Essentially this would only concern

Can you think of any other backend where it would work?

gdalle avatar Aug 19 '24 17:08 gdalle

With Enzyme, I think one could use inactive: https://enzyme.mit.edu/index.fcgi/julia/stable/generated/custom_rule/#Marking-functions-inactive. For source-to-source I think one would want it to also work for Tapir if possible.

It should ideally impact all of the operator-overloading ADs as well, at least ForwardDiff, ReverseDiff, and Tracker. Probably Symbolics also.

sethaxen avatar Aug 19 '24 17:08 sethaxen

Contributions are welcome on this, I'm not yet comfortable enough with metaprogramming to try it alone

gdalle avatar Aug 20 '24 06:08 gdalle

Inactive's semantics are different from chainrules. Marking something as inactive in Enzyme.jl says that both the operation itself doesn't transfer derivative information, but also that no value produced by the function call could not contain differentiable data in the future.

e.g. allocating an empty vector would be legal to mark as chainrules inactive, but not enzyme inactive.

In that sense enzyme inactive implies chainrules inactive (I think).

Enzyme's activity analysis also supports the ability to specify that just the instruction (e.g. function call) doesn't transfer derivative information without discussing the return, but we haven't added syntactic sugar to Julia for that yet.

wsmoses avatar Aug 20 '24 22:08 wsmoses