ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
Add `differential_type`, to test for non-differentiability
This adds a function to check whether a given type is non-differentiable. The purpose is to let you test whether to take the trivial path for some rule.
It goes by whether ProjectTo(x) infers to be a trivial projector. This means it should always be costless, rather than iterating through x = Any[true, false] it will give up and return false.
Possibly this should return something whose type is the indication, like Val(true)? Or return one of AbstractZero / Number / Any, or some trait struct IsNonDiff, seems a bit heavyweight...
Codecov Report
Merging #528 (7c8e040) into main (2f76da0) will increase coverage by
0.04%. The diff coverage is100.00%.
@@ Coverage Diff @@
## main #528 +/- ##
==========================================
+ Coverage 93.02% 93.07% +0.04%
==========================================
Files 15 15
Lines 860 866 +6
==========================================
+ Hits 800 806 +6
Misses 60 60
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/ChainRulesCore.jl | 100.00% <ø> (ø) |
|
| src/projection.jl | 97.36% <100.00%> (+0.06%) |
:arrow_up: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update 2f76da0...7c8e040. Read the comment docs.
🚲 Naming-wise, i wonder if "differentiable" should be reserved for functions (/callables), like in @non_differentiable f(x, y) @non_differentiable T(x)? Partly because of the existence of the @non_differentiable macro; on naming alone, i'd guessed this was related to @non_differentiable (not so much ProjectTo), i.e. is_non_differentiable(f, args...) if and only if we had @non_differentiable f(args...).
I think "perturb[able]" is a word that's been used in this package when talking about types (e.g. in CRTestUtils here and here). So that could be an option if we wanted to avoid "differentiable" for types.
(That said, a quick search suggests that Switft for TF thought "differentiable types" was a good name for "types that can be used as arguments and results of differentiable functions" https://github.com/tensorflow/swift/blob/f0d6c74ef5d016046afc1eac0b07a2f6b74b8fdf/docs/DifferentiableTypes.md)
Also, i wonder if we want to flip the sign (for ease of understanding) e.g. is_differentiable not is_non_differentiable (double-negatives like "is_non_diff(T) == false so T is not non-differentiable" can be hard to follow sometimes)?
I don't love the name.
One reason for the apparent double negative is that the present implementation fails to false. It answers the question "are we certain we can take the trivial path here?"
Latest commit changes this to return a type, not a function. This means you test differential_type(x) <: AbstractZero to see whether x is known to have no derivative.
In general it returns the T in ProjectTo{T}. I'm not certain that has any other uses. But at least it's more obviously something read out from the projection machinery, rather than being an independent concept.