Tracker.jl
Tracker.jl copied to clipboard
Back Propagation for SVD
Here I present the correct (but poor) implementation of BP for SVD, this implementation changes the original svd
interfaces a bit, hoping someone can help improve it.
using LinearAlgebra
using Flux
using Flux.Tracker: @grad, data, track, TrackedTuple
import Flux.Tracker: _forward
import LinearAlgebra: svd
"""stablized back propagation function for svd"""
function svd_back(U, S, V, dU, dS, dV)
NS = length(S)
S2 = S.^2
Sinv = 1 ./ S
F = S2' .- S2
@. F = F/(F^2 + 1e-12)
UdU = U'*dU
VdV = V'*dV
Su = (F.*(UdU-UdU'))*Diagonal(S)
Sv = Diagonal(S) * (F.*(VdV-VdV'))
U * (Su + Sv + Diagonal(dS)) * V' +
(I - U*U') * dU*Diagonal(Sinv) * V' +
U*Diagonal(Sinv) * dV' * (I - V*V')
end
svd(a::TrackedArray) = track(svd, a)
# I doubt the macro `@grad` interface is less intuitive than `_forward`
function _forward(::typeof(svd), a)
U, S, V = svd(data(a)) # making `svd` return value SVD, making Julian's life shorter.
# returning a list won't work, one will get 0 gradient
# [U|>param, S|>param, V|>param], -> (svd_back(U, S, V, dU, dS, dV),)
(U, S, Matrix(V)), Δ -> (svd_back(U, S, V, Δ...),)
end
# This is a use case
M, N = 4, 6
K = min(M, N)
A = param(randn(M, N))
res = svd(A)
# implement `Base.iterate(res::TrackedTuple) = ?` can make it prettier
U, S, V = res[1], res[2], res[3]
dU, dS, dV = randn(M, K), randn(K), randn(N, K)
Tracker.back!(res, (dU, dS, dV))
Tracker.grad(A)
Why we use Matrix(V)
here?
We see this line in file src/tracker/scalar.jl
is called
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
One should notice function zero
can change type
sometimes!
Here, SVD returns V as Adjoint, zero(Adjoint)
will get Array
!
Gocha!
Some aspects can be improved
- Returning [U, S, V] in Flux should not cause gradient tracking failure.
- Tracker should be able to propagate over dagger?
- Return value checking for
_forward
is nessesary, so that readable error message can be throwed. -
@grad
is an arguably useful interface - Julia should remove over designed outputs for linear algebra functions like
svd
, I didn't see many benefits of such design. -
zero
andone
should never change type, here, it should be considered as a bug.