Function to tell if you are being differentiated
Quite a few packages define a function to tell whether they are inside AD, as of https://github.com/FluxML/Flux.jl/pull/1863/files#r806287154 Flux has:
istraining() = false
# @adjoint istraining() = true, _ -> nothing
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
Maybe CRC should provide one?
It could be true/false. It could also be Val(true), or it could something more elaborate like:
order() = (fwd=0, rev=0, total=0)
What thoughts?
I am tempted to say we don't need this (since AD frameworks already provide it). And that it is out of scope for ChainRules.
However, it is a common operation and having a ChainRules way to do this would help make code easier to switch between AD systems. So I am in favour.
Yes Zygote has such a function, but packages may want to check without depending on it. Like NNlib here, where it's guarding a mutating path... seems likely to be the right thing even for non-Zygote reverse AD.
Like NNlib here, where it's guarding a mutating path... seems likely to be the right thing even for non-Zygote reverse AD.
I guess this specific use case could also be solved by an official SupportsMutation trait, as sketched in the docs?
Or lower-tech for that use, ensure that is_inplaceable_destination gives false within gradients. (Although I'm not sure I understand its rules... why is ChainRulesCore.is_inplaceable_destination([1]) true, for integers?)
I guess this specific use case could also be solved by an official SupportsMutation trait, as sketched in the docs?
It's not clear that this is actually the right way to do this. We might actually want a totally different API for mutating rules. We should probably remove that from the docs and put a more generic example.
why is
ChainRulesCore.is_inplaceable_destination([1])true, for integers?)
Because it only considers if the array type is one that can be mutated, not that it can be mutated to contain a particular type of element.
However, it is a common operation and having a ChainRules way to do this would help make code easier to switch between AD systems. So I am in favour.
However, it is a common operation and having a ChainRules way to do this would help make code easier to switch between AD systems. So I am in favour.
I agree. In particular this is similar to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/470 and we should do it here.
This turns out to be tricker than expected. See discussion in https://github.com/FluxML/NNlib.jl/pull/434, but the tl;dr is that istraining above relies on AD not being smart enough to notice that the condition in cond ? a : b can never see a gradient.