stingray icon indicating copy to clipboard operation
stingray copied to clipboard

Test problems with JAX

Open matteobachetti opened this issue 1 year ago • 3 comments

@dhuppenkothen @Gaurav17Joshi

  1. we are getting this new deprecation warning. Better to fix it asap, so that we have a stable enough API

    DeprecationWarning: jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.

  2. new test problem:

ERROR ../../.tox/py311-test-alldeps-cov/lib/python3.11/site-packages/stingray/modeling/tests/test_gpmodeling.py::TestGPResult::test_sample - jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[5] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_U at /home/runner/work/stingray/stingray/.tox/py311-test-alldeps-cov/lib/python3.11/site-packages/jaxns/model.py:47 traced for jit.

matteobachetti avatar Oct 05 '23 14:10 matteobachetti

Ok, it's from a dependency. I'm fixing it in #763 by adding an ignore line to setup.cfg

matteobachetti avatar Oct 05 '23 15:10 matteobachetti

I'm shutting down the second by setting xfail on the gpmodeling tests. This needs to be addressed, maybe in #767

matteobachetti avatar Oct 05 '23 15:10 matteobachetti

I think the second error is happening in the Nested sampling step (from line 286 to 295 ) of the test_gpmodeling, as the issue says the leak happens in the function sample_U, which is a function used by JAXNs while sampling.

Gaurav17Joshi avatar Oct 05 '23 17:10 Gaurav17Joshi