ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Extend scalar_rule to handle broadcast

Open CarloLucibello opened this issue 4 years ago • 6 comments

As a follow up to https://github.com/FluxML/NNlib.jl/pull/257 and https://github.com/FluxML/NNlib.jl/pull/258, I wish the @scalar_rule macro could be extended with an optional keyword argument so that it also defines the broadcasted version of a rule, so something like

@scalar_rule  tanh(x)  1 - Ω^2    broadcast=true

I would have attempted a PR, but I quickly realized that this above my very very little macro expertise :(

CarloLucibello avatar Dec 29 '20 02:12 CarloLucibello

Maybe also a map=true option for map(f, x...)?

CarloLucibello avatar Dec 29 '20 07:12 CarloLucibello

This would be one way to implement https://github.com/JuliaDiff/ChainRules.jl/issues/222 and probably a pretty decent one.

It does contribute to making @scalar_rule kind of extra good. LIke premium way of declaring things with extra bonus stuff. but that would also be the case for other things being discussed like #246

I think we should probably have a seperate macro, which @scalar_rule would call. Something like @declare_higher_order_function_rules, (or maybe @declare_functional_rules so as to not to confuse with higher order derivatives) so that it can also be triggered directly, for things that were not declared by @scalar_rule.

I don't even think we need to have it as an optional thing. We should just always do it when it makes sense. (If we find out that we are wrong later we can change that.)

oxinabox avatar Dec 29 '20 15:12 oxinabox

xref #68

nickrobinson251 avatar Apr 18 '21 14:04 nickrobinson251

If I got this right, when you define a rule via @scalar_rule, it does not get applied when you broadcast that function?

cossio avatar Dec 24 '21 15:12 cossio

If I got this right, when you define a rule via @scalar_rule, it does not get applied when you broadcast that function?

That is incorrect. Zygote sometimes doesn't apply rules for broadcast (of either rrule, @scalar_rule or even ZygoteRules.@adjoint) because it switches to forwards mode using ForwardDiff.jl which doesn't support extensible rules of any kind. This is unrelated to this issue.

This issue is proposing that @scalar_rule would automatically generate rrule(::Base.broadcasted, ::f, args...) and similar for forwards mode.

It is less clear that this is a good idea vs just bringing in a general solution now that we have all the parts: like https://github.com/JuliaDiff/Diffractor.jl/pull/68

oxinabox avatar Jan 11 '22 14:01 oxinabox

@scalar_rule now defines derivatives_given_output which seems like a good way to implement both broadcasting and things like sum(f, x).

One question is whether it should do more to help with this, perhaps by flagging some functions as cheap enough to do twice.

By default that derivatives_given_output wants both x and y = f.(x), and thus the simples broadcasting rule defined using it must be un-fused, i.e. must store y even when you calculate z = g.(f.(x)) forwards. If you know both that the rule for f doesn't need its output, and that the rule for g doesn't need its input, then you could avoid storing y, but that seems tricky to check, and rare. Alternatively, if you know that f is very cheap, then you might prefer to re-calculate it, in order to fuse. I think that's true for +,-,*,conj which are easy to hard-wire. How many other functions are like that?

mcabbott avatar Jan 11 '22 16:01 mcabbott