ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Can rules make decisions based on which primal method is used?

Open sethaxen opened this issue 5 years ago • 3 comments

Can we safely detect if a primal function is going to hit a specific default and use that to change the logic in the frule or rrule?

For example, the rrule for det(A) calls inv(A). If we know that the primal function that would be hit is the det(A::AbstractMatrix) definition in generic.jl, then we know that the primal is using the lu decomposition to compute the determinant, and we can reuse that to compute the inverse faster. But if a specialized primal method was being hit, then we probably just want to call the primal and invert separately since that primal is probably more efficient for that type than lu.

sethaxen avatar Jul 16 '20 00:07 sethaxen

e.g. something like this:

function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
    F = if which(det, Tuple{typeof(x)}) === which(det, Tuple{AbstractMatrix})
        lu(x; check = false)
    else
        x
    end
    Ω = det(F)
    function det_pullback(ΔΩ)
        return NO_FIELDS, Ω * ΔΩ * inv(F)'
    end
    return Ω, det_pullback
end

This is type-unstable though, and the rrule is 10x slower, though the pullback is a little faster.

sethaxen avatar Jul 16 '20 00:07 sethaxen

It seems to me that the "correct" way to distinguish these cases is just standard dispatch:

function rrule(::typeof(det), x::AbstractMatrix)
    F = lu(x; check = false)
    Ω = det(F)
    function det_pullback(ΔΩ)
        return NO_FIELDS, Ω * ΔΩ * inv(F)'
    end
    return Ω, det_pullback
end

function rrule(::typeof(det), x::YourSpecialMatrixType)
    # Do whatever you have to do
end

The rationale would be that if "det via LU" is a good enough fallback for the primal function, then it should also be good enough for the derivative.

ettersi avatar Jul 16 '20 01:07 ettersi

Maybe related https://github.com/JuliaDiff/ChainRulesCore.jl/issues/155

nickrobinson251 avatar Jul 16 '20 08:07 nickrobinson251