RecursiveArrayTools.jl icon indicating copy to clipboard operation
RecursiveArrayTools.jl copied to clipboard

Issue with OrdinaryDiffEq.jl when using an ArrayPartition of Arrays of StaticArrays

Open jlchan opened this issue 2 years ago • 7 comments

OrdinaryDiffEq.jl has an issue with ArrayPartitions constructed from nested arrays of StaticArrays. The following MWE

using StaticArrays, StructArrays
using RecursiveArrayTools
using OrdinaryDiffEq

u1 = StructArray{SVector{2, Float64}}(ntuple(x -> x * ones(10), 2))
u2 = deepcopy(u1)
u = ArrayPartition(u1, u2)

function rhs!(du, u, p, t) 
    du .= u
end
prob = ODEProblem(rhs!, u, (0, 1.0))
sol = solve(prob, Tsit5())

yields the stacktrace

ERROR: MethodError: Cannot `convert` an object of type Float64 to an object of type SVector{2, Float64}
Closest candidates are:
  convert(::Type{SA}, ::Tuple) where SA<:StaticArray at ~/.julia/packages/StaticArrays/0T5rI/src/convert.jl:171
  convert(::Type{SA}, ::SA) where SA<:StaticArray at ~/.julia/packages/StaticArrays/0T5rI/src/convert.jl:170
  convert(::Type{SA}, ::StaticArray{S}) where {SA<:StaticArray, S<:Tuple} at ~/.julia/packages/StaticArrays/0T5rI/src/convert.jl:164
  ...
Stacktrace:
  [1] maybe_convert_elt(#unused#::Type{SVector{2, Float64}}, vals::Float64)
    @ StructArrays ~/.julia/packages/StructArrays/rICDm/src/utils.jl:197
  [2] setindex!
    @ ~/.julia/packages/StructArrays/rICDm/src/structarray.jl:363 [inlined]
  [3] macro expansion
    @ ./broadcast.jl:961 [inlined]
  [4] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [5] copyto!
    @ ./broadcast.jl:960 [inlined]
  [6] copyto!
    @ ./broadcast.jl:913 [inlined]
  [7] f
    @ ~/.julia/packages/RecursiveArrayTools/YoTgv/src/array_partition.jl:327 [inlined]
  [8] ntuple
    @ ./ntuple.jl:49 [inlined]
  [9] copyto!
    @ ~/.julia/packages/RecursiveArrayTools/YoTgv/src/array_partition.jl:329 [inlined]
 [10] materialize!
    @ ./broadcast.jl:871 [inlined]
 [11] materialize!
    @ ./broadcast.jl:868 [inlined]
 [12] fast_materialize!
    @ ~/.julia/packages/FastBroadcast/fkwMa/src/FastBroadcast.jl:47 [inlined]
 [13] ode_determine_initdt(u0::ArrayPartition{Float64, Tuple{StructVector{SVector{2, Float64}, Tuple{Vector{Float64}, ....

I'm not sure but this seems to be related to the fact that eltype of ArrayPartitions of AbstractArray{<:StaticArray{N,T}} is T rather than StaticArray{N,T}.

jlchan avatar Jun 16 '22 02:06 jlchan

Note to self: the issue disappears when specifying an initial dt and adaptive=false. If an initial dt isn't specified, the error is in https://github.com/SciML/OrdinaryDiffEq.jl/blob/b15f3c53601c274a69216ffea242a46561bbe121/src/initdt.jl#L23

      @.. broadcast=false sk = abstol+internalnorm(u0,t)*reltol

If adaptive=false isn't specified, there's a similar error with fast broadcast over compute_residuals.

Since the eltype of an ArrayPartition of Arrays of SVector is Float64, it triggers https://github.com/SciML/OrdinaryDiffEq.jl/blob/b15f3c53601c274a69216ffea242a46561bbe121/src/initdt.jl#L15 and the broadcast assignment to sk errors.

However, manually defining an ArrayPartition of StructArray{<:SVector} with eltype = SVector leads to conflicting broadcast rules between ArrayPartition and StructArrays.

jlchan avatar Jun 17 '22 20:06 jlchan

Yeah, I think we need to extend the ArrayPartition broadcast overload so that will generate an ArrayPartition{SVector}.

ChrisRackauckas avatar Jun 18 '22 02:06 ChrisRackauckas

Turns out this is just because the constructor used in the example calls recursive_bottom_eltype to determine T. Using the default constructor instead and specifying T manually fixes things, and broadcasting seems to work fine.

using StaticArrays, RecursiveArrayTools
u1 = [SVector{2, Float64}(1,2) for _ in 1:5]
u2 = zeros(SVector{2, Float64}, (2, 2))
args = (u1, u2)
u = ArrayPartition{eltype(u1), typeof(args)}(args)
du = similar(u)
du .= u

I'm OK to close this if you are.

EDIT: nevermind, this doesn't fix the original issue

jlchan avatar Dec 31 '22 20:12 jlchan

So you're saying Arrays of StaticArrays are fine in the differential equation solve? That's what I would have expected since there are tests covering it. I'm not sure what the StructArray is doing in the first example?

ChrisRackauckas avatar Jan 02 '23 05:01 ChrisRackauckas

Oops, I copied and pasted the wrong example code. StructArrays were just in the MWE for similarity to the full application.

Switching to the other constructor fixed an unrelated issue, so scratch the previous comment.

jlchan avatar Jan 02 '23 14:01 jlchan

I'm confused as to what the problem is and whether it's solved 😅

ChrisRackauckas avatar Jan 02 '23 14:01 ChrisRackauckas

Sorry - to clarify, the problem remains. I solved an unrelated issue recently, and it'd been too long since I looked at this and thought I'd solved it too.

jlchan avatar Jan 02 '23 14:01 jlchan