JuMP.jl icon indicating copy to clipboard operation
JuMP.jl copied to clipboard

Performance problem with deeply nested nonlinear expressions

Open pekpuglia opened this issue 6 months ago • 9 comments

Hello, I am interested in switching from CasADi tu JuMP for optimal control (the system is quaternion kinematics, so not a lot of variables). My MATLAB CasADi model does multistep RK4 integration for multiple shooting (4 steps/interval, 10 intervals) and builds and solves the problem in < 2s. The equivalent JuMP code builds the model for many minutes, using up to 10GB of RAM, and is basically unusable. This is the JuMP code I wrote:

using ReferenceFrameRotations, DifferentialEquations, Plots
##
function qdot(X, u)
    collect(dquat(Quaternion(X), u))
end

function RK4(f, X, u, dt)
    k1 = dt*f(X     , u)
    k2 = dt*f(X+k1/2, u)
    k3 = dt*f(X+k2/2, u)
    k4 = dt*f(X+k3  , u)
    X + (k1+2k2+2k3+k4)/6
end

model = Model(Ipopt.Optimizer)

DCMf = [
    -1 0 0
    0 0 1
    0 1 0
]

# two quaternions represent the same orientation
qf = -convert(Quaternion, DCM(DCMf)) |> collect

N = 2

T = 1
dt = T / (N-1)

tabq = @variable(model, [1:4, 1:N], base_name="q")

tabu = @variable(model, [1:3, 1:N-1], base_name="ω_b")

cost = 0

