ChainRules.jl
ChainRules.jl copied to clipboard
rrule for multi-argument product
From dfdx/Yota.jl#93:
A = rand(100, 100)
x = rand(100)
rrule(*, x', A, x) # ==> nothing
It's possible to binarize the operation on the AD engine side, but having a single rule that works automagically works for all ADs seems to be a better approach.
Putting it into my own ToDo list, but if somebody feels like taking over, please do.
Xref https://github.com/JuliaDiff/ChainRules.jl/issues/544 , about numbers, because they were slow. And https://github.com/JuliaDiff/ChainRules.jl/pull/412 for those 3,4-arg *
cases which go to mul!
.
Unlike real numbers, the order matters for vectors & matrices. And up to 4, Julia tries to choose the fastest one, forwards. If that's still the optimum grouping for the gradient, then it should be respected by any rules defined here.