timemachine
timemachine copied to clipboard
Use jax.config.update("jax_enable_x64", True) more precisely
Goal
- Provide control over the precision of jax ~* Avoid flake8 linting issues regarding imports not at top of file (#596)~
Issues
- Jax on the GPU has a hard time with 64 bit
- Want to be able to use 32 bit potentials
There are definitely use cases where we may want to use 32-bit versions of the potentials. It's pretty dangerous to forcefully set 64bit when we import the timemachine module. I think we need to be a bit more careful about how this is set (because it will be very hard to undo afterwards).
Could you specify how we want this to function? Right now we have the call to configure it to use 64 bit in enough places that I would imagine almost all code uses the 64 bit version.
What about having this be set where appropriate, but more broadly.
For tests (where this 64bit critical), I think we can call jax.config.update
during setUpClass
of the TestCase base class, or have this be invoked by a pytest hook. For things like the endpoint correction in production (outside of tests), we probably still want to include this in the main runner script.
The problem is that jax has no way right now to specify the precision of a given operation (similar to how we can specify .impl(precision)
). Given that we may at some point use jax on the GPU, which typically have crippled 64bit flops, we want to leave ourselves with an out.
Done in https://github.com/proteneer/timemachine/pull/867