Symbolics.jl icon indicating copy to clipboard operation
Symbolics.jl copied to clipboard

Incorrect substitution assignment in MTK systems with ShardedForm/MultithreadedForm

Open wsphillips opened this issue 2 years ago • 3 comments

If you have a large MTK system that needs ShardedForm for compilation/performance reasons, it will fail if it contains observed variables that normally get assigned/evaluated in the let block preceding the main function body. This is because the child/sharded functions do not inherit the assignments from the parent scope (iiuc because they are RGFs).

You can reproduce on a smaller system containing observed variables by setting the sharding cutoff to be small (just for MWE purposes):

using ModelingToolkit, Plots, OrdinaryDiffEq, LinearAlgebra
using Symbolics: scalarize

@variables t
D = Differential(t)

function Mass(; name, m = 1.0, xy = [0.0, 0.0], u = [0.0, 0.0])
    ps = @parameters m = m
    sts = @variables pos(t)[1:2]=xy v(t)[1:2]=u
    eqs = scalarize(D.(pos) .~ v)
    ODESystem(eqs, t, [pos..., v...], ps; name)
end

function Spring(; name, k = 1e4, l = 1.0)
    ps = @parameters k=k l=l
    @variables x(t), dir(t)[1:2]
    ODESystem(Equation[], t, [x, dir...], ps; name)
end

function connect_spring(spring, a, b)
    [spring.x ~ norm(scalarize(a .- b))
        scalarize(spring.dir .~ scalarize(a .- b))]
end

function spring_force(spring)
    -spring.k .* scalarize(spring.dir) .* (spring.x - spring.l) ./ spring.x
end

m = 1.0
xy = [1.0, -1.0]
k = 1e4
l = 1.0
center = [0.0, 0.0]
g = [0.0, -9.81]
@named mass = Mass(m = m, xy = xy)
@named spring = Spring(k = k, l = l)

eqs = [connect_spring(spring, mass.pos, center)
    scalarize(D.(mass.v) .~ spring_force(spring) / mass.m .+ g)]

@named _model = ODESystem(eqs, t, [spring.x; spring.dir; mass.pos], [])
@named model = compose(_model, mass, spring)
sys = structural_simplify(model)

prob = ODEProblem(sys, [], (0.0, 3.0); parallel = Symbolics.ShardedForm(2,2))
sol = solve(prob, Rosenbrock23())

There are 2 solutions I can think of (but of course you all may have a better one):

  1. Easiest would be to recursively apply the post process let block to all child functions. This is probably a quicker fix that could be done on the Symbolics.jl side, but would lead to redundant computation of the observed variables.

  2. A more efficient (but potentially complicated) solution would be to gather the values of the variables assigned in the top-level function after they are evaluated and then pass them as an extra argument to child functions, reassigning in the scope of the child function.

Since lack of sharding greatly limits scaling in Conductor.jl, I would be willing to help out on this but I would need advice on the preferred solution.

wsphillips avatar Jul 21 '23 23:07 wsphillips

I noticed that this seems to crop up in other situations unrelated to just sharding. Basically anytime a nested RGF is built (see the code generated in https://github.com/SciML/ModelingToolkit.jl/issues/2173#issuecomment-1567090435 for a ODAEProblem)

wsphillips avatar Jul 22 '23 00:07 wsphillips

@shashi

wsphillips avatar Jul 22 '23 00:07 wsphillips

@YingboMa @shashi is there a hard requirement for postprocess_fbody to be a function on expressions for MTK models? Right now MTK defines constants and substitutions by using SymbolicUtils.Code types (e.g. with Let,Assignment) manually before calling into build_function. In principle, this could just as easily be done with symbolic equations, which would make tracking scope easier...?

wsphillips avatar Jul 29 '23 16:07 wsphillips