Diffractor.jl icon indicating copy to clipboard operation
Diffractor.jl copied to clipboard

Construct natural tangents for Number and AbstractArray{<:Number} in forward mode

Open oxinabox opened this issue 1 year ago • 1 comments

Ok. Basically a lot of the code in ChainRules assumes that if your type is a Number then its tangent is a number of that same type. A slightly smaller, but still non-zero portion of code assumed the same about

This is an alternative to https://github.com/JuliaDiff/Diffractor.jl/pull/272 and https://github.com/JuliaDiff/ChainRules.jl/pull/787 Since if you always work with natural tangent types for Numbers you don't run into problems like in ability to add them, since that same addition occurs in the primal problem. And that feels better, mostly.

I think it will also obsolete a lot of hackier things like our special handling for StaticArrays e.g. what had to be changed in https://github.com/JuliaDiff/Diffractor.jl/pull/275 we shouldn't need that rule at all. And probably some of the other rules in extra_rules.jl can go away.

To support Diffractor over ForwardDiff we need both handling of natural tangents for AbstractArray{<Number} (for ForwardDiff.Partials) and handling of natural tangents for Number (for ForwardDiff.Dual, which uses ForwarDiff.Partials)

oxinabox avatar Mar 05 '24 11:03 oxinabox

This should also be called in zero_bundle since zero_tangent doesn't do it for AbstractArrays only Arrays

oxinabox avatar Mar 07 '24 13:03 oxinabox