jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Type precision issue in BoxOSQP

Open jewillco opened this issue 2 years ago • 10 comments

When I have float64 support enabled in JAX and try to run:

optimizer = jaxopt.BoxOSQP()
optimizer.run(
    params_obj=(
        jnp.eye(30, dtype=jax.numpy.float32),
        jnp.ones((30,), dtype=jax.numpy.float32),
    ),
    params_eq=jnp.ones((1, 30), dtype=jax.numpy.float32),
    params_ineq=(-1, 1),
)

I get an internal error from the implementation (partially redacted):

[.../jaxopt/_src/osqp.py](...) in run(self, init_params, params_obj, params_eq, params_ineq)
    763       init_params = self.init_params(None, params_obj, params_eq, params_ineq)
    764 
--> 765     return super().run(init_params, params_obj, params_eq, params_ineq)
    766 
    767   def l2_optimality_error(

[.../jaxopt/_src/base.py](...) in run(self, init_params, *args, **kwargs)
    345       run = decorator(run)
    346 
--> 347     return run(init_params, *args, **kwargs)
    348 
    349   def __post_init__(self):

[.../jaxopt/_src/implicit_diff.py](...) in wrapped_solver_fun(*args, **kwargs)
    249     args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250     keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251     return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
    252 
    253   return wrapped_solver_fun

[.../jaxopt/_src/implicit_diff.py](...) in solver_fun_flat(*flat_args)
    205     def solver_fun_flat(*flat_args):
    206       args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207       return solver_fun(*args, **kwargs)
    208 
    209     def solver_fun_fwd(*flat_args):

[.../jaxopt/_src/base.py](...) in _run(self, init_params, *args, **kwargs)
    307     zero_step = self._make_zero_step(init_params, state)
    308 
--> 309     opt_step = self.update(init_params, state, *args, **kwargs)
    310     init_val = (opt_step, (args, kwargs))
    311 

[.../jaxopt/_src/osqp.py](...) in update(self, params, state, params_obj, params_eq, params_ineq)
    703     # We need our own ifelse cond because automatic jitting of jax.lax.cond branches
    704     # could pose problems with non jittable matvecs, or prevent printing when verbose > 0.
--> 705     rho_bar, solver_state = cond(
    706         jnp.mod(state.iter_num, self.stepsize_updates_frequency) == 0,
    707         lambda _: self._update_stepsize(rho_bar, solver_state, primal_residuals, dual_residuals, Q, c, A, x, y),

[.../jaxopt/_src/cond.py](...) in cond(cond, if_fun, else_fun, jit, *operands)
     22     with jax.disable_jit():
     23       return jax.lax.cond(cond, if_fun, else_fun, *operands)
---> 24   return jax.lax.cond(cond, if_fun, else_fun, *operands)

TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)', ('ShapedArray(float32[30])', ('ShapedArray(float32[30,30])', 'ShapedArray(float32[1,30])', 'ShapedArray(float64[], weak_type=True)', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)'), None)).

jewillco avatar Oct 14 '23 02:10 jewillco

Can you try to promote the params_ineq=(-1, 1) tuple to float32 by default? Tell me how it's going.

Algue-Rythme avatar Oct 17 '23 08:10 Algue-Rythme

I tried float32 and float64 there (using the jnp.float32(1) syntax):

float32: TypeError: true_fun and false_fun output must have identical types, got ('DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)', ('ShapedArray(float32[30])', ('ShapedArray(float32[30,30])', 'ShapedArray(float32[1,30])', 'ShapedArray(float64[], weak_type=True)', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)'), None)).

float64: TypeError: body_fun output and input must have identical types, got ('DIFFERENT ShapedArray(float64[30]) vs. ShapedArray(float32[30])', 'ShapedArray(float64[30])', 'ShapedArray(float64[])', 'ShapedArray(float64[30])', 'ShapedArray(int64[], weak_type=True)').

jewillco avatar Oct 17 '23 08:10 jewillco

Would you mind sharing your minimal (not) working example in Colab? Thanks in advance.

Algue-Rythme avatar Oct 17 '23 08:10 Algue-Rythme

optimizer = jaxopt.BoxOSQP()
optimizer.run(
    params_obj=(
        jnp.eye(30, dtype=jnp.float32),
        jnp.ones((30,), dtype=jnp.float32),
    ),
    params_eq=jnp.ones((1, 30), dtype=jnp.float32),
    params_ineq=(jnp.float32(-1), jnp.float32(1)),
)

jewillco avatar Oct 17 '23 16:10 jewillco

You did not gave me a Colab link. So, I copy/pasted the code in Colab, add a few imports, and in Colab, it works! There are no errors... which version are you using for jax/jaxopt/python? Are you using a GPU?

Algue-Rythme avatar Oct 18 '23 08:10 Algue-Rythme

I am using a TPU and my Colab has a large number of other things in it so I can't share it. Did you turn on float64 in JAX? That is the one thing that might be different from the snippet I posted.

jewillco avatar Oct 18 '23 08:10 jewillco

I did not turn on float64 on my initial test, check by yourself! I tested in float32 in CPU / GPU / TPU in Colab; it works.

In float64 enabled, and on a TPU, I get: XlaRuntimeError: INVALID_ARGUMENT: 64-bit data types are not yet supported on the TPU driver API. Convert inputs to float32/int32_t before using. which is the expected behavior for TPUs anyway, because they don't intent to leverage float64 arithmetic. Since you don't encounter this error, I wonder if you enabled the TPU in Colab with jax.tools.colab_tpu.setup_tpu().

The error you gave me arises when mixing float32 objects (in your call) with float64 objects that are allocated by default in BoxOSQP, on CPU (for example after failing to enable the TPU). This is also an expected behavior, because Jax policy is to prevent aggressive type promotion. However, if you force everything to be in float64, it works! Look here

my Colab has a large number of other things in it so I can't share it

Well, I am not asking for your whole work, just a minimal working example that reproduces the issue.

That is the one thing that might be different from the snippet I posted.

This is what I meant when I said "share a Colab link": it is not easy to infer what you did on your environement without details, the code you gave me was clearly unsufficient to understand what is really going on. As you can see, on Colab I can trigger different types of errors by juggling with types, environements, initialization at startup, and I consider none of these behaviors as a bug.

Algue-Rythme avatar Oct 18 '23 12:10 Algue-Rythme

Hi - JAX developer here – it looks like you're using Colab TPU; as of this writing (October 2023) Colab only provides very old TPU hardware, and is only compatible with a very old JAX version. I would not recommend running JAX on Colab TPU until this changes (but note that Colab CPU and GPU are fine). I believe this issue is fixed on more modern TPU architectures.

If you'd like to use modern TPUs in a free public notebook, I'd suggest taking a look at Kaggle, which provides more up-to-date TPU runtimes.

jakevdp avatar Oct 18 '23 16:10 jakevdp

Thanks for the heads up.

@jewillco: could you clarify your intent with this code? If my understanding is correct, you need:

  • a TPU for performance
  • float64 precision enabled by default for some reason (do you expect these computations to run on TPU?)
  • but you want the boxOSQP solver to run in float32 by default anyway?

Algue-Rythme avatar Oct 19 '23 10:10 Algue-Rythme

I want #1 and #2, at least with the option to run other parts of my code in float64 on the TPU (which is semi-supported). I would like BoxOSQP to run in either float32 or float64 depending on what inputs I give it. It turns out that it does work with all float64 inputs to the solver; it still produces NaNs on my problem but that's a different issue.

jewillco avatar Oct 19 '23 15:10 jewillco