for i = 1:N-1
    F = [tabq[:, i]; cost]
    for M = 1:4
        F = RK4((X, u) -> [qdot(X[1:4], u); 1/2 * u' * u], F, tabu[:, i], dt)
    end
    @constraint(model, tabq[:, i+1] == F[1:4])
    cost = F[end]
end

#traj constraints
@constraint(model, tabq[:, 1] == q0)
@constraint(model, tabq[:, N] == qf)

# @constraint(model, c[i=1:N], tabq[:, i]' * tabq[:, i] == 1)

@objective(model, Min, cost)

#start values
set_start_value.(tabq, q0)

I set N = 2 to prove it's slow even for a single step (similar problems happen with single-step high order RK schemes, such as RK8, which leads me to believe JuMP can't handle big expressions really well). For comparison, if I build a single-step model with 50 steps (approx. 4steps/interval x 10 intervals for the working CasADi model) like this:

model = Model(Ipopt.Optimizer)

DCMf = [
    -1 0 0
    0 0 1
    0 1 0
]

# two quaternions represent the same orientation
qf = -convert(Quaternion, DCM(DCMf)) |> collect

N = 50

T = 1
dt = T / (N-1)

tabq = @variable(model, [1:4, 1:N], base_name="q")

tabu = @variable(model, [1:3, 1:N-1], base_name="ω_b")

cost = 0

for i = 1:N-1
    #can't do multistep integration
    F = RK4((X, u) -> [qdot(X[1:4], u); 1/2 * u' * u], [tabq[:, i]; cost], tabu[:, i], dt)
    @constraint(model, tabq[:, i+1] == F[1:4])
    cost = F[end]
end

#traj constraints
@constraint(model, tabq[:, 1] == q0)
@constraint(model, tabq[:, N] == qf)

# @constraint(model, c[i=1:N], tabq[:, i]' * tabq[:, i] == 1)

@objective(model, Min, cost)

#start values
set_start_value.(tabq, q0)
model
##
optimize!(model)

It then works reasonably fast. I don't understand why JuMP can't handle the first code; in my view, this is a big limitation. Is there a particular way of coding in JuMP I'm missing, or is this truly JuMP's fault?

pekpuglia avatar Jul 07 '25 08:07 pekpuglia

Hi @pekpuglia Good question. The issue here is that we expand everything into an expression graph at the level of elementary operations: *, +, ... But that graph is way to big. So you would want adding more allowed operators to this graph such as RK4 etc... Let's even add the whole multistep integration as one operator. One limitation though is that JuMP currently only support scalar inputs and scalar outputs so we have to scalarize like so:

using ReferenceFrameRotations, DifferentialEquations, Plots
##
function qdot(X, u)
    collect(dquat(Quaternion(X), u))
end

function RK4(f, X, u, dt)
    k1 = dt*f(X     , u)
    k2 = dt*f(X+k1/2, u)
    k3 = dt*f(X+k2/2, u)
    k4 = dt*f(X+k3  , u)
    X + (k1+2k2+2k3+k4)/6
end

using JuMP
import Ipopt

model = Model(Ipopt.Optimizer)

DCMf = [
    -1 0 0
    0 0 1
    0 1 0
]

# two quaternions represent the same orientation
qf = -convert(Quaternion, DCM(DCMf)) |> collect

function integrate(x, u, dt)
    cost = 0
    F = [x; cost]
    for _ = 1:4
        F = RK4((X, u) -> [qdot(X[1:4], u); 1/2 * u' * u], F, u, dt)
    end
    return F
end

int1(dt, u1, u2, u3, x1, x2, x3, x4) = integrate([x1, x2, x3, x4], [u1, u2, u3], dt)[1]
int2(dt, u1, u2, u3, x1, x2, x3, x4) = integrate([x1, x2, x3, x4], [u1, u2, u3], dt)[2]
int3(dt, u1, u2, u3, x1, x2, x3, x4) = integrate([x1, x2, x3, x4], [u1, u2, u3], dt)[3]
int4(dt, u1, u2, u3, x1, x2, x3, x4) = integrate([x1, x2, x3, x4], [u1, u2, u3], dt)[4]
int5(dt, u1, u2, u3, x1, x2, x3, x4) = integrate([x1, x2, x3, x4], [u1, u2, u3], dt)[5]

@operator(model, op_int1, 8, int1)
@operator(model, op_int2, 8, int2)
@operator(model, op_int3, 8, int3)
@operator(model, op_int4, 8, int4)
@operator(model, op_int5, 8, int5)
op_inti = [op_int1, op_int2, op_int3, op_int4, op_int5]

N = 2

T = 1
dt = T / (N-1)

tabq = @variable(model, [1:4, 1:N], base_name="q")

tabu = @variable(model, [1:3, 1:N-1], base_name="ω_b")

cost = 0
for i = 1:N-1
    @constraint(model, [j = 1:4], tabq[j, i+1] == op_inti[j](dt, tabu[:,i]..., tabq[:,i]...))
    cost += op_inti[5](dt, tabu[:,i]..., tabq[:,i]...)
end

#traj constraints
#@constraint(model, tabq[:, 1] .== q0) # Commented out otherwise Ipopt says `TOO_FEW_DEGREES_OF_FREEDOM`
@constraint(model, tabq[:, N] .== qf)

# @constraint(model, c[i=1:N], tabq[:, i]' * tabq[:, i] == 1)

@objective(model, Min, cost)

#start values
optimize!(model)

value.(tabq)
value.(tabu)

Since we created separate functions for each output to make it scalar output, there might be duplicated computation, see this tutorial to avoid it.

blegat avatar Jul 07 '25 14:07 blegat

Thank you! will try this in a bit. However, I still don't understand wht is the difference between JuMP and CasADi that allows CasADi to trace multi-step integration and not JuMP. I believe CasADi also expands expressions, as can be verified interactively in MATLAB.

pekpuglia avatar Jul 07 '25 14:07 pekpuglia

Can you post the expression graph you obtain with CasADi (so F) after one or two steps ?

blegat avatar Jul 07 '25 15:07 blegat

I'm tempted to close this issue as out-of-scope. JuMP is not built for this sort of modeling approach. With large nested expressions.

The solution is to introduce temporary variables to represent intermediate expressions:

using JuMP, Ipopt, ReferenceFrameRotations

qdot(X, u) = collect(dquat(Quaternion(X), u))

function RK4(f, X, u, dt)
    k1 = dt*f(X     , u)
    k2 = dt*f(X+k1/2, u)
    k3 = dt*f(X+k2/2, u)
    k4 = dt*f(X+k3  , u)
    return X + (k1+2k2+2k3+k4)/6
end

DCMf = [-1 0 0; 0 0 1; 0 1 0]
qf = -convert(Quaternion, DCM(DCMf)) |> collect
N = 2
T = 1
dt = T / (N - 1)

model = Model(Ipopt.Optimizer)
@variable(model, x[1:4, 1:N])
@variable(model, u[1:3, 1:N-1])
@variable(model, F[1:5, 1:N-1, 1:5])
for i in 1:N-1
    @constraint(model, F[:, i, 1] .== [x[:, i]; 0.0])
    for M in 1:4
        F_m = RK4(
            (X, u) -> [qdot(X[1:4], u); 1/2 * u' * u],
            F[:, i, M],
            u[:, i],
            dt,
        )
        @constraint(model, F[:, i, M + 1] .== F_m)
    end
    @constraint(model, x[:, i+1] .== F[1:4, i, end])
end
cost = F[5, i, end]

Alternatively, you might want to look into https://github.com/infiniteopt/InfiniteOpt.jl

odow avatar Jul 07 '25 22:07 odow

I just ran the equivalent CasADi code (using a casadi.Function for multistep RK4, which I'm not sure is important) and it builds this expression:

@1=0.00462963, @2=0.5, @3=(@2*(((tab_u_0*tab_X_1)+(tab_u_1*tab_X_2))+(tab_u_2*tab_X_3))), @4=0.0138889, @5=(@2*(((tab_u_0*tab_X_0)-(tab_u_1*tab_X_3))+(tab_u_2*tab_X_2))), @6=(tab_X_1+(@4*@5)), @7=(@2*(((tab_u_1*tab_X_0)+(tab_u_0*tab_X_3))-(tab_u_2*tab_X_1))), @8=(tab_X_2+(@4*@7)), @9=(@2*(((tab_u_1*tab_X_1)-(tab_u_0*tab_X_2))+(tab_u_2*tab_X_0))), @10=(tab_X_3+(@4*@9)), @11=(((tab_u_0*@6)+(tab_u_1*@8))+(tab_u_2*@10)), @12=(tab_X_0-(@4*@3)), @13=(((tab_u_0*@12)-(tab_u_1*@10))+(tab_u_2*@8)), @14=(tab_X_1+(@4*(@2*@13))), @15=(((tab_u_1*@12)+(tab_u_0*@10))-(tab_u_2*@6)), @16=(tab_X_2+(@4*(@2*@15))), @17=(((tab_u_1*@6)-(tab_u_0*@8))+(tab_u_2*@12)), @18=(tab_X_3+(@4*(@2*@17))), @19=(((tab_u_0*@14)+(tab_u_1*@16))+(tab_u_2*@18)), @20=0.0277778, @21=(tab_X_0-(@4*(@2*@11))), @22=(((tab_u_0*@21)-(tab_u_1*@18))+(tab_u_2*@16)), @23=(tab_X_1+(@20*(@2*@22))), @24=(((tab_u_1*@21)+(tab_u_0*@18))-(tab_u_2*@14)), @25=(tab_X_2+(@20*(@2*@24))), @26=(((tab_u_1*@14)-(tab_u_0*@16))+(tab_u_2*@21)), @27=(tab_X_3+(@20*(@2*@26))), @28=(tab_X_0-(@1*(((@3+@11)+@19)+(@2*(((tab_u_0*@23)+(tab_u_1*@25))+(tab_u_2*@27)))))), @29=(tab_X_0-(@20*(@2*@19))), @30=(tab_X_1+(@1*(((@5+@13)+@22)+(@2*(((tab_u_0*@29)-(tab_u_1*@27))+(tab_u_2*@25)))))), @31=(tab_X_2+(@1*(((@7+@15)+@24)+(@2*(((tab_u_1*@29)+(tab_u_0*@27))-(tab_u_2*@23)))))), @32=(tab_X_3+(@1*(((@9+@17)+@26)+(@2*(((tab_u_1*@23)-(tab_u_0*@25))+(tab_u_2*@29)))))), @33=(@2*(((tab_u_0*@30)+(tab_u_1*@31))+(tab_u_2*@32))), @34=(@2*(((tab_u_0*@28)-(tab_u_1*@32))+(tab_u_2*@31))), @35=(@30+(@4*@34)), @36=(@2*(((tab_u_1*@28)+(tab_u_0*@32))-(tab_u_2*@30))), @37=(@31+(@4*@36)), @38=(@2*(((tab_u_1*@30)-(tab_u_0*@31))+(tab_u_2*@28))), @39=(@32+(@4*@38)), @40=(((tab_u_0*@35)+(tab_u_1*@37))+(tab_u_2*@39)), @41=(@28-(@4*@33)), @42=(((tab_u_0*@41)-(tab_u_1*@39))+(tab_u_2*@37)), @43=(@30+(@4*(@2*@42))), @44=(((tab_u_1*@41)+(tab_u_0*@39))-(tab_u_2*@35)), @45=(@31+(@4*(@2*@44))), @46=(((tab_u_1*@35)-(tab_u_0*@37))+(tab_u_2*@41)), @47=(@32+(@4*(@2*@46))), @48=(((tab_u_0*@43)+(tab_u_1*@45))+(tab_u_2*@47)), @49=(@28-(@4*(@2*@40))), @50=(((tab_u_0*@49)-(tab_u_1*@47))+(tab_u_2*@45)), @51=(@30+(@20*(@2*@50))), @52=(((tab_u_1*@49)+(tab_u_0*@47))-(tab_u_2*@43)), @53=(@31+(@20*(@2*@52))), @54=(((tab_u_1*@43)-(tab_u_0*@45))+(tab_u_2*@49)), @55=(@32+(@20*(@2*@54))), @56=(@28-(@1*(((@33+@40)+@48)+(@2*(((tab_u_0*@51)+(tab_u_1*@53))+(tab_u_2*@55)))))), @57=(@28-(@20*(@2*@48))), @58=(@30+(@1*(((@34+@42)+@50)+(@2*(((tab_u_0*@57)-(tab_u_1*@55))+(tab_u_2*@53)))))), @59=(@31+(@1*(((@36+@44)+@52)+(@2*(((tab_u_1*@57)+(tab_u_0*@55))-(tab_u_2*@51)))))), @60=(@32+(@1*(((@38+@46)+@54)+(@2*(((tab_u_1*@51)-(tab_u_0*@53))+(tab_u_2*@57)))))), @61=(@2*(((tab_u_0*@58)+(tab_u_1*@59))+(tab_u_2*@60))), @62=(@2*(((tab_u_0*@56)-(tab_u_1*@60))+(tab_u_2*@59))), @63=(@58+(@4*@62)), @64=(@2*(((tab_u_1*@56)+(tab_u_0*@60))-(tab_u_2*@58))), @65=(@59+(@4*@64)), @66=(@2*(((tab_u_1*@58)-(tab_u_0*@59))+(tab_u_2*@56))), @67=(@60+(@4*@66)), @68=(((tab_u_0*@63)+(tab_u_1*@65))+(tab_u_2*@67)), @69=(@56-(@4*@61)), @70=(((tab_u_0*@69)-(tab_u_1*@67))+(tab_u_2*@65)), @71=(@58+(@4*(@2*@70))), @72=(((tab_u_1*@69)+(tab_u_0*@67))-(tab_u_2*@63)), @73=(@59+(@4*(@2*@72))), @74=(((tab_u_1*@63)-(tab_u_0*@65))+(tab_u_2*@69)), @75=(@60+(@4*(@2*@74))), @76=(((tab_u_0*@71)+(tab_u_1*@73))+(tab_u_2*@75)), @77=(@56-(@4*(@2*@68))), @78=(((tab_u_0*@77)-(tab_u_1*@75))+(tab_u_2*@73)), @79=(@58+(@20*(@2*@78))), @80=(((tab_u_1*@77)+(tab_u_0*@75))-(tab_u_2*@71)), @81=(@59+(@20*(@2*@80))), @82=(((tab_u_1*@71)-(tab_u_0*@73))+(tab_u_2*@77)), @83=(@60+(@20*(@2*@82))), @84=(@56-(@1*(((@61+@68)+@76)+(@2*(((tab_u_0*@79)+(tab_u_1*@81))+(tab_u_2*@83)))))), @85=(@56-(@20*(@2*@76))), @86=(@58+(@1*(((@62+@70)+@78)+(@2*(((tab_u_0*@85)-(tab_u_1*@83))+(tab_u_2*@81)))))), @87=(@59+(@1*(((@64+@72)+@80)+(@2*(((tab_u_1*@85)+(tab_u_0*@83))-(tab_u_2*@79)))))), @88=(@60+(@1*(((@66+@74)+@82)+(@2*(((tab_u_1*@79)-(tab_u_0*@81))+(tab_u_2*@85)))))), @89=(@2*(((tab_u_0*@86)+(tab_u_1*@87))+(tab_u_2*@88))), @90=(@86+(@4*(@2*(((tab_u_0*@84)-(tab_u_1*@88))+(tab_u_2*@87))))), @91=(@87+(@4*(@2*(((tab_u_1*@84)+(tab_u_0*@88))-(tab_u_2*@86))))), @92=(@88+(@4*(@2*(((tab_u_1*@86)-(tab_u_0*@87))+(tab_u_2*@84))))), @93=(((tab_u_0*@90)+(tab_u_1*@91))+(tab_u_2*@92)), @94=(@84-(@4*@89)), @95=(@86+(@4*(@2*(((tab_u_0*@94)-(tab_u_1*@92))+(tab_u_2*@91))))), @96=(@87+(@4*(@2*(((tab_u_1*@94)+(tab_u_0*@92))-(tab_u_2*@90))))), @97=(@88+(@4*(@2*(((tab_u_1*@90)-(tab_u_0*@91))+(tab_u_2*@94))))), @98=(@84-(@4*(@2*@93))), ((@84-(@1*(((@89+@93)+(((tab_u_0*@95)+(tab_u_1*@96))+(tab_u_2*@97)))+(@2*(((tab_u_0*(@86+(@20*(@2*(((tab_u_0*@98)-(tab_u_1*@97))+(tab_u_2*@96))))))+(tab_u_1*(@87+(@20*(@2*(((tab_u_1*@98)+(tab_u_0*@97))-(tab_u_2*@95)))))))+(tab_u_2*(@88+(@20*(@2*(((tab_u_1*@95)-(tab_u_0*@96))+(tab_u_2*@98)))))))))))-tab_X_4)

Which apparently builds sub-expressions for the multi-step integration.

So, in practice, there are feasible ways of doing this with JuMP, and I can go on working on my problem (thanks!). It just feels frustrating/limiting that, even as Julia's main optimization/optimal control package, JuMP needs workarounds (such as operator registering) where CasADi just works. The memoization approach is somewhat cumbersome and I think there would be interest from the community in making this more streamlined. Maybe from the point of view of the developers this is out-of-scope, but as a Julia user, I would reach for JuMP as a drop-in replacement (meaning, the same mathematical algorithms should work in JuMP as easily as in CasADi).

pekpuglia avatar Jul 08 '25 07:07 pekpuglia

Thanks, it seems Casadi is automatically creating the intermediate variable as @odow mentioned. Actually, building the JuMP functions is quite fast. The fact that they have common sub-expressions is not an issue because they are shared as reference, not copied. So as long as you don't print them you are fine. The bottleneck is moi_function because, it's being called many times on the same GenericNonlinearExpr. If we had a dict that would check the ref and check whether we already built the MOI.ScalarNonlinearFunction for it then it would speed up the generation of the MOI model but the AD will still be slow because it will not exploit the fact that there are common sub-expressions. So just improving the build time without creating sub-expressions is not helpful. We discussed in https://github.com/jump-dev/JuMP.jl/issues/3738 for a syntax for the user to explictly add sub-expressions. Actually, maybe we should just automatically create sub-expressions when we have several times a reference to the same GenericNonlinearExpr. This seems to be what CasADi is doing and as mentioned by @pekpuglia , this will help for the problems users write naturally.

blegat avatar Jul 08 '25 12:07 blegat

I explicitly considered this case when designing the data structure, and we chose not to handle it. If you're solving optimal control type problems, CasADi is an excellent tool for the job. JuMP doesn't have to be great at everything.

The fact that they have common sub-expressions is not an issue because they are shared as reference, not copied.

This is very much "it depends." We don't detect equivalent subexpressions, and subexpressions are mutable, so they don't have object equality.

Actually, maybe we should just automatically create sub-expressions when we have several times a reference to the same GenericNonlinearExpr.

The problem is that this is not easily decidable.

We have chosen to force people to explicitly implement their own subexpressions, like I did above. If JuMP got into that game it opens a massive can of worms.

As one example: we couldn't detect that two subexpressions are equivalent and replace them by one shared reference because the user may later modify one of them.

odow avatar Jul 08 '25 22:07 odow

As a follow-up question, since I'm not completely sure from reading the docs: going the operator registering way disables the use of the Hessian?

pekpuglia avatar Jul 28 '25 09:07 pekpuglia

If you don't provide a Hessian operator, yes. But you can provide a Hessian callback for the operator (see the docs).

odow avatar Jul 28 '25 09:07 odow