Running pytest locally -- against AttributeError: num_generated_cases
Description
I'm trying to run pytest locally, but the test session fails due to AttributeError: num_generated_cases
The command: either $ pytest or $ JAX_NUM_GENERATED_CASES=1 pytest
...
ERROR tests/state_test.py - AttributeError: num_generated_cases
ERROR tests/stax_test.py - AttributeError: num_generated_cases
ERROR tests/svd_test.py - AttributeError: num_generated_cases
ERROR tests/x64_context_test.py - AttributeError: num_generated_cases
ERROR tests/xmap_test.py - AttributeError: num_generated_cases
ERROR tests/third_party/scipy/line_search_test.py - AttributeError: num_generated_cases
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 49 errors during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!
========================================= 49 errors in 6.24s ========================================
Failing tests' stacktraces are:
___________ ERROR collecting tests/third_party/scipy/line_search_test.py ___________
tests/third_party/scipy/line_search_test.py:14: in <module>
class TestLineSearch(jtu.JaxTestCase):
tests/third_party/scipy/line_search_test.py:64: in TestLineSearch
@parameterized.named_parameters(jtu.cases_from_list(
jax/_src/test_util.py:661: in cases_from_list
k = min(n, FLAGS.num_generated_cases)
jax/_src/config.py:453: in __getattr__
return self._getter(name)
jax/_src/config.py:97: in read
return self._read(name)
jax/_src/config.py:101: in _read
return getattr(self.absl_flags.FLAGS, name)
.../lib/python3.10/site-packages/absl/flags/_flagvalues.py:471: in __getattr__
raise AttributeError(name)
E AttributeError: num_generated_cases
But FLAGS.num_generated_cases is defined in jax/_src/test_util.py:
flags.DEFINE_integer(
'num_generated_cases',
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
help='Number of generated cases to test')
Running individual tests, i.e. pytest tests/third_party/scipy/line_search_test.py works fine. During collecting the test suites, ABSL flags seem to not have been registered properly (maybe at the import time).
What jax/jaxlib version are you using?
latest master (0.3.17+)
Which accelerator(s) are you using?
N/A
Additional System Info
macOS
- Python: 3.10.5 (miniconda3)
- pytest 7.1.3
- absl-py 1.2.0
absl's flag mechanism has always been a big pain for me because it has a side effect on the module-level, e.g., use of global variables as state. I wonder if absl flag is reset somewhere while collecting/executing the module, or absl flag parsing is skipped due to already_configured_with_absl. It looks like once parse_flags_with_absl is called, all subsequent FLAGS.define_XXX are ignored.
Given the above observation, a workaround I find is to add following lines to conftest.py, which is executed before collecting pytest suites:
# Ensure JAX test flags are registered before collecting test suites.
# see https://github.com/google/jax/issues/12411
# pylint: disable-next=unused-import
import jax._src.test_util
which eliminates all the AttributeErrors during test collection.
I suspect the issue is we always run pytest tests not pytest. But it seems like a reasonable thing to want to do.