SymbolicUtils.jl icon indicating copy to clipboard operation
SymbolicUtils.jl copied to clipboard

Stack overflow in custom interface

Open MilesCranmer opened this issue 1 year ago • 0 comments

Hey all,

I am trying to build a direct interface in DynamicExpressions.jl to speed up simplification: https://github.com/SymbolicML/DynamicExpressions.jl/pull/42.

I am seeing a stack overflow at the moment, even for null rule sets. I am wondering if this may be because my type is constrained to binary trees ($\text{arity} \in {0, 1, 2}$)? Thus perhaps there is some initial expansion going on that fails.

Here's an example, using this commit: https://github.com/SymbolicML/DynamicExpressions.jl/pull/42/commits/c12d5a7672a7a7898cca106fee1a89a2accfbe80

using SymbolicUtils, DynamicExpressions

operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])

x1 = Node(Float64; feature=1)

expression = x1 + x1

simplify(SelfContainedNode(expression, operators), RuleSet([]))

which triggers the following error:

ERROR: StackOverflowError:
Stacktrace:
     [1] (::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false})(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}})
       @ SymbolicUtils.Rewriters ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:191
     [2] (::SymbolicUtils.Rewriters.PassThrough{SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false}})(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}})
       @ SymbolicUtils.Rewriters ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:188
     [3] iterate
       @ ./generator.jl:47 [inlined]
     [4] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, SymbolicUtils.Rewriters.PassThrough{SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
       @ Base ./array.jl:802
     [5] collect_similar
       @ ./array.jl:711 [inlined]
     [6] map
       @ ./abstractarray.jl:3261 [inlined]
     [7] (::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false})(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}})
       @ SymbolicUtils.Rewriters ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:198
--- the last 6 lines are repeated 12138 more times ---
 [72836] macro expansion
       @ ~/.julia/packages/SymbolicUtils/H684H/src/utils.jl:11 [inlined]
 [72837] (::SymbolicUtils.Rewriters.Fixpoint{SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false}})(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}})
       @ SymbolicUtils.Rewriters ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:122
 [72838] PassThrough
       @ ~/.julia/packages/SymbolicUtils/H684H/src/rewriters.jl:188 [inlined]
 [72839] simplify(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}}; expand::Bool, polynorm::Nothing, threaded::Bool, simplify_fractions::Bool, thread_subtree_cutoff::Int64, rewriter::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false})
       @ SymbolicUtils ~/.julia/packages/SymbolicUtils/H684H/src/simplify.jl:41
 [72840] simplify
       @ ~/.julia/packages/SymbolicUtils/H684H/src/simplify.jl:16 [inlined]
 [72841] simplify(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}}, ctx::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
       @ SymbolicUtils ./deprecated.jl:105
 [72842] simplify(x::SelfContainedNode{Float64, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}, Tuple{}, Tuple{}}}, ctx::SymbolicUtils.Rewriters.Walk{:post, SymbolicUtils.Rewriters.Chain, typeof(similarterm), false})
       @ SymbolicUtils ./deprecated.jl:103
 [72843] #simplify#22
       @ ~/Documents/DynamicExpressions.jl/ext/DynamicExpressionsSymbolicUtilsExt.jl:352 [inlined]

My interface is as follows (full code here):

arity(x::SelfContainedNode) = x.tree.degree
istree(x::SelfContainedNode) = arity(x) > 0
symtype(::S) where {T,S<:SelfContainedNode{T}} = T
function operation(x::SelfContainedNode)
    if arity(x) == 1
        return x.operators.unaops[x.tree.op]
    elseif arity(x) == 2
        return x.operators.binops[x.tree.op]
    else
        error("Unexpected arity $(arity(x)).")
    end
end
function unsorted_arguments(x::S) where {T,S<:SelfContainedNode{T}}
    if arity(x) == 0
        return Any[]
    elseif arity(x) == 1
        return Any[isconstant(x.tree.l) ? x.tree.l.val::T : S(x.tree.l, x.operators)]
    elseif arity(x) == 2
        return Any[
            isconstant(x.tree.l) ? x.tree.l.val::T : S(x.tree.l, x.operators),
            isconstant(x.tree.r) ? x.tree.r.val::T : S(x.tree.r, x.operators),
        ]
    end
end
function arguments(x::S) where {T,S<:SelfContainedNode{T}}
    return unsorted_arguments(x)
end
function similarterm(
    t::S, f::F, args::AbstractArray, symtype=nothing; kws...
)::S where {T,S<:SelfContainedNode{T},F<:Function}
    if length(args) > 2
        l = similarterm(t, f, args[begin:(begin + 1)], symtype; kws...)
        return similarterm(t, f, [l, args[(begin + 2):end]...], symtype; kws...)
    end
    if length(args) == 1
        op_index = mustfindfirst(f, t.operators.unaops)
        new_node = Node(op_index, to_node(T, op_index, args[1]))
        return S(new_node, t.operators)
    elseif length(args) == 2
        op_index = mustfindfirst(f, t.operators.binops)
        new_node = if all(isconstant, args)
            to_node(T, op_index, f(args...))
        else
            Node(op_index, [to_node(T, op_index, arg) for arg in args]...)
        end
        return S(new_node, t.operators)
    else
        error("Unexpected length $(length(args)).")
    end
end

Basically SelfContainedNode stores a Node (binary tree, with 1-ary nodes allowed - see type description here) and OperatorEnum. For x::Node, x.l is the left child, and x.r the right child. x.op indexes the OperatorEnum.

My guess is that the first part of similarterm is triggering infinite recursion:

    if length(args) > 2
        l = similarterm(t, f, args[begin:(begin + 1)], symtype; kws...)
        return similarterm(t, f, [l, args[(begin + 2):end]...], symtype; kws...)
    end

This is required because a Node can only store two children at a time, so we recursively generate the tree here. Perhaps this breaks some assumption in SymbolicUtils.jl?

This might just be incompatible with the package, so feel free to close if so.

MilesCranmer avatar Jun 20 '23 07:06 MilesCranmer