ChainRules.jl
ChainRules.jl copied to clipboard
Can rules make decisions based on which primal method is used?
Can we safely detect if a primal function is going to hit a specific default and use that to change the logic in the frule or rrule?
For example, the rrule for det(A) calls inv(A). If we know that the primal function that would be hit is the det(A::AbstractMatrix) definition in generic.jl, then we know that the primal is using the lu decomposition to compute the determinant, and we can reuse that to compute the inverse faster. But if a specialized primal method was being hit, then we probably just want to call the primal and invert separately since that primal is probably more efficient for that type than lu.
e.g. something like this:
function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
F = if which(det, Tuple{typeof(x)}) === which(det, Tuple{AbstractMatrix})
lu(x; check = false)
else
x
end
Ω = det(F)
function det_pullback(ΔΩ)
return NO_FIELDS, Ω * ΔΩ * inv(F)'
end
return Ω, det_pullback
end
This is type-unstable though, and the rrule is 10x slower, though the pullback is a little faster.
It seems to me that the "correct" way to distinguish these cases is just standard dispatch:
function rrule(::typeof(det), x::AbstractMatrix)
F = lu(x; check = false)
Ω = det(F)
function det_pullback(ΔΩ)
return NO_FIELDS, Ω * ΔΩ * inv(F)'
end
return Ω, det_pullback
end
function rrule(::typeof(det), x::YourSpecialMatrixType)
# Do whatever you have to do
end
The rationale would be that if "det via LU" is a good enough fallback for the primal function, then it should also be good enough for the derivative.
Maybe related https://github.com/JuliaDiff/ChainRulesCore.jl/issues/155