ChainRules.jl
ChainRules.jl copied to clipboard
Representing (co)tangents of structured matrices
This follows up a discussion on Slack. @sethaxen:
The pushforward of a function that produces a
Symmetricmatrix should also produce aSymmetricmatrix. Is it also true that the pullback of a function that takes aSymmetricmatrix should produce aSymmetricmatrix?
@ChrisRackauckas:
yes
@sethaxen:
Okay, then what about when the input is
Symmetric{<:Real}, but the pullback is passed anAbstractMatrix{<:Complex}? Should the pullback then produce aHermitian{<:Complex}or aSymmetric{<:Real}, or something else?
@willtebbutt:
I’ve been wondering about this. My old answer would have been (assuming that the symmetric matrix looks at the upper triangular of the underlying matrix): add the lower triangle to the upper triangle (don’t touch the diagonal), and represent the cotangent via the upper triangle, the rationale being that the lower triangle of the data backing the symmetric matrix never gets used. My newer thinking is that passing an asymmetric cotangent (I’m unclear whether we want represent this literally as a Symmetric or a
Composite{<:Symmetric}) to the pullback (e.g. a plain oldAbstractMatrix) should just error. My reasoning for this is that if you’ve somehow managed to get an asymmetric cotangent, then the downstream operations have been implemented incorrectly (e.g.getindex(::Symmetric, ::Int, ::Int)must be wrong and treating aSymmetricmatrix as a generalMatrixor something). Nothing in the Julia ecosystem currently implements this AFAICT though.
There are several points for discussion here. Under this perspective (@willtebbutt's newer thinking, which I tend to agree with), if downstream rules and AD have done everything right, then the pullback for Y = Symmetric(A) should always receive an object ΔY with a data field (either Composite{<:Symmetric} or Symmetric), and its pullback should just be ΔY.data. If the pullback is passed an UpperTriangular, LowerTriangular or Diagonal matrix, as in the current rrule implementation and as in #178, then something is wrong somewhere else. Moreover, we don't need to do anything to data, such as zeroing a triangle, because if that triangle should be zeroed, it is already zeroed in ΔY.data (e.g. if Matrix(Y) was called, then the unused triangle was overwritten by the used triangle in the forward pass. Consequently, a correctly implemented Matrix_pullback will zero out the unused triangle in the cotangent vector before wrapping with Symmetric or Composite{Symmetric}). Thus a custom rule is probably not even necessary for the Symmetric constructor. Have I got this right?
One thing that worries me is e.g. what if a user defined an override like (::Diagonal * ::MyDiagonal)::MyDiagonal. This would trigger our general rrule. The pullback would expect an AbstractMatrix, but it will be passed a Composite{MyDiagonal} (or MyDiagonal). To do the right thing, it would need to produce a Composite{Diagonal} (or Diagonal) and a Composite{MyDiagonal} (or MyDiagonal). So how should we define generic rules that handle such cases?
Also, should we adopt a convention regarding whether the (co)tangent of structured matrices should be matrices or Composite? A point for the former is that we can automatically multiply them by other matrices and add them, and things should just work. A point for the latter is that in many cases, the (co)tangent doesn't share the same structure as the primal (e.g. the (co)tangent of a unitary matrix is a unitary transformation of a skew-Hermitian matrix). A compromise is a utility method that in most cases is a no-op but is meant to convert from a composite type to a primal when possible.
Relates https://github.com/JuliaDiff/ChainRules.jl/issues/52
cc @mcabbott @oxinabox
i.e. this would be the entirety of the frule and rrule for Symmetric and Hermitian:
function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
Y = T(A, uplo)
return Y, Composite{typeof(Y)}(data = ΔA)
end
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
HermOrSym_pullback(ΔY) = Composite{typeof(A)}(ΔY.data)
return T(A, uplo), HermOrSym_pullback
end
What these don't cover are the cases where one calls Hermitian(::Symmetric) or Symmetric(::Symmetric), which are essentially no-ops. In those cases, this fails, since it incorrectly wraps the cotangent of the data field. Hence why it's probably best to not have a rule for these at all.
Here's an example of a function whose first operation is Symmetric but for which finite differencing does not produce a triangular matrix:
julia> using FiniteDifferences
julia> x = collect(reshape(1.0:9.0, 3, 3))
3×3 Array{Float64,2}:
1.0 4.0 7.0
2.0 5.0 8.0
3.0 6.0 9.0
julia> function g(x)
s = Symmetric(x)
d = s.data
h = s * d
z = h[1, 2]
return z
end
g (generic function with 1 method)
julia> only(j′vp(central_fdm(5, 1), g, 1.0, x))
3×3 Array{Float64,2}:
4.0 6.0 6.0
0.0 4.0 0.0
0.0 7.0 0.0
Yeah, it's just by virtue of the accessing of the data field, which I guess you're not really meant to access. Of course you can access it and people will, so we have to support accessing it, but it'll definitely break our Symmetric assumptions.
Yeah, so we have two cases:
- A rule consumes a
Symmetric, using it by accessing the underlying.dataproperty. This rules' pullbacks will produce a validComposite{<:Symmetric}cotangent with a.datafield containing a matrix that can be asymmetric. - A general matrix rule is triggered, consuming a
Symmetric. Its pullback will in general produce anAbstractMatrixcotangent, though someSymmetric-specific functions' pullbacks might produce aSymmetric. Turning theCompositeinto aSymmetricwill destroy the asymmetric part of the cotangent.
Because the Composite can be asymmetric, it should be preferred for the cotangents.
But to pull these two types of cotangents through the pullback of the Symmetric constructor, we need a rule for combining these.
Here's an idea:
# when a matrix sensitivity is added to a `Composite{<:Symmetric}`, first pull it back
# to the correct triangle of `.data`, then combine.
function Base.:+(a::P, b::Composite{P}) where {P<:Symmetric}
return Composite{P}(data=_symmetric_back(a)) + b
end
function Base.:+(a::AbstractMatrix, b::Composite{P}) where {P<:Symmetric}
return Composite{P}(data=_symmetric_back(a)) + b
end
function rrule(::Type{<:Symmetric}, A::AbstractMatrix)
function Symmetric_pullback(ȳ)
return (NO_FIELDS, @thunk(_symmetric_back(ȳ)))
end
return Symmetric(A), Symmetric_pullback
end
# If no composites were accumulated, pull back to used triangle and return
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
# If any composites, pullback has already been called, so just return
_symmetric_back(ΔΩ::Composite{<:Symmetric}) = ΔΩ.data
One problem with this approach is that the information about which triangle is storing the data is not contained in the type of Symmetric. So when accumulating an AbstractMatrix and a Composite{<:Symmetric}, one doesn't know which triangle to use. I don't know how to deal with this yet.
Is there anything like a Composite that can store different representations of sensitivities for later accumulation?