Minor (?) test failure with JAX 0.8.0 - no warning emitted if JAX transformation set as attribute
I'm getting the following error with JAX 0.8.0:
(.jax-080) jhaffner@bs-mbpas-0019 equinox % pytest tests/test_module.py
======================================================================================== test session starts ========================================================================================
platform darwin -- Python 3.13.3, pytest-8.4.2, pluggy-1.6.0
rootdir: /Users/jhaffner/Desktop/projects/equinox
configfile: pyproject.toml
plugins: jaxtyping-0.3.3
collected 64 items
tests/test_module.py ................................................F............... [100%]
============================================================================================= FAILURES ==============================================================================================
______________________________________________________________________________________ test_jax_transform_warn ______________________________________________________________________________________
getkey = GetKey(seed=1816561341, call=11, key=Array([ 0, 1816561341], dtype=uint32))
def test_jax_transform_warn(getkey):
class A(eqx.Module):
linear: Callable
class B(eqx.Module):
linear: Callable
def __init__(self, linear):
self.linear = linear
for cls in (A, B):
for transform in (
jax.jit,
jax.grad,
jax.vmap,
jax.value_and_grad,
jax.jacfwd,
jax.jacrev,
jax.hessian,
jax.custom_jvp,
jax.custom_vjp,
jax.checkpoint, # pyright: ignore
jax.pmap,
):
> with pytest.warns(
match="Possibly assigning a JAX-transformed callable as an attribute"
):
E Failed: DID NOT WARN. No warnings of type (<class 'Warning'>,) were emitted.
E Emitted warnings: [].
tests/test_module.py:883: Failed
--------------------------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------------------------
testing for: <function jit at 0x10826e8e0>
testing for: <function grad at 0x10826eac0>
testing for: <function vmap at 0x10826f6a0>
testing for: <function value_and_grad at 0x10826eb60>
testing for: <function jacfwd at 0x10826ee80>
testing for: <function jacrev at 0x10826f060>
testing for: <function hessian at 0x10826f1a0>
testing for: <class 'jax._src.custom_derivatives.custom_jvp'>
testing for: <class 'jax._src.custom_derivatives.custom_vjp'>
testing for: <function checkpoint at 0x108824ea0>
testing for: <function pmap at 0x10826f880>
====================================================================================== short test summary info ======================================================================================
FAILED tests/test_module.py::test_jax_transform_warn - Failed: DID NOT WARN. No warnings of type (<class 'Warning'>,) were emitted.
I think the error is emitted after the loop has been run, and I don't have time to look into this now, so posting it here.
PS: we don't get any other errata!
P.P.S: We also do not get any errata in Lineax, Optimistix, or Diffrax. 🥳
Looks like a minor issue in JAX; isolated and opened the GitHub issue above.
Not the end of the world for us if it's not fixed - we just stop supporting nice error messages with jax.pmap, which is anyway a fairly uncommon function to use these days.
Can confirm that on the latest commit on JAX's main branch (jax==0.8.1.dev20251116) as of today, this test passes again both for CUDA:
(equinox) arturgalstyan@talon:~/Workspace/equinox$ uv pip install -U --pre jax jaxlib "jax-cuda13-plugin[with-cuda]" jax-cuda13-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Resolved 22 packages in 22.70s
Prepared 4 packages in 18.44s
Uninstalled 4 packages in 15ms
Installed 4 packages in 11ms
- jax==0.8.0
+ jax==0.8.1.dev20251116
- jax-cuda13-pjrt==0.8.0
+ jax-cuda13-pjrt==0.8.1.dev20251116
- jax-cuda13-plugin==0.8.0
+ jax-cuda13-plugin==0.8.1.dev20251116
- jaxlib==0.8.0
+ jaxlib==0.8.1.dev20251116
(equinox) arturgalstyan@talon:~/Workspace/equinox$ pytest -x -s tests/test_module.py
==================================================================================== test session starts =====================================================================================
platform linux -- Python 3.14.0, pytest-9.0.1, pluggy-1.6.0
rootdir: /home/arturgalstyan/Workspace/equinox
configfile: pyproject.toml
plugins: jaxtyping-0.3.3
collected 64 items
tests/test_module.py .....W1116 13:16:32.605954 17634 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1116 13:16:32.607793 17533 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
...........................................................
===================================================================================== 64 passed in 1.56s =====================================================================================
(equinox) arturgalstyan@talon:~/Workspace/equinox$
And CPU:
(equinox) ➜ equinox git:(main) ✗ uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Resolved 6 packages in 10.34s
Prepared 2 packages in 9.06s
Uninstalled 2 packages in 83ms
Installed 2 packages in 6ms
- jax==0.8.0
+ jax==0.8.1.dev20251116
- jaxlib==0.8.0
+ jaxlib==0.8.1.dev20251116
(equinox) ➜ equinox git:(main) ✗ pytest -x -s tests/test_module.py
===================================================================================== test session starts ======================================================================================
platform darwin -- Python 3.13.9, pytest-9.0.1, pluggy-1.6.0
rootdir: /Users/arturgalstyan/Workspace/equinox
configfile: pyproject.toml
plugins: jaxtyping-0.3.3
collected 64 items
tests/test_module.py ................................................................
====================================================================================== 64 passed in 0.72s ======================================================================================
(equinox) ➜ equinox git:(main) ✗
Thank you!