ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
StaticArrays
I think we would generically like StaticArray arguments to have StaticArray tangents. The current behaviour depends on what path you hit:
julia> using StaticArrays, ChainRulesCore
julia> p = ProjectTo(SA[1,2,3]) # has SOneTo
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (SOneTo(3),))
julia> p([4,5,6]) # doesn't reshape
3-element Vector{Float64}:
4.0
5.0
6.0
julia> p(ones(3,1)) # does reshape
3-element SizedVector{3, Float64, Vector{Float64}} with indices SOneTo(3):
1.0
1.0
1.0
If we change this line to test ===, then the first would be like the second:
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/0e560c648ae29bb0860af77f9010763c0cc6fb48/src/projection.jl#L214
Would this have any surprising downsides? It will also improve type-stability of things like this:
julia> p2 = ProjectTo(zeros(1:3))
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (OffsetArrays.IdOffsetRange(values=1:3, indices=1:3),))
julia> @code_warntype p2(ones(3))
Going the other way, if the argument is an ordinary Vector and the tangent is an SVector, then reshape won't do anything:
julia> reshape(SA[1,2,3], axes([4,5,6]))
3-element SVector{3, Int64} with indices SOneTo(3):
1
2
3
This is the case of https://github.com/FluxML/Zygote.jl/issues/1093, I think. A simple example of where the rule accidentally makes a SVector is:
julia> gradient(x -> dot(SA[1,2,3], x), rand(3))[1] # dy = reshape(x .* ΔΩ, axes(y))
3-element SVector{3, Float64} with indices SOneTo(3):
1.0
2.0
3.0
Is it a good idea as a general rule to force those to be converted back to Array? And if so, is there an easy way to implement this without depending on StaticArrays, nor the reverse?
To summaryize my thoughts.
| Primal | Tangent | Projected Tangent | reason |
|---|---|---|---|
| StaticArray | StaticArray | StaticArray | Don't stop being static |
| Not StaticArray | Not StaticArray | Not StaticArray | Don't start being static for no reason |
| StaticArray | Other dense (eg. Array) | Ideally StaticArray | If the primal is the right size for StaticArray then so is the tangent, but this is performance, not correctness |
| StaticArray | Other Non-dense (eg. Diagonal) | Either is good | The sparsity of the tangent is not a important mathematical fact, as it doesn't reflect the sparsity of the primal, so it is just a performance optiomization, being a StaticArray is also just a performance optimization, so either works. |
| Other Dense | StaticArray | Ideally StaticArray | StaticArray is a valid representation of dense, and making it into an Array would do a copy. It seems like we should assume it was given to use because it was optimal. |
| Other Non-dense (eg. Diagonal) | StaticArray | Other Non-dense | It is the point of ProjectTo to prevent losing the structure of the primal when making the tangent |
Here's the behaviour, in the order of this table, with the === change (which only alters 2 of them):
julia> using StaticArrays, ChainRulesCore
julia> ProjectTo(SA[1,2])(SA[3,4]) isa SVector
true
julia> ProjectTo([1,2])([3,4]) isa Vector
true
julia> ProjectTo(SA[1,2])([3,4]) # was a Vector
2-element SizedVector{2, Float64, Vector{Float64}} with indices SOneTo(2):
3.0
4.0
julia> ProjectTo(SA[1 2; 3 4])(Diagonal([5.0, 6.0])) # was unchanged
2×2 SizedMatrix{2, 2, Float64, 2, Base.ReshapedArray{Float64, 2, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}} with indices SOneTo(2)×SOneTo(2):
5.0 0.0
0.0 6.0
julia> ProjectTo([1,2])(SA[3,4])
2-element SVector{2, Float64} with indices SOneTo(2):
3.0
4.0
julia> ProjectTo(Diagonal([1,2]))(SA[3 4; 5 6])
2×2 Diagonal{Float64, SVector{2, Float64}}:
3.0 ⋅
⋅ 6.0