ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

StaticArrays

Open mcabbott opened this issue 4 years ago • 2 comments

I think we would generically like StaticArray arguments to have StaticArray tangents. The current behaviour depends on what path you hit:

julia> using StaticArrays, ChainRulesCore

julia> p = ProjectTo(SA[1,2,3])  # has SOneTo
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (SOneTo(3),))

julia> p([4,5,6])  # doesn't reshape
3-element Vector{Float64}:
 4.0
 5.0
 6.0

julia> p(ones(3,1))  # does reshape
3-element SizedVector{3, Float64, Vector{Float64}} with indices SOneTo(3):
 1.0
 1.0
 1.0

If we change this line to test ===, then the first would be like the second: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/0e560c648ae29bb0860af77f9010763c0cc6fb48/src/projection.jl#L214 Would this have any surprising downsides? It will also improve type-stability of things like this:

julia> p2 = ProjectTo(zeros(1:3))
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (OffsetArrays.IdOffsetRange(values=1:3, indices=1:3),))

julia> @code_warntype p2(ones(3))

Going the other way, if the argument is an ordinary Vector and the tangent is an SVector, then reshape won't do anything:

julia> reshape(SA[1,2,3], axes([4,5,6]))
3-element SVector{3, Int64} with indices SOneTo(3):
 1
 2
 3

This is the case of https://github.com/FluxML/Zygote.jl/issues/1093, I think. A simple example of where the rule accidentally makes a SVector is:

julia> gradient(x -> dot(SA[1,2,3], x), rand(3))[1]  # dy = reshape(x .* ΔΩ, axes(y))
3-element SVector{3, Float64} with indices SOneTo(3):
 1.0
 2.0
 3.0

Is it a good idea as a general rule to force those to be converted back to Array? And if so, is there an easy way to implement this without depending on StaticArrays, nor the reverse?

mcabbott avatar Oct 04 '21 19:10 mcabbott

To summaryize my thoughts.

Primal Tangent Projected Tangent reason
StaticArray StaticArray StaticArray Don't stop being static
Not StaticArray Not StaticArray Not StaticArray Don't start being static for no reason
StaticArray Other dense (eg. Array) Ideally StaticArray If the primal is the right size for StaticArray then so is the tangent, but this is performance, not correctness
StaticArray Other Non-dense (eg. Diagonal) Either is good The sparsity of the tangent is not a important mathematical fact, as it doesn't reflect the sparsity of the primal, so it is just a performance optiomization, being a StaticArray is also just a performance optimization, so either works.
Other Dense StaticArray Ideally StaticArray StaticArray is a valid representation of dense, and making it into an Array would do a copy. It seems like we should assume it was given to use because it was optimal.
Other Non-dense (eg. Diagonal) StaticArray Other Non-dense It is the point of ProjectTo to prevent losing the structure of the primal when making the tangent

oxinabox avatar Oct 06 '21 11:10 oxinabox

Here's the behaviour, in the order of this table, with the === change (which only alters 2 of them):

julia> using StaticArrays, ChainRulesCore

julia> ProjectTo(SA[1,2])(SA[3,4]) isa SVector
true

julia> ProjectTo([1,2])([3,4]) isa Vector
true

julia> ProjectTo(SA[1,2])([3,4])  # was a Vector
2-element SizedVector{2, Float64, Vector{Float64}} with indices SOneTo(2):
 3.0
 4.0

julia> ProjectTo(SA[1 2; 3 4])(Diagonal([5.0, 6.0]))  # was unchanged
2×2 SizedMatrix{2, 2, Float64, 2, Base.ReshapedArray{Float64, 2, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}} with indices SOneTo(2)×SOneTo(2):
 5.0  0.0
 0.0  6.0

julia> ProjectTo([1,2])(SA[3,4])
2-element SVector{2, Float64} with indices SOneTo(2):
 3.0
 4.0

julia> ProjectTo(Diagonal([1,2]))(SA[3 4; 5 6])
2×2 Diagonal{Float64, SVector{2, Float64}}:
 3.0   ⋅ 
  ⋅   6.0

mcabbott avatar Oct 06 '21 12:10 mcabbott