`ProjectTo(::AbstractArray)` does not infer
e.g. see
julia> @inferred ProjectTo(rand(3, 3))(Diagonal(rand(3)))
ERROR: return type Diagonal{Float64, Vector{Float64}} does not match inferred return type Union{Base.ReshapedArray{Float64, 2, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}}
the reason for the failure is
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/6efb2d258dc225a50d961fb17fa5ba88d7296ce7/src/projection.jl#L174-L183
which decides whether to dy is dx or a reshape(dx, ...) based on the values of the axes.
We could instead do
dy = if length(axes(dx)) == length(project.axes)
axes(dx) == project.axes || throw(_projection_mismatch(project.axes, size(dx)))
dx
else
for d in 1:max(M, length(project.axes))
if size(dx, d) != length(get(project.axes, d, 1))
throw(_projection_mismatch(project.axes, size(dx)))
end
end
reshape(dx, project.axes)
end
which does infer, but throws an error for the arrays that need to be reshaped. In practice, this only means that only
poffv = ProjectTo(OffsetArray(rand(3), 0:2))
@test axes(poffv([1, 2, 3])) == (0:2,)
test fails in ChainRulesCore and all the rrule inference tests (but not the one where inplaceablethunk inference fails) in https://github.com/JuliaDiff/ChainRules.jl/pull/459#issuecomment-884757000 are fixed.
do we want this tradeoff or not?
Lyndon says "It is a union of <=4 elements. It is fine."
My reply: I found a “solution” with a small tradeoff.
It is a union of <=4 elements. It is fine.
I’ve heard small unions are fine, but never understood this really. Does this mean that a thing happens (JIT?) for each of the elements, and not all possible elements? My concern (that may be unjustified) is that having a few of these small unions through the program means that the inferred union snowballs when functions are called
i.e. something like
julia> inner(x) = x > 0 ? 3.0 : 3
inner (generic function with 1 method)
julia> wrapper(x::Int) = x > 0 ? rand(2, 3) : Diagonal(rand(2))
wrapper (generic function with 1 method)
julia> wrapper(x::Float64) = x > 0 ? "hello" : :world
wrapper (generic function with 2 methods)
julia> together(x) = wrapper(inner(x))
together (generic function with 1 method)
julia> @code_warntype together(2.0)
Variables
#self#::Core.Const(together)
x::Float64
Body::Any
1 ─ %1 = Main.inner(x)::Union{Float64, Int64}
│ %2 = Main.wrapper(%1)::Any
└── return %2
where the result is actually Any rather than a Union of four types
I wonder if we actually only want the reshape for if the dx is a Array.
That is the case that we know we will never get ReshapedArray out, but will just get another Array that reference the same memory.
The ReshapedArray is probably not a friend, it is still a view but idk how consistently it is used and how nice it plays with BLAS etc.
Maybe not having the reshape always work though would defeat utility of the feature
Surely this can be made to infer.
The trick if length(axes(dx)) == length(project.axes) is clever, but what I think it will miss is that the reshape at present restores OffsetArrays, which tend to go missing e.g. hcat(OffsetArray(rand(3), 0:2)) isa Matrix.