DiffEqBase.jl
DiffEqBase.jl copied to clipboard
StackOverflowError when adding an structure inside DEDataVector in a Module
Hi! With the following code:
module MyTest2
using ArrayInterface
export A
struct A{T}
a::T
end
ArrayInterface.ismutable(::Type{<:A}) = true
using RecursiveArrayTools
using OrdinaryDiffEq
using DiffEqBase
export Workspace
mutable struct Workspace{T} <: DEDataVector{T}
x::Vector{T}
a::A{T}
end
export test
function test()
w1 = Workspace(zeros(3), A(0.0))
w2 = Workspace(zeros(3), A(0.0))
return RecursiveArrayTools.recursivecopy!(w1,w2)
end
end
I am getting the following error when calling the function test():
julia> test()
ERROR: StackOverflowError:
Stacktrace:
[1] ismutable(::Type{T} where T) at /Users/ronan.arraes/.julia/packages/ArrayInterface/YFV07/src/ArrayInterface.jl:19 (repeats 79978 times)
I think it is related to the @generated function recursivecopy!. Because, this code works outside a module:
using ArrayInterface
struct A{T}
a::T
end
ArrayInterface.ismutable(::Type{<:A}) = false
using RecursiveArrayTools
using OrdinaryDiffEq
using DiffEqBase
mutable struct Workspace{T} <: DEDataVector{T}
x::Vector{T}
a::A{T}
end
function test()
w1 = Workspace(zeros(3), A(0.0))
w2 = Workspace(zeros(3), A(0.0))
return RecursiveArrayTools.recursivecopy!(w1,w2)
end
julia> test()
3-element Workspace{Float64}:
0.0
0.0
0.0
In fact, this is a strange behavior related to precompilation. Starting Julia with julia --compiled-modules=no does not show the problem. Is there any workaround?
The workaround is to create a generated function inside MyTest2 more specialized than the one in DiffEqBase:
@generated function RecursiveArrayTools.recursivecopy!(dest::Workspace{T}, src::Workspace{T}) where T
fields = fieldnames(src)
expressions = Vector{Expr}(undef, length(fields))
@inbounds for i = 1:length(fields)
f = fields[i]
Tf = src.types[i]
qf = Meta.quot(f)
if ArrayInterface.ismutable(Tf)
expressions[i] = :( dest.$f = getfield( src, $qf ) )
elseif Tf <: AbstractArray
expressions[i] = :( recursivecopy!(dest.$f, getfield( src, $qf ) ) )
else
expressions[i] = :( dest.$f = deepcopy( getfield( src, $qf ) ) )
end
end
:($(expressions...); dest)
end
EDIT: For the general case, the same must be done for DiffEqBase.copy_fields.
EDIT2: After thinking for a while, I have absolutely no idea how this can be fix. In fact, I think it cannot be fixed given the way Julia handles generated functions. If I am right, we need to document that, inside a module, we must have those definitions.
EDIT3: Sorry for the long comments and many edits, I am in a middle of a brainstorm here :D @ChrisRackauckas what do you think to have a macro like @DataArrayStruct or something that will make a structure compatible to the data array interface. This macro will make the structure a subtype of DEDataVector and will declare the required generated function. Hence, the definitions of recursivecopy! and copy_fields! will move to the user code. This will eliminate the problem described here and the ArrayInterface can still be used as is.
I'm wondering, couldn't we just define recursivecopy! as a regular function by just looping over the fields? Intuitively it shouldn't affect the performance too drastically but maybe I'm wrong. One could even add a function barrier if that would help the compiler.
Well, that will certainly solve everything in the most easy way! In fact, if anyone wants more performance, they can always overload those methods with a more specialize signature.
Looping over the fields won't be type stable. For small arrays that could be pretty bad for performance. For large DEDataArrays, sure it's fine, but smaller ones is an issue.
Wouldn't it help if we add a function barrier and replace the if branches with a function call?
I cannot reproduce the stack overflow error.
I cannot reproduce the stack overflow error.
Did you paste the code inside a module? I tried here with a clean env and I reproduce it every time.
Hi!
Did anyone think about a better workaround than that one I posted?