jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

OSQP crashing on unexpected params

Open Illviljan opened this issue 1 year ago • 3 comments

I'm trying to move over some qp code to jaxopt, but I'm struggling to understand the cryptic errors that appears to only happen in the jaxopt implementation. I've tried with other packages and these params work with those implementations.

Here's a minimal example:

import numpy as np
import jax.numpy as jnp
from jaxopt import OSQP

from qpsolvers import solve_qp


def to_numpy(*args):
    return tuple(np.asarray(v) for v in args)


P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[-1.0]])
h_ = jnp.array([2.0])
A_ = jnp.array([[]], dtype=float).T
b_ = jnp.array([], dtype=float)

x = solve_qp(*to_numpy(P_, q_, G_, h_, A_, b_), solver="osqp")  # works

qp = OSQP()
deltas = qp.run(
    params_obj=(P_, q_),
    params_eq=(A_, b_),
    params_ineq=(G_, h_),
).params.primal  # Crashes with cryptic error.
# TypeError: dot_general requires contracting dimensions to have the same shape, got (1,) and (2,).

# jax-0.4.23 jaxlib-0.4.23 jaxopt-0.8.3 ml-dtypes-0.3.2 opt-einsum-3.3.0

Illviljan avatar Jan 16 '24 18:01 Illviljan

Hi Illviljan

Sorry for the cryptic error message. The error comes from the fact that the matrix A_ = jnp.array([[]], dtype=float).T is not a valid linear operator. If you don't need equality constraints you just need to pass None to params_eq:

P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[-1.0]])
h_ = jnp.array([2.0])
# A_ = jnp.array([[]], dtype=float).T
# b_ = jnp.array([], dtype=float)

qp = OSQP()
deltas = qp.run(
    params_obj=(P_, q_),
    params_eq=None,   # CHANGE HERE.
    params_ineq=(G_, h_),
).params.primal

Similarly, if you don't need inequality constraints just pass None to params_ineq. Thank you for your message, I just came to the realization that I forgot to document this functionnality.

Algue-Rythme avatar Jan 17 '24 10:01 Algue-Rythme

Thank you, a quite simple fix. I maybe just need to continue with all constraints active in my larger project.

I get surprised because it seems to me that jaxopt is the odd one out since A_ is valid in other qp packages.

Using None is fine I guess, the annoying part is that jaxopt doesn't allow both constraints to be None. Other packages allows that and I think it aligns more with how I build a new solution; start simple without any constraints and make sure it works, slowly add more constraints until the solution makes sense.

import numpy as np
import jax.numpy as jnp
from jaxopt import OSQP

from qpsolvers import solve_qp


def to_numpy(*args):
    return tuple(np.asarray(v) for v in args)


P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[]], dtype=float).T
h_ = jnp.array([], dtype=float)
A_ = jnp.array([[]], dtype=float).T
b_ = jnp.array([], dtype=float)

x = solve_qp(*to_numpy(P_, q_, G_, h_, A_, b_), solver="osqp")  # works
print(x)

qp = OSQP()
x = qp.run(
    params_obj=(P_, q_),
    params_eq=None,
    params_ineq=None,
).params.primal  # Unnecessarily strict crash

Illviljan avatar Jan 17 '24 20:01 Illviljan

That's true ; but using OSQP when you don't have constraints is overkill. In this case OSQP algorithm degenerates toward an inefficient way to solve a linear system.

As argued in the documentation you should revert to conjugate gradient in this case.

Algue-Rythme avatar Jan 18 '24 09:01 Algue-Rythme