BayesNewton icon indicating copy to clipboard operation
BayesNewton copied to clipboard

Latest versions of JAX and objax cause compile slow down

Open wil-j-wil opened this issue 3 years ago • 0 comments

It is recommended to use the following versions of jax and objax:

jax==0.2.9
jaxlib==0.1.60
objax==1.3.1

This is because of this objax issue which causes the model to JIT compile "twice", i.e. on the first two iterations rather than just the first. This causes a bit of a slow down for large models, but is not an problem otherwise.

wil-j-wil avatar Jun 29 '21 08:06 wil-j-wil