optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Is single precision a mistake?

Open matillda123 opened this issue 7 months ago • 8 comments

Hey, I am using optimistix mainly interactively. Thus I am overriding the default(?) double precision with single precision. I am wondering if this is a mistake. Occasionally solvers like BFGS seem to get stuck where I don't expect them to stagnate. Could single precision be responsible for this? And how much efficiency does one loose with this switch?

Thanks for the help. :)

matillda123 avatar May 19 '25 11:05 matillda123

Optimistix does not set a default, and you do not need to override to get to single-precision. (There is a section on that in the sharp bits.) Single precision is the JAX default, and double precision requires a flag set directly after you have imported your packages.

That said, for many scientific computations - differential equations, nonlinear solves - double precision is often required to get an accurate solution. How much this matters and if that is worth the cost depends on your particular setup. If you're wondering about runtime, then you could (micro-)benchmark this with %timeit, with and without

import jax

jax.config.update("jax_enable_x64", True) 

at startup.

johannahaffner avatar May 19 '25 13:05 johannahaffner

Optimistix does not set a default, and you do not need to override to get to single-precision.

Ah okay. I just thought that was the case since I am getting this warning.

/miniconda3/lib/python3.11/site-packages/optimistix/_misc.py:51: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. fn = lambda x: jnp.zeros(x.shape, x.dtype)

Thanks for the advice. :)

matillda123 avatar May 19 '25 14:05 matillda123

Huh. Do you have an MWE for this? Or a longer traceback?

The function that throws this warning uses a shape-dtype struct that must come from somewhere, and I believe it is called in solver.init.

johannahaffner avatar May 19 '25 14:05 johannahaffner

I'm not sure how to reduce this to something minimal. But you are right, the warning originates when I am initializing the optimistix solver for an interactive solve with solver.init. I checked however that none of my inputs have double precision. But I might be overlooking something.

matillda123 avatar May 19 '25 16:05 matillda123

Sorry. Forgot to add the traceback. Here I cut of the traceback right before my function is mentioned.

File "/miniconda3/lib/python3.11/site-packages/equinox/_module.py", line 1060, in call return self.func(self.self, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/lib/python3.11/site-packages/optimistix/_solver/gradient_methods.py", line 151, in init descent_state=self.descent.init(y, f_info_struct), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/lib/python3.11/site-packages/equinox/_module.py", line 1060, in call return self.func(self.self, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/lib/python3.11/site-packages/optimistix/_solver/nonlinear_cg.py", line 107, in init y_diff=tree_full_like(y, 0), ^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/lib/python3.11/site-packages/optimistix/_misc.py", line 61, in tree_full_like return jtu.tree_map(fn, struct) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/lib/python3.11/site-packages/jax/_src/tree_util.py", line 358, in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/lib/python3.11/site-packages/jax/_src/tree_util.py", line 358, in return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^ File "/miniconda3/lib/python3.11/site-packages/optimistix/_misc.py", line 51, in fn = lambda x: jnp.zeros(x.shape, x.dtype) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/lib/python3.11/site-packages/jax/_src/numpy/array_creation.py", line 77, in zeros dtypes.check_user_dtype_supported(dtype, "zeros") File "/miniconda3/lib/python3.11/site-packages/jax/_src/dtypes.py", line 919, in check_user_dtype_supported warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3) UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.

matillda123 avatar May 19 '25 16:05 matillda123

Thanks for the traceback! What do you see when you print jax.eval_shape(lambda: y)? It seems to come from y.

And very good to know that this is an interactive solve!

johannahaffner avatar May 19 '25 18:05 johannahaffner

Hey, (sorry for the long wait) I just figured it out. I was creating a numpy.array in some function for a custom initial guess and never explicitly converted it into a jax.numpy.array before assembling my pytree. It seems a bit weird though that jax was not complaining about this.

matillda123 avatar Jun 01 '25 15:06 matillda123

Ah, that explains it! Numpy defaults to 64-bit. I think you're not getting an error about this not being a JAX array because the shape checks, and you are getting an error about the requested double precision type before you ever use the value. Using Numpy arrays as input values is also not inherently wrong! They're just pytrees, same as many other things (lists, tuples).

(Aside from that, many JAX functions will also happily accept anything ArrayLike, and return a JAX array on which the next computation then operates. If you'd like to explicitly check for this, you might like jaxtyping which does allow you to specify granular requirements like specific array types.)

johannahaffner avatar Jun 01 '25 21:06 johannahaffner