ChainRules.jl
ChainRules.jl copied to clipboard
Rule for `svdvals`
Requested here: https://discourse.julialang.org/t/implementation-of-spectral-normalization-for-machine-learning/76074
The workaround is to call svd(X).S, which is slower forwards. But it looks like the gradient calculation with something like svd_rev((; U=NoTangent(), s=s, V=NoTangnet(), Vt=NoTangent()), NoTangent(), S̄, NoTangent()) is probably fairly efficient, and could easily be extracted to its own rule:
https://github.com/JuliaDiff/ChainRules.jl/blob/3590f9421950508a97d5a9dbc207208e331c8b75/src/rulesets/LinearAlgebra/factorization.jl#L221-L225
I agree, yes, this should be its own rule. We already did that for Hermitian matrices.
https://github.com/JuliaDiff/ChainRules.jl/blob/ffbaa5fecca8da39f20aeb66cc6e4edf3e0c3f11/src/rulesets/LinearAlgebra/symmetric.jl#L286-L300
There I ended up having a custom rule. I don't remember why. It may have been faster than using NoTangent().
Ah, no, it goes via the rrule for eigvals, which uses the pullback for eigen. So yeah, just using the pullback for svd might be best.