torchquad
torchquad copied to clipboard
Update `jax` config import
This PR changes the import statement for jax.config, which resolves the error
File "/usr/local/lib/python3.10/dist-packages/torchquad/utils/set_precision.py", line 58, in set_precision
from jax.config import config
ImportError: cannot import name 'config' from 'jax.config' (/usr/local/lib/python3.10/dist-packages/jax/config.py)
Hi @alberthli , thanks for the PR!
To understand a bit better, it is my understand jax deprecated this import recently.
Did you check by any chance if this works for older jax version as well? Currently, we require jax>=0.2.22 , so we may need to update the version
(P.S. don't worry update the failing tests, the failures are unrelated to this PR, I think, and a regression in the CI)
Hi @gomezzz, I haven't checked whether it works for earlier versions. This commit from 6 months ago seems relevant if you want to control versioning, though.
Hi @alberthli ! Sorry for the long delay. Yes, it looks good and works mostly with jax>=0.4.17 (having trouble to test with earlier version due to issues with jax). Maybe we should bump the recommended version of jax though.
Changes for that would be:
- [ ] Change jax version mentioned here https://github.com/esa/torchquad?tab=readme-ov-file#prerequisites
- [ ] And here https://github.com/esa/torchquad/blob/6c1b8cd17830c80457477bf33c429a4b06fab3f2/environment_all_backends.yml#L23
- [ ] And here https://github.com/esa/torchquad/blob/main/docs/source/install.rst
Would you mind updating it, @alberthli in this PR? Otherwise I can do it.
I have been trying to run the tests locally on CPU too but I have a problem with jax now.
FAILED integrator_types_test.py::test_integrate_jax - AssertionError: assert 'float64' == 'float32'
FAILED monte_carlo_test.py::test_integrate_jax - assert 0.01089246934838961 < 0.01
FAILED utils_integration_test.py::test_setup_integration_domain - AssertionError: assert 'float64' == 'float32'
I think the middle one just has a too aggressive threshold but the other two seem to be changes in the way jax behaves?
Errors are thrown here
integrator_types_test.py:90: in _run_simple_integrations
result = integrator.integrate(
../integration/trapezoid.py:25: in integrate
return super().integrate(fn, dim, N, integration_domain, backend)
../integration/grid_integrator.py:50: in integrate
function_values, num_points = self.evaluate_integrand(
../integration/base_integrator.py:65: in evaluate_integrand
result = fn(points, *args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
x = Array([[ 0. , -2. ],
[ 0.08333333, -2. ],
[ 0.16666667, -2. ],
[ 0.25... [ 0.83333333, 0. ],
[ 0.91666667, 0. ],
[ 1. , 0. ]], dtype=float64)
def fn_const(x):
assert infer_backend(x) == backend
> assert get_dtype_name(x) == expected_dtype_name
E AssertionError: assert 'float64' == 'float32'
E
E - float32
E + float64
integrator_types_test.py:43: AssertionError
and
utils_integration_test.py:48: in _run_tests_with_all_backends
func(dtype_name=dtype_name, backend=backend, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
dtype_name = 'float32', backend = 'jax'
def _run_setup_integration_domain_tests(dtype_name, backend):
"""
Test _setup_integration_domain with the given dtype and numerical backend
"""
print(
f"Testing _setup_integration_domain; backend: {backend}, precision: {dtype_name}"
)
# Domain given as List with Python floats
domain = _setup_integration_domain(2, [[0.0, 1.0], [1.0, 2.0]], backend)
assert infer_backend(domain) == backend
> assert get_dtype_name(domain) == dtype_name
E AssertionError: assert 'float64' == 'float32'
E
E - float32
E + float64
So either we just change the type in the test or might have to make a change in the set_precision.py. If this is too much, @alberthli , we can also move that to a dedicated issue since it is not directly related to your changes, if you prefer?
Thanks and sorry again for the delays!
Hi @gomezzz, my bandwidth is very limited in the next couple of weeks, so I probably won't be able to write this PR. I think separating the testing issues into a separate issue is a good idea, though perhaps that fix should be merged with this one together.
Hi, I would like to try torchquad with JAX for my project. Shouldn't we ask for a minimal JAX version of 0.4.25 following https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-25-feb-26-2024 ? I can make the PR if needed
Hi @HGangloff , ah yes, I was looking for that info unsuccessfully. I don't know jax so well :).
Sounds good, please go ahead, thanks! In case you have an idea why the datatypes in the tests changed (which should have been set via the here modified set_precision function, I think) I would appreciate your insight!