ChainRules.jl
ChainRules.jl copied to clipboard
Make sensitivities for structured matrix arguments structured
Quoting Will in #29:
FWIW, the other thing to think about is what is actually happening computationally under the hood. Ultimately the
Diagonalmatrix type doesn't use any off-diagonal elements when used in e.g. a matrix-matrix multiply - theDiagonaltype simply doesn't allow you to have non-zero off-diagonal elements, so it's a slightly odd question to ask what happens if you perturb the off-diagonals by an infinitesimal amount (i.e. compute the gradient w.r.t. them).It's this slightly weird situation in which thinking about a
Diagonalmatrix as a regular dense matrix that happens to contain zeros on its off-diagonals isn't really faithful to the semantics of the type (not sure if I've really phrased that correctly, but hopefully the gist is clear)
Since (say) *(x::Diagonal, y::AbstractVector) in LinearAlgebra is defined in terms of combination of operations whose rules are already defined, would just "nulling" automatically handle the situation?
frule(::typeof(*), ::Diagonal, ::AbstractVector) = nothing
rrule(::typeof(*), ::Diagonal, ::AbstractVector) = nothing
Ref a similar question I posted in Zygote: https://github.com/FluxML/Zygote.jl/issues/316
I think "disabling" the rules by returnuing nothing would do it, assuming inside an AD framework. Better would be to not though and just define the correct answer directly.
I think there are too many specializations in LinearAlgebra (and likely in many other in the wild). But maybe not so bad at least as a "midterm" solution? Looking at how Zygote handles broadcasting, I have a feeling that it'll take some time to support broadcasting properly...
If ChainRules.jl's stance is "rules specializations are welcome" I can make a PR to implement https://github.com/FluxML/Zygote.jl/issues/316 in ChainRules.jl (after the overhaul #91).
I think there are too many specializations in LinearAlgebra (and likely in many other in the wild). But maybe not so bad at least as a "midterm" solution?
yeah, I think bailing out sometimes is the right answer for ChainRules, prob with a link to a TODO issue for these cases.
If ChainRules.jl's stance is "rules specializations are welcome" I can make a PR to implement FluxML/Zygote.jl#316 in ChainRules.jl (after the overhaul #91).
Yeah, that would be great.
(Eventually we will have a @rrule or something that will have close-to the same surface semantics as @adjoint see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/44, at that point those will be copy-paste able, no promise when though)