ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Add `differential_type`, to test for non-differentiability

Open mcabbott opened this issue 3 years ago • 4 comments

This adds a function to check whether a given type is non-differentiable. The purpose is to let you test whether to take the trivial path for some rule.

It goes by whether ProjectTo(x) infers to be a trivial projector. This means it should always be costless, rather than iterating through x = Any[true, false] it will give up and return false.

Possibly this should return something whose type is the indication, like Val(true)? Or return one of AbstractZero / Number / Any, or some trait struct IsNonDiff, seems a bit heavyweight...

mcabbott avatar Jan 11 '22 15:01 mcabbott

Codecov Report

Merging #528 (7c8e040) into main (2f76da0) will increase coverage by 0.04%. The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #528      +/-   ##
==========================================
+ Coverage   93.02%   93.07%   +0.04%     
==========================================
  Files          15       15              
  Lines         860      866       +6     
==========================================
+ Hits          800      806       +6     
  Misses         60       60              
Impacted Files Coverage Δ
src/ChainRulesCore.jl 100.00% <ø> (ø)
src/projection.jl 97.36% <100.00%> (+0.06%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 2f76da0...7c8e040. Read the comment docs.

codecov-commenter avatar Jan 11 '22 15:01 codecov-commenter

🚲 Naming-wise, i wonder if "differentiable" should be reserved for functions (/callables), like in @non_differentiable f(x, y) @non_differentiable T(x)? Partly because of the existence of the @non_differentiable macro; on naming alone, i'd guessed this was related to @non_differentiable (not so much ProjectTo), i.e. is_non_differentiable(f, args...) if and only if we had @non_differentiable f(args...).

I think "perturb[able]" is a word that's been used in this package when talking about types (e.g. in CRTestUtils here and here). So that could be an option if we wanted to avoid "differentiable" for types.

(That said, a quick search suggests that Switft for TF thought "differentiable types" was a good name for "types that can be used as arguments and results of differentiable functions" https://github.com/tensorflow/swift/blob/f0d6c74ef5d016046afc1eac0b07a2f6b74b8fdf/docs/DifferentiableTypes.md)

Also, i wonder if we want to flip the sign (for ease of understanding) e.g. is_differentiable not is_non_differentiable (double-negatives like "is_non_diff(T) == false so T is not non-differentiable" can be hard to follow sometimes)?

nickrobinson251 avatar Jan 11 '22 15:01 nickrobinson251

I don't love the name.

One reason for the apparent double negative is that the present implementation fails to false. It answers the question "are we certain we can take the trivial path here?"

mcabbott avatar Jan 11 '22 15:01 mcabbott

Latest commit changes this to return a type, not a function. This means you test differential_type(x) <: AbstractZero to see whether x is known to have no derivative.

In general it returns the T in ProjectTo{T}. I'm not certain that has any other uses. But at least it's more obviously something read out from the projection machinery, rather than being an independent concept.

mcabbott avatar Jul 08 '22 14:07 mcabbott