ChainRules.jl
ChainRules.jl copied to clipboard
[WIP] Complex valued SVD
Nice project, I like the test coverage of this package. Sorry for the long delay (see https://github.com/GiggleLiu/BackwardsLinalg.jl/issues/17). I tried multiple times to add the backward rules to ChainRules, but feeling hard to write tests that fitting this framework. Can some one help me?
- added a new function for complex valued svd back-propagation. But feeling hard to test functions with gauge problem.
- added a
safe_inv
function to invert the singular matrix subtraction matrix safely. Otherwise, it can break up easily when there are degenerate spectrums. - also, I am planning to fix the real valued SVD to reduce the number of matrix multiplication.
Refs
- Automatic Differentiation for Complex Valued SVD
- https://j-towns.github.io/papers/svd-derivative.pdf
- https://giggleliu.github.io/2019/04/02/einsumbp.html
Before I review, can you comment on how this implementation differs from the existing real one? That is, it looks like the current real implementation would also work on complex numbers if we released the real type constraint, so I'm wondering how this rule differs from that one, and can they be unified.
Before I review, can you comment on how this implementation differs from the existing real one? That is, it looks like the current real implementation would also work on complex numbers if we released the real type constraint, so I'm wondering how this rule differs from that one, and can they be unified.
Sure, if you open the third link, you will see the red term. That one is the missing term in the real version. For more detailed description, you need to check the first link to the original paper. I agree they should be in a single function, this is just a draft.
Also, please note the safe_inv
is important in some applications. e.g. In many applications in physics, there are degenerate singular values due to the high symmetry.
Normally, when the denominator |s_i^2 - s_j^2|
is zero, the numerator should also be zero (the loss is gauge invariant). However, this assumption is vulnerable due to rounding errors, and this is why we need to handle this case manually to avoid fake singularity in the gradient.