torchquad icon indicating copy to clipboard operation
torchquad copied to clipboard

Update `jax` config import

Open alberthli opened this issue 1 year ago • 6 comments
trafficstars

This PR changes the import statement for jax.config, which resolves the error

File "/usr/local/lib/python3.10/dist-packages/torchquad/utils/set_precision.py", line 58, in set_precision
    from jax.config import config
ImportError: cannot import name 'config' from 'jax.config' (/usr/local/lib/python3.10/dist-packages/jax/config.py)

alberthli avatar Apr 25 '24 00:04 alberthli

Hi @alberthli , thanks for the PR!

To understand a bit better, it is my understand jax deprecated this import recently.

Did you check by any chance if this works for older jax version as well? Currently, we require jax>=0.2.22 , so we may need to update the version

(P.S. don't worry update the failing tests, the failures are unrelated to this PR, I think, and a regression in the CI)

gomezzz avatar Apr 25 '24 10:04 gomezzz

Hi @gomezzz, I haven't checked whether it works for earlier versions. This commit from 6 months ago seems relevant if you want to control versioning, though.

alberthli avatar Apr 25 '24 16:04 alberthli

Hi @alberthli ! Sorry for the long delay. Yes, it looks good and works mostly with jax>=0.4.17 (having trouble to test with earlier version due to issues with jax). Maybe we should bump the recommended version of jax though.

Changes for that would be:

  • [ ] Change jax version mentioned here https://github.com/esa/torchquad?tab=readme-ov-file#prerequisites
  • [ ] And here https://github.com/esa/torchquad/blob/6c1b8cd17830c80457477bf33c429a4b06fab3f2/environment_all_backends.yml#L23
  • [ ] And here https://github.com/esa/torchquad/blob/main/docs/source/install.rst

Would you mind updating it, @alberthli in this PR? Otherwise I can do it.

I have been trying to run the tests locally on CPU too but I have a problem with jax now.

FAILED integrator_types_test.py::test_integrate_jax - AssertionError: assert 'float64' == 'float32'
FAILED monte_carlo_test.py::test_integrate_jax - assert 0.01089246934838961 < 0.01
FAILED utils_integration_test.py::test_setup_integration_domain - AssertionError: assert 'float64' == 'float32'

I think the middle one just has a too aggressive threshold but the other two seem to be changes in the way jax behaves?

Errors are thrown here

integrator_types_test.py:90: in _run_simple_integrations
    result = integrator.integrate(
../integration/trapezoid.py:25: in integrate
    return super().integrate(fn, dim, N, integration_domain, backend)
../integration/grid_integrator.py:50: in integrate
    function_values, num_points = self.evaluate_integrand(
../integration/base_integrator.py:65: in evaluate_integrand
    result = fn(points, *args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

x = Array([[ 0.        , -2.        ],
       [ 0.08333333, -2.        ],
       [ 0.16666667, -2.        ],
       [ 0.25...      [ 0.83333333,  0.        ],
       [ 0.91666667,  0.        ],
       [ 1.        ,  0.        ]], dtype=float64)

    def fn_const(x):
        assert infer_backend(x) == backend
>       assert get_dtype_name(x) == expected_dtype_name
E       AssertionError: assert 'float64' == 'float32'
E         
E         - float32
E         + float64

integrator_types_test.py:43: AssertionError

and

utils_integration_test.py:48: in _run_tests_with_all_backends
    func(dtype_name=dtype_name, backend=backend, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

dtype_name = 'float32', backend = 'jax'

    def _run_setup_integration_domain_tests(dtype_name, backend):
        """
        Test _setup_integration_domain with the given dtype and numerical backend
        """
        print(
            f"Testing _setup_integration_domain; backend: {backend}, precision: {dtype_name}"
        )
    
        # Domain given as List with Python floats
        domain = _setup_integration_domain(2, [[0.0, 1.0], [1.0, 2.0]], backend)
        assert infer_backend(domain) == backend
>       assert get_dtype_name(domain) == dtype_name
E       AssertionError: assert 'float64' == 'float32'
E         
E         - float32
E         + float64

So either we just change the type in the test or might have to make a change in the set_precision.py. If this is too much, @alberthli , we can also move that to a dedicated issue since it is not directly related to your changes, if you prefer?

Thanks and sorry again for the delays!

gomezzz avatar Sep 12 '24 09:09 gomezzz

Hi @gomezzz, my bandwidth is very limited in the next couple of weeks, so I probably won't be able to write this PR. I think separating the testing issues into a separate issue is a good idea, though perhaps that fix should be merged with this one together.

alberthli avatar Sep 12 '24 16:09 alberthli

Hi, I would like to try torchquad with JAX for my project. Shouldn't we ask for a minimal JAX version of 0.4.25 following https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-25-feb-26-2024 ? I can make the PR if needed

HGangloff avatar Sep 13 '24 07:09 HGangloff

Hi @HGangloff , ah yes, I was looking for that info unsuccessfully. I don't know jax so well :).

Sounds good, please go ahead, thanks! In case you have an idea why the datatypes in the tests changed (which should have been set via the here modified set_precision function, I think) I would appreciate your insight!

gomezzz avatar Sep 13 '24 07:09 gomezzz