Extend scalar_rule to handle broadcast
As a follow up to https://github.com/FluxML/NNlib.jl/pull/257 and https://github.com/FluxML/NNlib.jl/pull/258,
I wish the @scalar_rule macro could be extended with an optional keyword argument so that it also defines
the broadcasted version of a rule, so something like
@scalar_rule tanh(x) 1 - Ω^2 broadcast=true
I would have attempted a PR, but I quickly realized that this above my very very little macro expertise :(
Maybe also a map=true option for map(f, x...)?
This would be one way to implement https://github.com/JuliaDiff/ChainRules.jl/issues/222 and probably a pretty decent one.
It does contribute to making @scalar_rule kind of extra good.
LIke premium way of declaring things with extra bonus stuff.
but that would also be the case for other things being discussed like #246
I think we should probably have a seperate macro, which @scalar_rule would call.
Something like @declare_higher_order_function_rules, (or maybe @declare_functional_rules so as to not to confuse with higher order derivatives)
so that it can also be triggered directly, for things that were not declared by @scalar_rule.
I don't even think we need to have it as an optional thing. We should just always do it when it makes sense. (If we find out that we are wrong later we can change that.)
xref #68
If I got this right, when you define a rule via @scalar_rule, it does not get applied when you broadcast that function?
If I got this right, when you define a rule via @scalar_rule, it does not get applied when you broadcast that function?
That is incorrect.
Zygote sometimes doesn't apply rules for broadcast (of either rrule, @scalar_rule or even ZygoteRules.@adjoint) because it switches to forwards mode using ForwardDiff.jl which doesn't support extensible rules of any kind.
This is unrelated to this issue.
This issue is proposing that @scalar_rule would automatically generate rrule(::Base.broadcasted, ::f, args...) and similar for forwards mode.
It is less clear that this is a good idea vs just bringing in a general solution now that we have all the parts: like https://github.com/JuliaDiff/Diffractor.jl/pull/68
@scalar_rule now defines derivatives_given_output which seems like a good way to implement both broadcasting and things like sum(f, x).
One question is whether it should do more to help with this, perhaps by flagging some functions as cheap enough to do twice.
By default that derivatives_given_output wants both x and y = f.(x), and thus the simples broadcasting rule defined using it must be un-fused, i.e. must store y even when you calculate z = g.(f.(x)) forwards. If you know both that the rule for f doesn't need its output, and that the rule for g doesn't need its input, then you could avoid storing y, but that seems tricky to check, and rare. Alternatively, if you know that f is very cheap, then you might prefer to re-calculate it, in order to fuse. I think that's true for +,-,*,conj which are easy to hard-wire. How many other functions are like that?