ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

Cannot compute gradient with `ArrayPartition` that holds containers of different types

Open bvdmitri opened this issue 1 year ago • 3 comments

ArrayPartition is a useful structure to concatenate arrays of different types. The type is defined in SciML/RecursiveArrayTools.jl

ArrayPartitions are also used in many places in SciML ecosystem, but also in other places like Manopt.jl. It appears, though, that if ArrayPartition references two containers, one of eltype is Float64 and another one is Int64, the gradient from ForwardDiff fails.

MWE is:

julia> using ForwardDiff, RecursiveArrayTools

julia> v = [ 0.0, 1 ]
2-element Vector{Float64}:
 0.0
 1.0

julia> f(v) = sum(v)
f (generic function with 1 method)

julia> ForwardDiff.gradient(f, [ 0.0, 1 ])
2-element Vector{Float64}:
 1.0
 1.0

julia> ForwardDiff.gradient(f, ArrayPartition([ 0.0 ], [ 1 ]))
ERROR: MethodError: no method matching ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}(::Int64, ::ForwardDiff.Partials{2, Float64})

Closest candidates are:
  ForwardDiff.Dual{T, V, N}(::Number) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:78
  ForwardDiff.Dual{T, V, N}(::Any) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:77
  ForwardDiff.Dual{T, V, N}(::V, ::ForwardDiff.Partials{N, V}) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:17

Stacktrace:
  [1] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [2] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [3] getindex
    @ ./broadcast.jl:636 [inlined]
  [4] macro expansion
    @ ./broadcast.jl:1004 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] copyto!
    @ ./broadcast.jl:1003 [inlined]
  [7] copyto!
    @ ./broadcast.jl:956 [inlined]
  [8] materialize!
    @ ./broadcast.jl:914 [inlined]
  [9] materialize!
    @ ./broadcast.jl:911 [inlined]
 [10] seed!(duals::ArrayPartition{…}, x::ArrayPartition{…}, seeds::Tuple{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:52
 [11] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:23 [inlined]

bvdmitri avatar Jul 25 '24 16:07 bvdmitri

I think the reason ForwardDiff is confused because this ArrayPartition declares itself to have Float64 elements, but in fact returns an Int sometimes:

julia> ap = ArrayPartition([ 0.0 ], [ 1 ])
([0.0], [1])

julia> size(ap), axes(ap), eltype(ap), supertype(typeof(ap))
((2,), (Base.OneTo(2),), Float64, AbstractVector{Float64})

julia> ap[1]  # no surprise
0.0

julia> ap[2]  # very surprising
1

julia> ap[1:2]  # here you get the expected eltype
2-element Vector{Float64}:
 0.0
 1.0

julia> ap[2,1:1]
1-element Vector{Float64}:
 1.0

The usual way to encode that elements of a vector have different types is to have an abstract eltype, which it seems ForwardDiff is able to handle.

(Note that the other example above constructs a Vector{Float64}, promoting to 1.0 when making the array.)

julia> x64 = [ 0.0, 1 ]  # this promotes on construction
2-element Vector{Float64}:
 0.0
 1.0

julia> ForwardDiff.gradient(f, x64)  # as in question
2-element Vector{Float64}:
 1.0
 1.0

julia> xabs = Real[ 0.0, 1 ]  # abstract eltype, could also use  xabs = Union{Float64, Int}[ 0.0, 1 ] 
2-element Vector{Real}:
 0.0
 1

julia> ForwardDiff.gradient(f, xabs)  # also OK, ForwardDiff not confused
2-element Vector{Float64}:
 1.0
 1.0

Fixing ArrayPartition to declare its eltype accurately would be the obvious fix here, and would probably avoid many other weird edge cases. (Or else fixing its getindex definition to convert to the declared eltype.) Although I'm sure there's going to be some reason that consistency is inconvenient for something.

It's possible that ForwardDiff could be made more robust to misleading signals. For instance making the ForwardDiff.Dual constructor called above promote its first argument might work here?

mcabbott avatar Jul 25 '24 16:07 mcabbott

Opened an issue in RecursiveArrayTools as well, though, I have a feeling that this behaviour might be by design.

For instance making the ForwardDiff.Dual constructor called above promote its first argument might work here?

For me that would be an obvious fix, that shouldn't break anything, right?

bvdmitri avatar Jul 25 '24 18:07 bvdmitri

I have a feeling that this behaviour might be by design.

That just seems broken though. If fixing RecursiveArrayTools also fixes this then I don't think anything should be done here.

KristofferC avatar Jul 26 '24 11:07 KristofferC