LinearSolve.jl icon indicating copy to clipboard operation
LinearSolve.jl copied to clipboard

linearsolve using the transpose of the factorization

Open MKAbdElrahman opened this issue 2 years ago • 10 comments

For defining rrules, I need to make use of the factorization in the forward pass. The rrules is also a linearsolve with the transpose of the linear system, how to avoid defining a new LinearProblem ?

This my current code, the LinearSolve interface with make it make use other algorithms and more clean


function ChainRulesCore.rrule(::typeof(efield), sim::Simulation,ϵ)
    linsys = ConstructLinearSystem(sim, ϵ )
    A ,b = linsys.A , linsys.b
    x = similar(b); x_adj = similar(b)
    F = lu(A)
    LinearAlgebra.ldiv!(x, F, b)
    sim.E = x
    function efield_pullback(ȳ)
        LinearAlgebra.ldiv!(x_adj, transpose(F), conj.(ȳ))
        f̄ = NoTangent()
        f̄oo =  NoTangent()
        ϵoo =  real(x .* x_adj)
        return f̄, f̄oo , ϵoo
    end
    return x, efield_pullback
end


Thanks!

MKAbdElrahman avatar Jan 14 '22 22:01 MKAbdElrahman

oh yes, this would be good to add. Now it might be hard for Krylov methods because it will require that the operator has an adjoint defined, and many times they might be defined Jacobian-free, that will just throw an appropriate error if it's not defined.

I think the right thing to do would be to add a boolean to the LinearProblem transpose=false by default, and then we can setup the algorithms to specialize on this. Many, such as Pardiso, have a lazy transpose option in the solver so we'd use that bool to flip the option.

ChrisRackauckas avatar Jan 15 '22 11:01 ChrisRackauckas

@MKAbdElrahman did the issue get resolved?

vpuri3 avatar Apr 01 '22 15:04 vpuri3

It did not

ChrisRackauckas avatar Apr 01 '22 18:04 ChrisRackauckas

In LinearCache, we can add a flag symmetric defaulting to false and a field Atransp defaulting to Adjoint(A), the lazy wrapper. And then to solve the adjoint problem via

solve(prob, alg, adjoint=true)

vpuri3 avatar Apr 15 '22 18:04 vpuri3

maybe change terminology to make it more clear for complex eltypes

  • symmetric --> selfadjoint
  • Atransp --> Aadjoint

vpuri3 avatar Apr 15 '22 18:04 vpuri3

@ChrisRackauckas is Adjoint(::DiffEqArrayOperator) defined?

vpuri3 avatar Apr 15 '22 18:04 vpuri3

No. What needs to be done is the operator interface documentation should get a note about adjoint(::SciMLOperator) as being a part of the (optional) interface, required for reverse mode automatic differentiation. Then, adjoint(::DiffEqArrayOperator) should be added to SciMLBase by just taking the adjoint of the internal array. Many other operators would have to be handled though, but that can be done over time.

ChrisRackauckas avatar Apr 16 '22 11:04 ChrisRackauckas

sounds good, is solve(prob, alg, adjoint=true)the standard interface for solving a joint problem in diffeq?

vpuri3 avatar Apr 16 '22 20:04 vpuri3

Not in solve, it's not a solve level thing. It changes the result, so it's not a solver control but instead something to do with the problem. The real question is whether the answer is for it to just be LinearProblem(A',b) or LinearProblem(A,b,adjoint=true). The advantage of the latter is that it makes it easier to do general implementations of adjoints for operators which only define A*x, since then it can use AbstractDifferentiation.jl behind the scenes. However, you don't have to do that if adjoint(A) already exists. So we need has_adjoint(A::AbstractSciMLOperator) and adjoint(A::AbstractSciMLOperator) in the interface in order to write it effectively. Also, I'm not so sure we want the AD overloads though, just because of the dependencies that would give. So it would need to do something like DiffEqSensitivity, where by default an error is raised saying to using DiffEqSensitivity to get the adjoint overloads, unless has_adjoint(A). This complexity is eliminated if you just assume LinearProblem(A',b) is the way to do it.

ChrisRackauckas avatar Apr 18 '22 00:04 ChrisRackauckas

LinearProblem(prob, alg, adjoint=true) would be good for factorizations since adjoint of LinearAlgebra.Factorization is defined. For iterative methods, etc we can fallback to adjoint wrapper with a has_adjoint check

vpuri3 avatar Apr 18 '22 14:04 vpuri3