ForwardDiff.jl
ForwardDiff.jl copied to clipboard
Sparse backslash
I get a StackOverflowError if I try to differentiate some code with a sparse matrix dependent on the variable and I use backslash with it. MWE:
using SparseArrays, ForwardDiff, LinearAlgebra
A = sparse([1.0 1.0 ; 0.0 1.0])
f(x) = sparse(A + x * x') \ ones(2)
x = ones(2)
f(x)
ForwardDiff.jacobian(f, x)
Compared to the "full" version, which works:
f_full(x) = (A + x * x') \ ones(2)
x = ones(2)
f_full(x)
ForwardDiff.jacobian(f_full, x)
I feel like this could be solved by adding a rule for underlying machinery of dual numbers for backslash, something akin to
function Base.:\(M::Dual{T}, x)
MRf = factorize(realpart.(M))
sol = MRf \ x
sol -= Dual(0.0, 1.0) * (MRf \ (dualpart.(M) * sol))
return sol
end
(That's what I do when I use DualNumbers.jl.) This should work and allow to only use LinearAlgebra with real-valued types because
But I am not sure how to suggest changes directly in ForwardDiff.jl. Does that make sense?
I saw that there were many discussions revolving around moving ForwardDiff.jl's dual-number implementation into DualNumbers.jl and got a bit lost in those. What is the current status of this? Is there still an ongoing discussion?
I guess a solution chould actually define a new factor type, which stores the factorization of the real part and the dual part of the matrix. So for dual-value matrix M, have something like
struct DualFactors
RealFactors
DualMatrix
end
function LinearAlgebra.factorize(M::AbstractArray{Dual{T}, 2})
return DualFactors(factorize(realpart.(M)), dualpart.(M))
end
function Base.:\(Mf::DualFactors, x::AbstractArray{T})
sol = Mf.RealFactors \ x
sol -= Dual(0.0, 1.0) * (MRf \ (Mf.DualMatrix * sol))
return sol
end
Is that sensible? Is someone willing to help write this up to work with DualNumbers.jl or ForwardDiff.jl?
Stack overflow happens here: https://github.com/JuliaLang/julia/blob/v1.0.0/stdlib/SparseArrays/src/sparsematrix.jl#L421. float strikes again. See also:
- https://discourse.julialang.org/t/automatic-differentation-of-f-x-real-f-float-x-leads-to-stackoverflow/16076/6
- https://github.com/JuliaLang/julia/issues/26552
- https://github.com/JuliaDiff/ForwardDiff.jl/issues/261
- https://github.com/JuliaDiff/ForwardDiff.jl/issues/324
- https://github.com/musm/SLEEF.jl/issues/11
Thank you for trying to enlighten me. However, I am still quite lost and overwhelmed with the links you provided. Could you comment on my understanding and my suggestion?
Right now, I believe the way \ is treated by ForwardDiff.jl is defined in DiffRules.jl:
@define_diffrule Base.:\(x, y) = :( -($y / $x / $x) ), :( inv($x) )
Is this correct? I am not sure I understand exactly how the @define_diffrule macro works, but I feel like there is something wrong here. It seems like this rule only applies to scalars. That means that the differentiation rules are applied to whatever scalar operations happen inside whichever algorithm is called by \ when applied to a matrix M and a vector x.
Now if what I just wrote is true, it sounds quite inefficient for treating M \ x. And Julia is the opposite of inefficient, right? As mentioned in my first post, you can get to M \ x quicker for a dual-valued matrix M = A + ε Β (and also dual-valued vector x = a + ε b). So what about adding a rule that takes advantage of this and does:
function Base.:\(M::AbstractArray{Dual{T}, 2}, x::AbstractVecto{Dual{T}}) where T
A, B = factorize(realpart.(M)), dualpart.(M)
a, b = realpart.(x), dualpart.(x)
out = A \ a
out = out - ε * (A \ (B * out))
out = out + ε * (A \ b)
return out
end
Can a rule be defined for non-scalars in order to take precedence and shortcut the whole thing?
The rule can be generalized to higher-order duals (hyperduals, etc.) because is simply based on the Taylor expansion of the inverse of (A + ε₁ B₁ + ε₂ B₂ + ε₁₂ B₁₂), and so on.
And even better as mentioned in my first post would be a rule using a new type of matrix factors, which stores the factors of only the real part of M but also keeps the non-real part to use in the \ shortcut described above. (I.e., the factors contain B and the "factorized" A — see how I factorized A in the little snippet above.)
Is this already implemented in Julia? Would it not be much faster than the current "brute-force" approach of applying the diff rule to all scalars throughout? It would also allow ForwardDiff to "percolate" through those non-julia-native parts of factorizing and \, if any, right?
I don't know if I am out of my depth here, but I would love to understand this and/or help if I can.
Sorry for being terse. It just seems like this issue keeps cropping up.
Right now, I believe the way \ is treated by ForwardDiff.jl is defined in DiffRules.jl:
DiffRules.jl only defines differentiation rules for scalar-valued functions. Note that \ is defined for scalars as well, e.g. 2 \ 4 == 2.0.
Julia is the opposite of inefficient, right
Julia can be about as efficient or as inefficient as you want 🙂.
The direct cause of the stack overflow is that https://github.com/JuliaLang/julia/blob/v1.0.1/stdlib/SuiteSparse/src/umfpack.jl#L174 calls float on a SparseMatrixCSC of Duals, but doing so returns a SparseMatrixCSC of the same type (https://github.com/JuliaLang/julia/blob/v1.0.1/stdlib/SparseArrays/src/sparsematrix.jl#L432 and https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/dual.jl#L324), so lu ends up calling itself.
Can a rule be defined for non-scalars in order to take precedence and shortcut the whole thing?
Yes, although I don't think there's currently a precedent for that in ForwardDiff.jl itself. By the way, I'm mostly just an onlooker here, @jrevels is the main maintainer and can answer your questions better.
I'm sorry to insist, but I still think this could improve ForwardDiff.jl's efficiency when dealing with \ and factorizations. I have implemented my own factorization for my use case, but I thought I should share the benefits here. (Maybe this will motivate some more responses?) Below is a MWE with DualNumbers.jl (because I did not really know how to use ForwardDiff.jl's implementation of dual numbers). The goal is to solve J * x = y for dual-valued x with given dual-valued J and y.
using LinearAlgebra, DualNumbers, BenchmarkTools, SparseArrays
n = 2000 # size of J and y
# make J and y
J = Matrix(sprand(n, n, 1/n) + I) + Matrix(sprand(n, n, 1/n) + I) * ε
y = rand(n) + rand(n) * ε
# Type that stores the required terms only
mutable struct DualJacobianFactors
Af::LinearAlgebra.LU{Float64,Array{Float64,2}} # the factors of the real part
B::Array{Float64,2} # the non-real part
end
# Constructor function for a dual-valued matrix M
fast_factorize(M::Array{Dual{Float64},2}) = DualJacobianFactors(factorize(realpart.(M)), dualpart.(M))
# overload backslash (maybe this can be improved on)
function Base.:\(J::DualJacobianFactors, y::Vector{Dual{Float64}})
a, b = realpart.(y), dualpart.(y)
Af, B = J.Af, J.B
out = zeros(Dual{Float64}, length(y))
out .= Af \ a
out .-= (Af \ (B * out)) * ε
out .+= (Af \ b) * ε
return out
end
# Benchmark the factorization
@benchmark (Jf = factorize(J) ;)
@benchmark (fast_Jf = fast_factorize(J) ;)
# Benchmark the back-substitution
Jf = factorize(J) ;
fast_Jf = fast_factorize(J) ;
@benchmark (x1 = Jf \ y ;)
@benchmark (x2 = fast_Jf \ y ;)
# Check that the results are approximately the same
x1 = Jf \ y ;
x2 = fast_Jf \ y ;
realpart.(x1) ≈ realpart.(x2)
dualpart.(x1) ≈ dualpart.(x2)
I am not sure this is the right way to benchmark these things, but you can play around with the size n, edit the MWE, and see for yourself. On my laptop for n = 2000, this MWE improves the factorization by a factor of 20. However, the memory usage is slightly worse, and the backsubstitution is slightly slower — but this should not matter that much because the time of evaluating the factors is the bottleneck in most cases I believe. For the lazy, see below for the @benchmarks of the factorization:
julia> @benchmark (Jf = factorize(J) ;)
BenchmarkTools.Trial:
memory estimate: 61.05 MiB
allocs estimate: 16
--------------
minimum time: 10.883 s (0.01% GC)
median time: 10.883 s (0.01% GC)
mean time: 10.883 s (0.01% GC)
maximum time: 10.883 s (0.01% GC)
--------------
samples: 1
evals/sample: 1
julia> @benchmark (fast_Jf = fast_factorize(J) ;)
BenchmarkTools.Trial:
memory estimate: 91.57 MiB
allocs estimate: 9
--------------
minimum time: 432.454 ms (2.36% GC)
median time: 477.843 ms (2.32% GC)
mean time: 487.862 ms (6.14% GC)
maximum time: 592.935 ms (1.87% GC)
--------------
samples: 11
evals/sample: 1
@briochemc Hey thanks for your solution, it helps me a lot. I patched up ForwardDiff for this functionality, as below:
using LinearAlgebra
using SparseArrays
using ForwardDiff
import Base:\
function \(A::SparseMatrixCSC{ForwardDiff.Dual{T, V, N}, P}, b::AbstractVector{G}) where {T, V, N, P<:Integer, G}
# println("invoked sparse one")
return __FDbackslash(A, b, T, V, N)
end
function __FDbackslash(A, b, T, V, N)
Areal = ForwardDiff.value.(A)
breal = ForwardDiff.value.(b)
outreal = Areal\breal
M = length(outreal)
deriv = zeros(V, M, N)
for i in 1:N
pAi = ForwardDiff.partials.(A, i)
pbi = ForwardDiff.partials.(b, i)
deriv[:, i] = -Areal\(pAi * outreal - pbi)
end
out = Vector{eltype(A)}(undef, M)
for j in eachindex(out)
out[j] = ForwardDiff.Dual{T}(outreal[j], ForwardDiff.Partials(tuple(deriv[j,:]...)))
end
return out
end
I'm also not very familiar with the Dual API in ForwardDiff so this implementation is likely to be inefficient, but it's still better than the dense factorization alternatives. As far as I understand, ForwardDiff mainly works at the scalar level, so this kind of patch is not really desired. ForwardDiff2 is the package for vector-level implementation, and with ChainRules.jl it's much easier to add customized rules, but they are still under heavy development.
Yes I try to follow what's happening with ForwardDiff2.jl :)
FYI I also made DualMatrixTools.jl and HyperDualMatrixTools.jl too a year or so back, if that's any interest to you.
I have a use case for this where I solve a sparse linear system a bunch of times and need to be careful about allocating memory. In case helpful to anyone, I've implemented the code from @FuZhiyu for A\b where both have eltype Dual at https://github.com/magerton/ForwardDiffSparseSolve.jl
See also Sparspak.jl for a take on this...
Thanks, @j-fu!