jax
jax copied to clipboard
Test documentation fails
Description
I'm seeing an automated test fail in the documentation (https://github.com/google/jax/actions/runs/3055125249/jobs/4927830518).
The log of the failure is
Run pytest -n 1 --tb=short docs
============================= test session starts ==============================
platform linux -- Python 3.7.[13](https://github.com/google/jax/actions/runs/3055125249/jobs/4927830518#step:8:14), pytest-7.1.3, pluggy-1.0.0
rootdir: /home/runner/work/jax/jax, configfile: pytest.ini
plugins: xdist-2.5.0, forked-1.4.0
gw0 I
gw0 [8]
.F...... [100%]
=================================== FAILURES ===================================
______________________________ [doctest] faq.rst _______________________________
[gw0] linux -- Python 3.7.13 /opt/hostedtoolcache/Python/3.7.13/x64/bin/python
[24](https://github.com/google/jax/actions/runs/3055125249/jobs/4927830518#step:8:25)4 relevant class attributes, and we've also defined the ``__eq__`` method because it's
245 good practice to do so any time you override ``__hash__`` (see
246 `Python Data Model: __hash__ <https://docs.python.org/3/reference/datamodel.html#object.__hash__>`_
247 for more information on this). With this addition, the example works correctly::
248
249 >>> c = CustomClass(2, True)
[25](https://github.com/google/jax/actions/runs/3055125249/jobs/4927830518#step:8:26)0 >>> print(c.calc(3))
251 6
252 >>> c.mul = False
253 >>> print(c.calc(3))
Expected:
3
Got:
6
/home/runner/work/jax/jax/docs/faq.rst:253: DocTestFailure
=========================== short test summary info ============================
FAILED docs/faq.rst::faq.rst
========================= 1 failed, 7 passed in 2.10s ==========================
Error: Process completed with exit code 1.
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional System Info
No response
Thanks - echoing our conversation from elsewhere: this is a failure we've seen periodically before; it's a strange one that is hard to produce. It almost looks like there's some sort of race condition within doctest which causes out-of-order execution, but I can't imagine a mechanism that would produce that and not also show up elsewhere...
OK, I was able to reproduce it... it has something to do with jit cacheing I think:
from functools import partial
from jax import jit
class CustomClass:
def __init__(self, x: int, mul: bool):
self.x = x
self.mul = mul
@partial(jit, static_argnums=0)
def calc(self, y):
if self.mul:
return self.x * y
return y
def __hash__(self):
return hash((self.x, self.mul))
def __eq__(self, other):
return (isinstance(other, CustomClass) and
(self.x, self.mul) == (other.x, other.mul))
for i in range(1000):
c = CustomClass(2, True)
assert c.calc(3) == 6
c.mul = False
assert c.calc(3) == 3, f"failed on iteration {i}"
Output:
Traceback (most recent call last):
File "tmp.py", line 26, in <module>
assert c.calc(3) == 3, f"failed on iteration {i}"
AssertionError: failed on iteration 27