ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

More efficient projection in svd pullback

Open perrutquist opened this issue 2 years ago • 11 comments

This change makes the pullback for svd faster and more memory efficient when taking the SVD of very tall or very wide matrices (which is a common application).

The issue is that the "textbook" formula (I - U*U')*X can be extremely slow and memory-intensive when U is tall. The mathematically equivalent form X - U*(U'*X) avoids creating and then multiplying by the large matrix U*U'.

Example:

using LinearAlgebra, Zygote
f(x) = sum(svd(x).U)
X = randn(20000, 10);
@time Zygote.gradient(f, X);

Without this PR, the above runs in about 2 seconds and allocates 3 GB. With this PR, it runs in less than 0.01 seconds and allocates 11 MB.

There is further room for improvement: When the input is to svd wide, then I - U*U' is zero. Conversely, when the input is tall, then I - V*V' is zero. This means that it would be possible to avoid some unnecessary computation by adding a couple of if-statements.

(This PR also removes the undocumented, unexported utility function _eyesubx! that is not used elsewhere in the package.)

perrutquist avatar Nov 10 '23 12:11 perrutquist

@sethaxen are you able to take a look at this? You wrote most of our factorizations, but not this one. I think this one came from Nabla.jl

oxinabox avatar Nov 10 '23 13:11 oxinabox

Indeed, I've not inspected this one closely, but the approach of this PR looks right. I'll review in-depth shortly.

sethaxen avatar Nov 10 '23 13:11 sethaxen

As a consequence of the smplifications in my last commits to this PR, the _mulsubtrans!! helper function can no longer be used, and as it is not used anywhere else, it has been removed.

I have not done any benchmarking, but I feel the removal of unnecessary matrix multiplication should speed things up in general.

I've also changed the function signature to take V̄t instead of . This avoids an unnecessary transpose, but mainly it makes the code look much better, as there is an obvious symmetry between U and Vt. However, if there is any external code that calls the (undocumented) svd_rev function, then that would break. If that's a problem, I can revert the change.

perrutquist avatar Dec 04 '23 20:12 perrutquist

I've also changed the function signature to take V̄t instead of . This avoids an unnecessary transpose, but mainly it makes the code look much better, as there is an obvious symmetry between U and Vt. However, if there is any external code that calls the (undocumented) svd_rev function, then that would break. If that's a problem, I can revert the change.

svd_rev should be considered an implementation detail and not used by other packages. On GitHub, the only package I can see that uses svd_rev besides ChainRules is Neurogenesis: https://github.com/k8lion/Neurogenesis/blob/003d9e0a33f3c23a304b64968516638844bd837e/src/initializers.jl#L19-L25 . They also don't set compat bounds though, so they're not protected against breaking changes in any dependencies.

sethaxen avatar Dec 04 '23 21:12 sethaxen

In that case, I'll leave it as it is now, with V̄t as the input. The svd_rev method in Neurogenesis is for V̄::AbstractZero so it will still work.

The reason why that package has its own method might be that svd_rev has USV::SVD in the function signature. Changing it to accept any type would probably allow it to work with CUDA.CUSOLVER.CuSVD, and Neurogenesis could simply remove its specialized method.

perrutquist avatar Dec 04 '23 21:12 perrutquist

Actually, the method from Neurogenesis is still faster, because it avoids the allocation and computation of the F matrix. So I think it would make sense to have two methods: svd_rev(USV, Ū, s̄, V̄t) with the code from this PR. This applies in the case where the derivatives of the singular vectors are needed. svd_rev(USV, Ū::AbstractZero, s̄, V̄::AbstractZero) with all code related to Ū, and F removed (making it identical to the code in Neurogenesis.jl). This would apply in the case where the derivatives of only the singular values are needed, which I think is a common one. Edit: In fact, the svdvals function exists specifically for this case, and this would yield an efficient derivative for that function, resolving #206.

perrutquist avatar Dec 05 '23 07:12 perrutquist

One more potential improvement: The matrix currently written as (FUᵀŪS + S̄ + SFVᵀV̄) can be computed directly without first computing F, FUᵀŪS and SFVᵀV̄. This would avoid those three memory allocations.

perrutquist avatar Dec 05 '23 08:12 perrutquist

svd_rev(USV, Ū::AbstractZero, s̄, V̄::AbstractZero) with all code related to Ū, and F removed (making it identical to the code in Neurogenesis.jl). This would apply in the case where the derivatives of only the singular values are needed, which I think is a common one.

This is fine, though I would actually make the signature svd_rev(USV, Ū::AbstractZero, s̄, V̄t::AbstractZero). Even better is if our implementation here automatically compiles to the same thing due to all of the specializations on AbstractZero, though that's not easy to do while reusing storage.

Can you run JuliaFormatter on the code?

sethaxen avatar Dec 05 '23 19:12 sethaxen

Yes, I meant to write svd_rev(USV, Ū::AbstractZero, s̄, V̄t::AbstractZero) (copy-paste error).

The latest commit that I pushed improves performance a bit, and has been passed through JuliaFormatter.

Since Julia can't infer that M is diagonal when Ū and V̄t are AbstractZero, it still makes sense to include the method from Neurogenesis. It would perhaps be best if @neurogenesisauthors made that PR to preserve authorship.

perrutquist avatar Dec 05 '23 22:12 perrutquist

It would perhaps be best if @neurogenesisauthors made that PR to preserve authorship.

That's not necessary. Specializing for the AbstractZero case is common in our codebase. e.g. here's where we do the same thing for svd(::Symmetric) (called by eigen): https://github.com/JuliaDiff/ChainRules.jl/blob/ae37562ea1f16816a0d8fff24e0aca6cd594a40f/src/rulesets/LinearAlgebra/symmetric.jl#L153-L154

sethaxen avatar Jan 01 '24 15:01 sethaxen

The latest commit to this PR adds the method for when Ū and V̄t are AbstractZero. I don't think there's anything else left to do at the moment.

perrutquist avatar May 15 '24 15:05 perrutquist