ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Representing (co)tangents of structured matrices

Open sethaxen opened this issue 5 years ago • 4 comments

This follows up a discussion on Slack. @sethaxen:

The pushforward of a function that produces a Symmetric matrix should also produce a Symmetric matrix. Is it also true that the pullback of a function that takes a Symmetric matrix should produce a Symmetric matrix?

@ChrisRackauckas:

yes

@sethaxen:

Okay, then what about when the input is Symmetric{<:Real}, but the pullback is passed an AbstractMatrix{<:Complex}? Should the pullback then produce a Hermitian{<:Complex} or a Symmetric{<:Real}, or something else?

@willtebbutt:

I’ve been wondering about this. My old answer would have been (assuming that the symmetric matrix looks at the upper triangular of the underlying matrix): add the lower triangle to the upper triangle (don’t touch the diagonal), and represent the cotangent via the upper triangle, the rationale being that the lower triangle of the data backing the symmetric matrix never gets used. My newer thinking is that passing an asymmetric cotangent (I’m unclear whether we want represent this literally as a Symmetric or a Composite{<:Symmetric} ) to the pullback (e.g. a plain old AbstractMatrix ) should just error. My reasoning for this is that if you’ve somehow managed to get an asymmetric cotangent, then the downstream operations have been implemented incorrectly (e.g. getindex(::Symmetric, ::Int, ::Int) must be wrong and treating a Symmetric matrix as a general Matrix or something). Nothing in the Julia ecosystem currently implements this AFAICT though.

There are several points for discussion here. Under this perspective (@willtebbutt's newer thinking, which I tend to agree with), if downstream rules and AD have done everything right, then the pullback for Y = Symmetric(A) should always receive an object ΔY with a data field (either Composite{<:Symmetric} or Symmetric), and its pullback should just be ΔY.data. If the pullback is passed an UpperTriangular, LowerTriangular or Diagonal matrix, as in the current rrule implementation and as in #178, then something is wrong somewhere else. Moreover, we don't need to do anything to data, such as zeroing a triangle, because if that triangle should be zeroed, it is already zeroed in ΔY.data (e.g. if Matrix(Y) was called, then the unused triangle was overwritten by the used triangle in the forward pass. Consequently, a correctly implemented Matrix_pullback will zero out the unused triangle in the cotangent vector before wrapping with Symmetric or Composite{Symmetric}). Thus a custom rule is probably not even necessary for the Symmetric constructor. Have I got this right?

One thing that worries me is e.g. what if a user defined an override like (::Diagonal * ::MyDiagonal)::MyDiagonal. This would trigger our general rrule. The pullback would expect an AbstractMatrix, but it will be passed a Composite{MyDiagonal} (or MyDiagonal). To do the right thing, it would need to produce a Composite{Diagonal} (or Diagonal) and a Composite{MyDiagonal} (or MyDiagonal). So how should we define generic rules that handle such cases?

Also, should we adopt a convention regarding whether the (co)tangent of structured matrices should be matrices or Composite? A point for the former is that we can automatically multiply them by other matrices and add them, and things should just work. A point for the latter is that in many cases, the (co)tangent doesn't share the same structure as the primal (e.g. the (co)tangent of a unitary matrix is a unitary transformation of a skew-Hermitian matrix). A compromise is a utility method that in most cases is a no-op but is meant to convert from a composite type to a primal when possible.

Relates https://github.com/JuliaDiff/ChainRules.jl/issues/52

cc @mcabbott @oxinabox

sethaxen avatar May 17 '20 01:05 sethaxen

i.e. this would be the entirety of the frule and rrule for Symmetric and Hermitian:

function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
    Y = T(A, uplo)
    return Y, Composite{typeof(Y)}(data = ΔA)
end

function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
    HermOrSym_pullback(ΔY) = Composite{typeof(A)}(ΔY.data)
    return T(A, uplo), HermOrSym_pullback
end

What these don't cover are the cases where one calls Hermitian(::Symmetric) or Symmetric(::Symmetric), which are essentially no-ops. In those cases, this fails, since it incorrectly wraps the cotangent of the data field. Hence why it's probably best to not have a rule for these at all.

sethaxen avatar May 17 '20 06:05 sethaxen

Here's an example of a function whose first operation is Symmetric but for which finite differencing does not produce a triangular matrix:

julia> using FiniteDifferences

julia> x = collect(reshape(1.0:9.0, 3, 3))
3×3 Array{Float64,2}:
 1.0  4.0  7.0
 2.0  5.0  8.0
 3.0  6.0  9.0

julia> function g(x)
          s = Symmetric(x)
          d = s.data
          h = s * d
          z = h[1, 2]
          return z
       end
g (generic function with 1 method)

julia> only(j′vp(central_fdm(5, 1), g, 1.0, x))
3×3 Array{Float64,2}:
 4.0  6.0  6.0
 0.0  4.0  0.0
 0.0  7.0  0.0

sethaxen avatar May 18 '20 03:05 sethaxen

Yeah, it's just by virtue of the accessing of the data field, which I guess you're not really meant to access. Of course you can access it and people will, so we have to support accessing it, but it'll definitely break our Symmetric assumptions.

willtebbutt avatar May 18 '20 08:05 willtebbutt

Yeah, so we have two cases:

  1. A rule consumes a Symmetric, using it by accessing the underlying .data property. This rules' pullbacks will produce a valid Composite{<:Symmetric} cotangent with a .data field containing a matrix that can be asymmetric.
  2. A general matrix rule is triggered, consuming a Symmetric. Its pullback will in general produce an AbstractMatrix cotangent, though some Symmetric-specific functions' pullbacks might produce a Symmetric. Turning the Composite into a Symmetric will destroy the asymmetric part of the cotangent.

Because the Composite can be asymmetric, it should be preferred for the cotangents. But to pull these two types of cotangents through the pullback of the Symmetric constructor, we need a rule for combining these. Here's an idea:

# when a matrix sensitivity is added to a `Composite{<:Symmetric}`, first pull it back
# to the correct triangle of `.data`, then combine.
function Base.:+(a::P, b::Composite{P}) where {P<:Symmetric}
    return Composite{P}(data=_symmetric_back(a)) + b
end
function Base.:+(a::AbstractMatrix, b::Composite{P}) where {P<:Symmetric}
    return Composite{P}(data=_symmetric_back(a)) + b
end

function rrule(::Type{<:Symmetric}, A::AbstractMatrix)
    function Symmetric_pullback(ȳ)
        return (NO_FIELDS, @thunk(_symmetric_back(ȳ)))
    end
    return Symmetric(A), Symmetric_pullback
end

# If no composites were accumulated, pull back to used triangle and return
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
# If any composites, pullback has already been called, so just return
_symmetric_back(ΔΩ::Composite{<:Symmetric}) = ΔΩ.data

One problem with this approach is that the information about which triangle is storing the data is not contained in the type of Symmetric. So when accumulating an AbstractMatrix and a Composite{<:Symmetric}, one doesn't know which triangle to use. I don't know how to deal with this yet.

Is there anything like a Composite that can store different representations of sensitivities for later accumulation?

sethaxen avatar May 18 '20 11:05 sethaxen