timemachine icon indicating copy to clipboard operation
timemachine copied to clipboard

Use jax.config.update("jax_enable_x64", True) more precisely

Open badisa opened this issue 3 years ago • 3 comments

Goal

  • Provide control over the precision of jax ~* Avoid flake8 linting issues regarding imports not at top of file (#596)~

Issues

  • Jax on the GPU has a hard time with 64 bit
  • Want to be able to use 32 bit potentials

badisa avatar Feb 02 '22 18:02 badisa

There are definitely use cases where we may want to use 32-bit versions of the potentials. It's pretty dangerous to forcefully set 64bit when we import the timemachine module. I think we need to be a bit more careful about how this is set (because it will be very hard to undo afterwards).

proteneer avatar Feb 02 '22 18:02 proteneer

Could you specify how we want this to function? Right now we have the call to configure it to use 64 bit in enough places that I would imagine almost all code uses the 64 bit version.

badisa avatar Feb 02 '22 18:02 badisa

What about having this be set where appropriate, but more broadly.

For tests (where this 64bit critical), I think we can call jax.config.update during setUpClass of the TestCase base class, or have this be invoked by a pytest hook. For things like the endpoint correction in production (outside of tests), we probably still want to include this in the main runner script.

The problem is that jax has no way right now to specify the precision of a given operation (similar to how we can specify .impl(precision)). Given that we may at some point use jax on the GPU, which typically have crippled 64bit flops, we want to leave ourselves with an out.

proteneer avatar Feb 02 '22 18:02 proteneer

Done in https://github.com/proteneer/timemachine/pull/867

badisa avatar Oct 04 '22 15:10 badisa