Tracker.jl icon indicating copy to clipboard operation
Tracker.jl copied to clipboard

Back Propagation for SVD

Open GiggleLiu opened this issue 5 years ago • 3 comments

Here I present the correct (but poor) implementation of BP for SVD, this implementation changes the original svd interfaces a bit, hoping someone can help improve it.

using LinearAlgebra
using Flux
using Flux.Tracker: @grad, data, track, TrackedTuple
import Flux.Tracker: _forward
import LinearAlgebra: svd

"""stablized back propagation function for svd"""
function svd_back(U, S, V, dU, dS, dV)
    NS = length(S)
    S2 = S.^2
    Sinv = 1 ./ S
    F = S2' .- S2
    @. F = F/(F^2 + 1e-12)

    UdU = U'*dU
    VdV = V'*dV

    Su = (F.*(UdU-UdU'))*Diagonal(S)
    Sv = Diagonal(S) * (F.*(VdV-VdV'))

    U * (Su + Sv + Diagonal(dS)) * V' +
    (I - U*U') * dU*Diagonal(Sinv) * V' +
    U*Diagonal(Sinv) * dV' * (I - V*V')
end

svd(a::TrackedArray) = track(svd, a)
# I doubt the macro `@grad` interface is less intuitive than `_forward`
function _forward(::typeof(svd), a)
    U, S, V = svd(data(a))   # making `svd` return value SVD, making Julian's life shorter.
    # returning a list won't work, one will get 0 gradient
    # [U|>param, S|>param, V|>param],  -> (svd_back(U, S, V, dU, dS, dV),)
    (U, S, Matrix(V)), Δ -> (svd_back(U, S, V, Δ...),)
end

# This is a use case
M, N = 4, 6
K = min(M, N)
A = param(randn(M, N))
res = svd(A)
# implement `Base.iterate(res::TrackedTuple) = ?` can make it prettier
U, S, V = res[1], res[2], res[3]

dU, dS, dV = randn(M, K), randn(K), randn(N, K)
Tracker.back!(res, (dU, dS, dV))
Tracker.grad(A)

Why we use Matrix(V) here?

We see this line in file src/tracker/scalar.jl is called

track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))

One should notice function zero can change type sometimes!
Here, SVD returns V as Adjoint, zero(Adjoint) will get Array!

Gocha!

Some aspects can be improved

  • Returning [U, S, V] in Flux should not cause gradient tracking failure.
  • Tracker should be able to propagate over dagger?
  • Return value checking for _forward is nessesary, so that readable error message can be throwed.
  • @grad is an arguably useful interface
  • Julia should remove over designed outputs for linear algebra functions like svd, I didn't see many benefits of such design.
  • zero and one should never change type, here, it should be considered as a bug.

GiggleLiu avatar Oct 29 '18 16:10 GiggleLiu