ChainRules.jl
ChainRules.jl copied to clipboard
More efficient projection in svd pullback
This change makes the pullback for svd faster and more memory efficient when taking the SVD of very tall or very wide matrices (which is a common application).
The issue is that the "textbook" formula (I - U*U')*X can be extremely slow and memory-intensive when U is
tall. The mathematically equivalent form X - U*(U'*X) avoids creating and then multiplying by the large matrix U*U'.
Example:
using LinearAlgebra, Zygote
f(x) = sum(svd(x).U)
X = randn(20000, 10);
@time Zygote.gradient(f, X);
Without this PR, the above runs in about 2 seconds and allocates 3 GB. With this PR, it runs in less than 0.01 seconds and allocates 11 MB.
There is further room for improvement: When the input is to svd wide, then I - U*U' is zero. Conversely, when the input is tall, then I - V*V' is zero. This means that it would be possible to avoid some unnecessary computation by adding a couple of if-statements.
(This PR also removes the undocumented, unexported utility function _eyesubx! that is not used elsewhere in the package.)
@sethaxen are you able to take a look at this? You wrote most of our factorizations, but not this one. I think this one came from Nabla.jl
Indeed, I've not inspected this one closely, but the approach of this PR looks right. I'll review in-depth shortly.
As a consequence of the smplifications in my last commits to this PR, the _mulsubtrans!! helper function can no longer be used, and as it is not used anywhere else, it has been removed.
I have not done any benchmarking, but I feel the removal of unnecessary matrix multiplication should speed things up in general.
I've also changed the function signature to take V̄t instead of V̄. This avoids an unnecessary transpose, but mainly it makes the code look much better, as there is an obvious symmetry between U and Vt. However, if there is any external code that calls the (undocumented) svd_rev function, then that would break. If that's a problem, I can revert the change.
I've also changed the function signature to take
V̄tinstead ofV̄. This avoids an unnecessary transpose, but mainly it makes the code look much better, as there is an obvious symmetry between U and Vt. However, if there is any external code that calls the (undocumented)svd_revfunction, then that would break. If that's a problem, I can revert the change.
svd_rev should be considered an implementation detail and not used by other packages. On GitHub, the only package I can see that uses svd_rev besides ChainRules is Neurogenesis: https://github.com/k8lion/Neurogenesis/blob/003d9e0a33f3c23a304b64968516638844bd837e/src/initializers.jl#L19-L25 . They also don't set compat bounds though, so they're not protected against breaking changes in any dependencies.
In that case, I'll leave it as it is now, with V̄t as the input. The svd_rev method in Neurogenesis is for V̄::AbstractZero so it will still work.
The reason why that package has its own method might be that svd_rev has USV::SVD in the function signature. Changing it to accept any type would probably allow it to work with CUDA.CUSOLVER.CuSVD, and Neurogenesis could simply remove its specialized method.
Actually, the method from Neurogenesis is still faster, because it avoids the allocation and computation of the F matrix. So I think it would make sense to have two methods:
svd_rev(USV, Ū, s̄, V̄t) with the code from this PR. This applies in the case where the derivatives of the singular vectors are needed.
svd_rev(USV, Ū::AbstractZero, s̄, V̄::AbstractZero) with all code related to Ū, V̄ and F removed (making it identical to the code in Neurogenesis.jl). This would apply in the case where the derivatives of only the singular values are needed, which I think is a common one. Edit: In fact, the svdvals function exists specifically for this case, and this would yield an efficient derivative for that function, resolving #206.
One more potential improvement: The matrix currently written as (FUᵀŪS + S̄ + SFVᵀV̄) can be computed directly without first computing F, FUᵀŪS and SFVᵀV̄. This would avoid those three memory allocations.
svd_rev(USV, Ū::AbstractZero, s̄, V̄::AbstractZero)with all code related toŪ,V̄andFremoved (making it identical to the code in Neurogenesis.jl). This would apply in the case where the derivatives of only the singular values are needed, which I think is a common one.
This is fine, though I would actually make the signature svd_rev(USV, Ū::AbstractZero, s̄, V̄t::AbstractZero). Even better is if our implementation here automatically compiles to the same thing due to all of the specializations on AbstractZero, though that's not easy to do while reusing storage.
Can you run JuliaFormatter on the code?
Yes, I meant to write svd_rev(USV, Ū::AbstractZero, s̄, V̄t::AbstractZero) (copy-paste error).
The latest commit that I pushed improves performance a bit, and has been passed through JuliaFormatter.
Since Julia can't infer that M is diagonal when Ū and V̄t are AbstractZero, it still makes sense to include the method from Neurogenesis. It would perhaps be best if @neurogenesisauthors made that PR to preserve authorship.
It would perhaps be best if @neurogenesisauthors made that PR to preserve authorship.
That's not necessary. Specializing for the AbstractZero case is common in our codebase. e.g. here's where we do the same thing for svd(::Symmetric) (called by eigen):
https://github.com/JuliaDiff/ChainRules.jl/blob/ae37562ea1f16816a0d8fff24e0aca6cd594a40f/src/rulesets/LinearAlgebra/symmetric.jl#L153-L154
The latest commit to this PR adds the method for when Ū and V̄t are AbstractZero. I don't think there's anything else left to do at the moment.