scico icon indicating copy to clipboard operation
scico copied to clipboard

Handle double precision in tests more carefully

Open Michael-T-McCann opened this issue 2 years ago • 1 comments

JAX works in single precision by default, and won't even let you create double precision arrays unless an environment variable is set (JAX_ENABLE_X64=True) or a special command is run when jax is imported (config.update("jax_enable_x64", True)). In order to test double precision, these commands are used various places in the tests.

Unfortunately, enabling double precision also makes it the default for new arrays, creating situations where tests have different behavior when run on their own versus in the whole suite (because the config is "sticky" and setting it in one test affects others).

All of this may change in a future JAX release (https://github.com/google/jax/issues/8178), but for now, I propose running all tests with JAX_ENABLE_X64=True JAX_DEFAULT_DTYPE_BITS=32 and removing any config.updates in test files.

Michael-T-McCann avatar May 04 '22 18:05 Michael-T-McCann

While there has been no further discussion in google/jax#8178 for more than a year, it appears that deprecation of the X64 flag is still being considered. Instead of simply running all the tests with X64 enabled and X32 default, perhaps we should configure that state in scico/__init__.py so that it applies across the code base?

bwohlberg avatar Apr 28 '23 15:04 bwohlberg