ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

`ProjectTo(::AbstractArray)` does not infer

Open mzgubic opened this issue 4 years ago • 3 comments

e.g. see

julia> @inferred ProjectTo(rand(3, 3))(Diagonal(rand(3)))
ERROR: return type Diagonal{Float64, Vector{Float64}} does not match inferred return type Union{Base.ReshapedArray{Float64, 2, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}}

the reason for the failure is https://github.com/JuliaDiff/ChainRulesCore.jl/blob/6efb2d258dc225a50d961fb17fa5ba88d7296ce7/src/projection.jl#L174-L183 which decides whether to dy is dx or a reshape(dx, ...) based on the values of the axes.

We could instead do

    dy = if length(axes(dx)) == length(project.axes)
        axes(dx) == project.axes || throw(_projection_mismatch(project.axes, size(dx)))
        dx
    else
        for d in 1:max(M, length(project.axes))
            if size(dx, d) != length(get(project.axes, d, 1))
                throw(_projection_mismatch(project.axes, size(dx)))
            end
        end
        reshape(dx, project.axes)
    end

which does infer, but throws an error for the arrays that need to be reshaped. In practice, this only means that only

        poffv = ProjectTo(OffsetArray(rand(3), 0:2))
        @test axes(poffv([1, 2, 3])) == (0:2,)

test fails in ChainRulesCore and all the rrule inference tests (but not the one where inplaceablethunk inference fails) in https://github.com/JuliaDiff/ChainRules.jl/pull/459#issuecomment-884757000 are fixed.

do we want this tradeoff or not?

mzgubic avatar Jul 22 '21 12:07 mzgubic

Lyndon says "It is a union of <=4 elements. It is fine."

My reply: I found a “solution” with a small tradeoff.

It is a union of <=4 elements. It is fine.

I’ve heard small unions are fine, but never understood this really. Does this mean that a thing happens (JIT?) for each of the elements, and not all possible elements? My concern (that may be unjustified) is that having a few of these small unions through the program means that the inferred union snowballs when functions are called

i.e. something like

julia> inner(x) = x > 0 ? 3.0 : 3
inner (generic function with 1 method)

julia> wrapper(x::Int) = x > 0 ? rand(2, 3) : Diagonal(rand(2))
wrapper (generic function with 1 method)

julia> wrapper(x::Float64) = x > 0 ? "hello" : :world
wrapper (generic function with 2 methods)

julia> together(x) = wrapper(inner(x))
together (generic function with 1 method)

julia> @code_warntype together(2.0)
Variables
  #self#::Core.Const(together)
  x::Float64
Body::Any
1 ─ %1 = Main.inner(x)::Union{Float64, Int64}
│   %2 = Main.wrapper(%1)::Any
└──      return %2

where the result is actually Any rather than a Union of four types

mzgubic avatar Jul 22 '21 12:07 mzgubic

I wonder if we actually only want the reshape for if the dx is a Array. That is the case that we know we will never get ReshapedArray out, but will just get another Array that reference the same memory. The ReshapedArray is probably not a friend, it is still a view but idk how consistently it is used and how nice it plays with BLAS etc. Maybe not having the reshape always work though would defeat utility of the feature

oxinabox avatar Jul 22 '21 13:07 oxinabox

Surely this can be made to infer.

The trick if length(axes(dx)) == length(project.axes) is clever, but what I think it will miss is that the reshape at present restores OffsetArrays, which tend to go missing e.g. hcat(OffsetArray(rand(3), 0:2)) isa Matrix.

mcabbott avatar Jul 27 '21 13:07 mcabbott