jaxopt
jaxopt copied to clipboard
Type precision issue in BoxOSQP
When I have float64 support enabled in JAX and try to run:
optimizer = jaxopt.BoxOSQP()
optimizer.run(
params_obj=(
jnp.eye(30, dtype=jax.numpy.float32),
jnp.ones((30,), dtype=jax.numpy.float32),
),
params_eq=jnp.ones((1, 30), dtype=jax.numpy.float32),
params_ineq=(-1, 1),
)
I get an internal error from the implementation (partially redacted):
[.../jaxopt/_src/osqp.py](...) in run(self, init_params, params_obj, params_eq, params_ineq)
763 init_params = self.init_params(None, params_obj, params_eq, params_ineq)
764
--> 765 return super().run(init_params, params_obj, params_eq, params_ineq)
766
767 def l2_optimality_error(
[.../jaxopt/_src/base.py](...) in run(self, init_params, *args, **kwargs)
345 run = decorator(run)
346
--> 347 return run(init_params, *args, **kwargs)
348
349 def __post_init__(self):
[.../jaxopt/_src/implicit_diff.py](...) in wrapped_solver_fun(*args, **kwargs)
249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
252
253 return wrapped_solver_fun
[.../jaxopt/_src/implicit_diff.py](...) in solver_fun_flat(*flat_args)
205 def solver_fun_flat(*flat_args):
206 args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207 return solver_fun(*args, **kwargs)
208
209 def solver_fun_fwd(*flat_args):
[.../jaxopt/_src/base.py](...) in _run(self, init_params, *args, **kwargs)
307 zero_step = self._make_zero_step(init_params, state)
308
--> 309 opt_step = self.update(init_params, state, *args, **kwargs)
310 init_val = (opt_step, (args, kwargs))
311
[.../jaxopt/_src/osqp.py](...) in update(self, params, state, params_obj, params_eq, params_ineq)
703 # We need our own ifelse cond because automatic jitting of jax.lax.cond branches
704 # could pose problems with non jittable matvecs, or prevent printing when verbose > 0.
--> 705 rho_bar, solver_state = cond(
706 jnp.mod(state.iter_num, self.stepsize_updates_frequency) == 0,
707 lambda _: self._update_stepsize(rho_bar, solver_state, primal_residuals, dual_residuals, Q, c, A, x, y),
[.../jaxopt/_src/cond.py](...) in cond(cond, if_fun, else_fun, jit, *operands)
22 with jax.disable_jit():
23 return jax.lax.cond(cond, if_fun, else_fun, *operands)
---> 24 return jax.lax.cond(cond, if_fun, else_fun, *operands)
TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)', ('ShapedArray(float32[30])', ('ShapedArray(float32[30,30])', 'ShapedArray(float32[1,30])', 'ShapedArray(float64[], weak_type=True)', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)'), None)).
Can you try to promote the params_ineq=(-1, 1) tuple to float32 by default? Tell me how it's going.
I tried float32 and float64 there (using the jnp.float32(1) syntax):
float32: TypeError: true_fun and false_fun output must have identical types, got ('DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)', ('ShapedArray(float32[30])', ('ShapedArray(float32[30,30])', 'ShapedArray(float32[1,30])', 'ShapedArray(float64[], weak_type=True)', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)'), None)).
float64: TypeError: body_fun output and input must have identical types, got ('DIFFERENT ShapedArray(float64[30]) vs. ShapedArray(float32[30])', 'ShapedArray(float64[30])', 'ShapedArray(float64[])', 'ShapedArray(float64[30])', 'ShapedArray(int64[], weak_type=True)').
Would you mind sharing your minimal (not) working example in Colab? Thanks in advance.
optimizer = jaxopt.BoxOSQP()
optimizer.run(
params_obj=(
jnp.eye(30, dtype=jnp.float32),
jnp.ones((30,), dtype=jnp.float32),
),
params_eq=jnp.ones((1, 30), dtype=jnp.float32),
params_ineq=(jnp.float32(-1), jnp.float32(1)),
)
You did not gave me a Colab link. So, I copy/pasted the code in Colab, add a few imports, and in Colab, it works! There are no errors... which version are you using for jax/jaxopt/python? Are you using a GPU?
I am using a TPU and my Colab has a large number of other things in it so I can't share it. Did you turn on float64 in JAX? That is the one thing that might be different from the snippet I posted.
I did not turn on float64 on my initial test, check by yourself! I tested in float32 in CPU / GPU / TPU in Colab; it works.
In float64 enabled, and on a TPU, I get: XlaRuntimeError: INVALID_ARGUMENT: 64-bit data types are not yet supported on the TPU driver API. Convert inputs to float32/int32_t before using. which is the expected behavior for TPUs anyway, because they don't intent to leverage float64 arithmetic. Since you don't encounter this error, I wonder if you enabled the TPU in Colab with jax.tools.colab_tpu.setup_tpu().
The error you gave me arises when mixing float32 objects (in your call) with float64 objects that are allocated by default in BoxOSQP, on CPU (for example after failing to enable the TPU). This is also an expected behavior, because Jax policy is to prevent aggressive type promotion. However, if you force everything to be in float64, it works! Look here
my Colab has a large number of other things in it so I can't share it
Well, I am not asking for your whole work, just a minimal working example that reproduces the issue.
That is the one thing that might be different from the snippet I posted.
This is what I meant when I said "share a Colab link": it is not easy to infer what you did on your environement without details, the code you gave me was clearly unsufficient to understand what is really going on. As you can see, on Colab I can trigger different types of errors by juggling with types, environements, initialization at startup, and I consider none of these behaviors as a bug.
Hi - JAX developer here – it looks like you're using Colab TPU; as of this writing (October 2023) Colab only provides very old TPU hardware, and is only compatible with a very old JAX version. I would not recommend running JAX on Colab TPU until this changes (but note that Colab CPU and GPU are fine). I believe this issue is fixed on more modern TPU architectures.
If you'd like to use modern TPUs in a free public notebook, I'd suggest taking a look at Kaggle, which provides more up-to-date TPU runtimes.
Thanks for the heads up.
@jewillco: could you clarify your intent with this code? If my understanding is correct, you need:
- a TPU for performance
float64precision enabled by default for some reason (do you expect these computations to run on TPU?)- but you want the
boxOSQPsolver to run infloat32by default anyway?
I want #1 and #2, at least with the option to run other parts of my code in float64 on the TPU (which is semi-supported). I would like BoxOSQP to run in either float32 or float64 depending on what inputs I give it. It turns out that it does work with all float64 inputs to the solver; it still produces NaNs on my problem but that's a different issue.