OSQP crashing on unexpected params
I'm trying to move over some qp code to jaxopt, but I'm struggling to understand the cryptic errors that appears to only happen in the jaxopt implementation. I've tried with other packages and these params work with those implementations.
Here's a minimal example:
import numpy as np
import jax.numpy as jnp
from jaxopt import OSQP
from qpsolvers import solve_qp
def to_numpy(*args):
return tuple(np.asarray(v) for v in args)
P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[-1.0]])
h_ = jnp.array([2.0])
A_ = jnp.array([[]], dtype=float).T
b_ = jnp.array([], dtype=float)
x = solve_qp(*to_numpy(P_, q_, G_, h_, A_, b_), solver="osqp") # works
qp = OSQP()
deltas = qp.run(
params_obj=(P_, q_),
params_eq=(A_, b_),
params_ineq=(G_, h_),
).params.primal # Crashes with cryptic error.
# TypeError: dot_general requires contracting dimensions to have the same shape, got (1,) and (2,).
# jax-0.4.23 jaxlib-0.4.23 jaxopt-0.8.3 ml-dtypes-0.3.2 opt-einsum-3.3.0
Hi Illviljan
Sorry for the cryptic error message. The error comes from the fact that the matrix A_ = jnp.array([[]], dtype=float).T is not a valid linear operator. If you don't need equality constraints you just need to pass None to params_eq:
P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[-1.0]])
h_ = jnp.array([2.0])
# A_ = jnp.array([[]], dtype=float).T
# b_ = jnp.array([], dtype=float)
qp = OSQP()
deltas = qp.run(
params_obj=(P_, q_),
params_eq=None, # CHANGE HERE.
params_ineq=(G_, h_),
).params.primal
Similarly, if you don't need inequality constraints just pass None to params_ineq. Thank you for your message, I just came to the realization that I forgot to document this functionnality.
Thank you, a quite simple fix. I maybe just need to continue with all constraints active in my larger project.
I get surprised because it seems to me that jaxopt is the odd one out since A_ is valid in other qp packages.
Using None is fine I guess, the annoying part is that jaxopt doesn't allow both constraints to be None. Other packages allows that and I think it aligns more with how I build a new solution; start simple without any constraints and make sure it works, slowly add more constraints until the solution makes sense.
import numpy as np
import jax.numpy as jnp
from jaxopt import OSQP
from qpsolvers import solve_qp
def to_numpy(*args):
return tuple(np.asarray(v) for v in args)
P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[]], dtype=float).T
h_ = jnp.array([], dtype=float)
A_ = jnp.array([[]], dtype=float).T
b_ = jnp.array([], dtype=float)
x = solve_qp(*to_numpy(P_, q_, G_, h_, A_, b_), solver="osqp") # works
print(x)
qp = OSQP()
x = qp.run(
params_obj=(P_, q_),
params_eq=None,
params_ineq=None,
).params.primal # Unnecessarily strict crash
That's true ; but using OSQP when you don't have constraints is overkill. In this case OSQP algorithm degenerates toward an inefficient way to solve a linear system.
As argued in the documentation you should revert to conjugate gradient in this case.