ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Where to find basic aritmetic operators derivatives and broadcasted versions

Open Cvikli opened this issue 3 years ago • 5 comments

Hey,

I try to redefine @scalar_rules and some of the macros, to create the appropriate code for our symbolic derivation.

I tried to google and understand the source code it but didin't see where to fing the pullback for

:(+)(::AbstractArray, ::AbstractArray) ...
:(.+)(::AbstractArray, ::AbstractArray) ...
... -,*,/,^...

and so on.

I know it is easy to do but the library looks really nice and I don't understand where are the basic aritmetics. I found some *... in some case but the Array*Array also something I couldn't find.

What do I miss? How does Zygote do the chainrules without these arithmetics I couldn't get to know. Can you guys help me?

Cvikli avatar Apr 18 '21 07:04 Cvikli

The frule and rrules for scalar-functions are mostly defined with the @scalar_rule macro.

And rules for scalar + and * are in the src/rulesets/base/fastmath_able.jl file.

See https://github.com/JuliaDiff/ChainRules.jl/blob/38caf4bfdb8af616fcbe7626d10699608af21904/src/rulesets/Base/fastmath_able.jl#L163-L165 and https://github.com/JuliaDiff/ChainRules.jl/blob/38caf4bfdb8af616fcbe7626d10699608af21904/src/rulesets/Base/fastmath_able.jl#L206-L220

Rules for Array functions are in rulesets/base/array.jl or rulesets/base/arraymath.jl (roughly trying to match the location of the functions in Julia Base). Some Array functionality is from the LinearAlgebra standard library, so defined in src/rulesets/LinearAlgebra/.

The rules for Array*Array are here https://github.com/JuliaDiff/ChainRules.jl/blob/76ef95c326e773c6c7140fb56eb2fd16a2af468b/src/rulesets/Base/arraymath.jl#L18-L63

nickrobinson251 avatar Apr 18 '21 14:04 nickrobinson251

Thank you for the detailed answer!

For me the unclear part is that how does * handle for bigger arrays 3D, 4D... But as I typed this again realised that there is no * opreation between bigger array, what I am looking for the ".*" and ".^" etc. which is interpreted between bigger arrays. What are the broadcasted function's version frule, rrule?

One more question just asking fast, I see * but only see @scalar_rule for + and -, from where does the Zygote get the info for the array cases?

Cvikli avatar Apr 18 '21 14:04 Cvikli

Zygote doesn't use ChainRules for handling broadcasting at all, everything is defined here: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl. We'd first need to solve https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68 before we can define such rules in ChainRules.

simeonschaub avatar Apr 18 '21 14:04 simeonschaub

It is really tricky code, I just can't understand how does this broadcasting called/overloaded in each . case. Also I don't see the cases for the N dimension Array types. But I will try to read more of their code maybe.

So the only question I have where are the array versions of the +, - in ChainRules, or am I missing something here?

Cvikli avatar Apr 18 '21 14:04 Cvikli

So the only question I have where are the array versions of the +, - in ChainRules, or am I missing something here?

See https://github.com/JuliaDiff/ChainRules.jl/blob/2e6491c9dd20608b725b7ecdca2dbe872cf7833b/src/rulesets/Base/arraymath.jl#L306-L329. I'm not sure if there's a dedicated rrule for -(A::AbstractArray, B::AbstractArray) or if it falls through to broadcasting.

ToucheSir avatar Sep 04 '21 23:09 ToucheSir