equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Minor (?) test failure with JAX 0.8.0 - no warning emitted if JAX transformation set as attribute

Open johannahaffner opened this issue 3 months ago • 5 comments

I'm getting the following error with JAX 0.8.0:

(.jax-080) jhaffner@bs-mbpas-0019 equinox % pytest tests/test_module.py 
======================================================================================== test session starts ========================================================================================
platform darwin -- Python 3.13.3, pytest-8.4.2, pluggy-1.6.0
rootdir: /Users/jhaffner/Desktop/projects/equinox
configfile: pyproject.toml
plugins: jaxtyping-0.3.3
collected 64 items                                                                                                                                                                                  

tests/test_module.py ................................................F...............                                                                                                         [100%]

============================================================================================= FAILURES ==============================================================================================
______________________________________________________________________________________ test_jax_transform_warn ______________________________________________________________________________________

getkey = GetKey(seed=1816561341, call=11, key=Array([         0, 1816561341], dtype=uint32))

    def test_jax_transform_warn(getkey):
        class A(eqx.Module):
            linear: Callable
    
        class B(eqx.Module):
            linear: Callable
    
            def __init__(self, linear):
                self.linear = linear
    
        for cls in (A, B):
            for transform in (
                jax.jit,
                jax.grad,
                jax.vmap,
                jax.value_and_grad,
                jax.jacfwd,
                jax.jacrev,
                jax.hessian,
                jax.custom_jvp,
                jax.custom_vjp,
                jax.checkpoint,  # pyright: ignore
                jax.pmap,
            ):
>               with pytest.warns(
                    match="Possibly assigning a JAX-transformed callable as an attribute"
                ):
E               Failed: DID NOT WARN. No warnings of type (<class 'Warning'>,) were emitted.
E                Emitted warnings: [].

tests/test_module.py:883: Failed
--------------------------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------------------------
testing for: <function jit at 0x10826e8e0>
testing for: <function grad at 0x10826eac0>
testing for: <function vmap at 0x10826f6a0>
testing for: <function value_and_grad at 0x10826eb60>
testing for: <function jacfwd at 0x10826ee80>
testing for: <function jacrev at 0x10826f060>
testing for: <function hessian at 0x10826f1a0>
testing for: <class 'jax._src.custom_derivatives.custom_jvp'>
testing for: <class 'jax._src.custom_derivatives.custom_vjp'>
testing for: <function checkpoint at 0x108824ea0>
testing for: <function pmap at 0x10826f880>
====================================================================================== short test summary info ======================================================================================
FAILED tests/test_module.py::test_jax_transform_warn - Failed: DID NOT WARN. No warnings of type (<class 'Warning'>,) were emitted.

I think the error is emitted after the loop has been run, and I don't have time to look into this now, so posting it here.

johannahaffner avatar Oct 16 '25 08:10 johannahaffner

PS: we don't get any other errata!

johannahaffner avatar Oct 16 '25 08:10 johannahaffner

P.P.S: We also do not get any errata in Lineax, Optimistix, or Diffrax. 🥳

johannahaffner avatar Oct 16 '25 13:10 johannahaffner

Looks like a minor issue in JAX; isolated and opened the GitHub issue above.

Not the end of the world for us if it's not fixed - we just stop supporting nice error messages with jax.pmap, which is anyway a fairly uncommon function to use these days.

patrick-kidger avatar Oct 16 '25 21:10 patrick-kidger

Can confirm that on the latest commit on JAX's main branch (jax==0.8.1.dev20251116) as of today, this test passes again both for CUDA:

(equinox) arturgalstyan@talon:~/Workspace/equinox$ uv pip install -U --pre jax jaxlib "jax-cuda13-plugin[with-cuda]" jax-cuda13-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Resolved 22 packages in 22.70s
Prepared 4 packages in 18.44s
Uninstalled 4 packages in 15ms
Installed 4 packages in 11ms
 - jax==0.8.0
 + jax==0.8.1.dev20251116
 - jax-cuda13-pjrt==0.8.0
 + jax-cuda13-pjrt==0.8.1.dev20251116
 - jax-cuda13-plugin==0.8.0
 + jax-cuda13-plugin==0.8.1.dev20251116
 - jaxlib==0.8.0
 + jaxlib==0.8.1.dev20251116
(equinox) arturgalstyan@talon:~/Workspace/equinox$ pytest -x -s tests/test_module.py
==================================================================================== test session starts =====================================================================================
platform linux -- Python 3.14.0, pytest-9.0.1, pluggy-1.6.0
rootdir: /home/arturgalstyan/Workspace/equinox
configfile: pyproject.toml
plugins: jaxtyping-0.3.3
collected 64 items

tests/test_module.py .....W1116 13:16:32.605954   17634 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1116 13:16:32.607793   17533 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
...........................................................

===================================================================================== 64 passed in 1.56s =====================================================================================
(equinox) arturgalstyan@talon:~/Workspace/equinox$

And CPU:

(equinox) ➜  equinox git:(main) ✗ uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Resolved 6 packages in 10.34s
Prepared 2 packages in 9.06s
Uninstalled 2 packages in 83ms
Installed 2 packages in 6ms
 - jax==0.8.0
 + jax==0.8.1.dev20251116
 - jaxlib==0.8.0
 + jaxlib==0.8.1.dev20251116
(equinox) ➜  equinox git:(main) ✗  pytest -x -s tests/test_module.py
===================================================================================== test session starts ======================================================================================
platform darwin -- Python 3.13.9, pytest-9.0.1, pluggy-1.6.0
rootdir: /Users/arturgalstyan/Workspace/equinox
configfile: pyproject.toml
plugins: jaxtyping-0.3.3
collected 64 items

tests/test_module.py ................................................................

====================================================================================== 64 passed in 0.72s ======================================================================================
(equinox) ➜  equinox git:(main) ✗

Artur-Galstyan avatar Nov 16 '25 12:11 Artur-Galstyan

Thank you!

johannahaffner avatar Nov 16 '25 12:11 johannahaffner