ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

rrule for multi-argument product

Open dfdx opened this issue 3 years ago • 1 comments

From dfdx/Yota.jl#93:

A = rand(100, 100)
x = rand(100)
rrule(*, x', A, x)  # ==> nothing

It's possible to binarize the operation on the AD engine side, but having a single rule that works automagically works for all ADs seems to be a better approach.

Putting it into my own ToDo list, but if somebody feels like taking over, please do.

dfdx avatar Feb 15 '22 22:02 dfdx

Xref https://github.com/JuliaDiff/ChainRules.jl/issues/544 , about numbers, because they were slow. And https://github.com/JuliaDiff/ChainRules.jl/pull/412 for those 3,4-arg * cases which go to mul!.

Unlike real numbers, the order matters for vectors & matrices. And up to 4, Julia tries to choose the fastest one, forwards. If that's still the optimum grouping for the gradient, then it should be respected by any rules defined here.

mcabbott avatar Feb 15 '22 22:02 mcabbott