ChainRules.jl
ChainRules.jl copied to clipboard
Where to find basic aritmetic operators derivatives and broadcasted versions
Hey,
I try to redefine @scalar_rules and some of the macros, to create the appropriate code for our symbolic derivation.
I tried to google and understand the source code it but didin't see where to fing the pullback for
:(+)(::AbstractArray, ::AbstractArray) ...
:(.+)(::AbstractArray, ::AbstractArray) ...
... -,*,/,^...
and so on.
I know it is easy to do but the library looks really nice and I don't understand where are the basic aritmetics. I found some *... in some case but the Array*Array also something I couldn't find.
What do I miss? How does Zygote do the chainrules without these arithmetics I couldn't get to know. Can you guys help me?
The frule and rrules for scalar-functions are mostly defined with the @scalar_rule macro.
And rules for scalar + and * are in the src/rulesets/base/fastmath_able.jl file.
See https://github.com/JuliaDiff/ChainRules.jl/blob/38caf4bfdb8af616fcbe7626d10699608af21904/src/rulesets/Base/fastmath_able.jl#L163-L165 and https://github.com/JuliaDiff/ChainRules.jl/blob/38caf4bfdb8af616fcbe7626d10699608af21904/src/rulesets/Base/fastmath_able.jl#L206-L220
Rules for Array functions are in rulesets/base/array.jl or rulesets/base/arraymath.jl (roughly trying to match the location of the functions in Julia Base). Some Array functionality is from the LinearAlgebra standard library, so defined in src/rulesets/LinearAlgebra/.
The rules for Array*Array are here https://github.com/JuliaDiff/ChainRules.jl/blob/76ef95c326e773c6c7140fb56eb2fd16a2af468b/src/rulesets/Base/arraymath.jl#L18-L63
Thank you for the detailed answer!
For me the unclear part is that how does * handle for bigger arrays 3D, 4D... But as I typed this again realised that there is no * opreation between bigger array, what I am looking for the ".*" and ".^" etc. which is interpreted between bigger arrays. What are the broadcasted function's version frule, rrule?
One more question just asking fast, I see * but only see @scalar_rule for + and -, from where does the Zygote get the info for the array cases?
Zygote doesn't use ChainRules for handling broadcasting at all, everything is defined here: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl. We'd first need to solve https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68 before we can define such rules in ChainRules.
It is really tricky code, I just can't understand how does this broadcasting called/overloaded in each . case. Also I don't see the cases for the N dimension Array types. But I will try to read more of their code maybe.
So the only question I have where are the array versions of the +, - in ChainRules, or am I missing something here?
So the only question I have where are the array versions of the
+,-in ChainRules, or am I missing something here?
See https://github.com/JuliaDiff/ChainRules.jl/blob/2e6491c9dd20608b725b7ecdca2dbe872cf7833b/src/rulesets/Base/arraymath.jl#L306-L329. I'm not sure if there's a dedicated rrule for -(A::AbstractArray, B::AbstractArray) or if it falls through to broadcasting.