ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Make the rrule for 3-arg dot lazy

Open mohdibntarek opened this issue 1 year ago • 6 comments

Currently this line https://github.com/JuliaDiff/ChainRules.jl/blob/9f1817a22404259113e230bef149a54d379a660b/src/rulesets/LinearAlgebra/dense.jl#L38 leads to a loss of structure because it constructs a dense matrix from 2 vectors. We should make this a lazy outer product (like in https://github.com/SciML/LinearSolve.jl/pull/484) such that the m x n matrix's co-tangent only requires O(m + n) memory. See https://discourse.julialang.org/t/zygote-jl-how-to-get-the-gradient-of-sparse-matrix/59067/12 for the potential gain and suggested implementation. If the idea gets a green light, I will open a PR.

mohdibntarek avatar Mar 23 '24 10:03 mohdibntarek

I guess the key question is are we willing to add LazyArrays as a dependency?

oxinabox avatar Apr 04 '24 08:04 oxinabox

Related: https://discourse.julialang.org/t/zygote-much-slower-than-jax-for-automatic-differentiation-of-energy/114239/5

gdalle avatar May 14 '24 13:05 gdalle

Since this seems to have come up twice in two different contexts i think we can say that this is worth it and we should add LazyArrays as a dependency. @mohamed82008 are you able to make the PR?

oxinabox avatar May 14 '24 14:05 oxinabox

I am able to yes. I will do it in the weekend unless someone wants to beat me to it. It's basically exactly this implementation https://discourse.julialang.org/t/zygote-jl-how-to-get-the-gradient-of-sparse-matrix/59067/12 plus unit tests.

mohdibntarek avatar May 14 '24 15:05 mohdibntarek

Hello, are there any news about this? I would like to implement reverse diff to dot(x, A, x) function with A sparse, but at the moment it converts the A matrix into a dense one.

albertomercurio avatar May 22 '24 09:05 albertomercurio

https://github.com/JuliaDiff/ChainRules.jl/pull/796

mohdibntarek avatar May 25 '24 21:05 mohdibntarek