Symbolics.jl icon indicating copy to clipboard operation
Symbolics.jl copied to clipboard

Re-Introduce `ShardedForm` for large expressions

Open lassepe opened this issue 2 years ago • 5 comments

The removal of ShardedForm for large arrays in https://github.com/JuliaSymbolics/Symbolics.jl/commit/616ef521b93d2b521f6954e94e570f9f8af3a464 has caused a major performance regression in one of my projects. Compilation time is now about 10x longer. Beyond that, previously, calling the resulting function from 23 distributed workers in parallel worked just fine. Now, it quickly runs out of memory on a machine with 64GB of ram; I guess because all workers have to compile the function upon first call and compilation is more memory intensive for the SerialForm.

Is there a way to default again to ShardedForms for large functions or is this a fundamental limitation of RuntimeGeneratedFunctions.jl?

lassepe avatar Jul 26 '23 13:07 lassepe

Someone just needs to fix it. We cannot default to it if it's not correct. If you're willing to fix its dependency analysis then we'd be happy to re-enable it.

ChrisRackauckas avatar Aug 08 '23 23:08 ChrisRackauckas

I am happy to take a look. I have little exposure in this area; so no idea if I can be of actual use here. Are there any more pointers/evidence of what exactly seems to be the issue?

lassepe avatar Aug 09 '23 08:08 lassepe

The observed equations are not appended to the front of the sharded equations so it errors with any observables. At least that's what someone mentioned to me 2 weeks ago (@shashi?)

ChrisRackauckas avatar Aug 13 '23 00:08 ChrisRackauckas

I would like to see a reproducer from you guys for the multithreading deadlock. The problem might have gone away now. @wsphillips was saying it gave wrong answers. But just from the code, I don't see a chance of dead lock or race conditions, unless you are passing in indices into build_function with repeated indices, in which case, even the serial version would be wrong.

shashi avatar Aug 16 '23 19:08 shashi

The fix from @shashi solves the observable scoping issue. But when I tested it with some Conductor.jl models it returns solutions that are wrong if in MultithreadedForm() (serially executed sharded form is fine).

You can use this MWE I adapted from the MTK docs to reproduce:

Edit: using Shashi's opaque closure PR

using ModelingToolkit, Plots, OrdinaryDiffEq, LinearAlgebra
using Symbolics: scalarize

@variables t
D = Differential(t)

function Mass(; name, m = 1.0, xy = [0.0, 0.0], u = [0.0, 0.0])
    ps = @parameters m = m
    sts = @variables pos(t)[1:2]=xy v(t)[1:2]=u
    eqs = scalarize(D.(pos) .~ v)
    ODESystem(eqs, t, [pos..., v...], ps; name)
end

function Spring(; name, k = 1e4, l = 1.0)
    ps = @parameters k=k l=l
    @variables x(t), dir(t)[1:2]
    ODESystem(Equation[], t, [x, dir...], ps; name)
end

function connect_spring(spring, a, b)
    [spring.x ~ norm(scalarize(a .- b))
        scalarize(spring.dir .~ scalarize(a .- b))]
end

function spring_force(spring)
    -spring.k .* scalarize(spring.dir) .* (spring.x - spring.l) ./ spring.x
end

m = 1.0
xy = [1.0, -1.0]
k = 1e4
l = 1.0
center = [0.0, 0.0]
g = [0.0, -9.81]
@named mass = Mass(m = m, xy = xy)
@named spring = Spring(k = k, l = l)

eqs = [connect_spring(spring, mass.pos, center)
    scalarize(D.(mass.v) .~ spring_force(spring) / mass.m .+ g)]

@named _model = ODESystem(eqs, t, [spring.x; spring.dir; mass.pos], [])
@named model = compose(_model, mass, spring)
sys = structural_simplify(model)

# if parallel = `Symbolics.ShardedForm(2,2)` or default serial form this works fine
prob = ODEProblem(sys, [], (0.0, 3.0); parallel = Symbolics.MultithreadedForm(2,2))
sol = solve(prob, Rosenbrock23())
plot(sol) # no oscillations/wrong output when multithreaded

wsphillips avatar Aug 16 '23 20:08 wsphillips