jax icon indicating copy to clipboard operation
jax copied to clipboard

Running pytest locally -- against AttributeError: num_generated_cases

Open wookayin opened this issue 3 years ago • 2 comments

Description

I'm trying to run pytest locally, but the test session fails due to AttributeError: num_generated_cases

The command: either $ pytest or $ JAX_NUM_GENERATED_CASES=1 pytest

...
ERROR tests/state_test.py - AttributeError: num_generated_cases
ERROR tests/stax_test.py - AttributeError: num_generated_cases
ERROR tests/svd_test.py - AttributeError: num_generated_cases
ERROR tests/x64_context_test.py - AttributeError: num_generated_cases
ERROR tests/xmap_test.py - AttributeError: num_generated_cases
ERROR tests/third_party/scipy/line_search_test.py - AttributeError: num_generated_cases

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 49 errors during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!
========================================= 49 errors in 6.24s ========================================

Failing tests' stacktraces are:

___________ ERROR collecting tests/third_party/scipy/line_search_test.py ___________
tests/third_party/scipy/line_search_test.py:14: in <module>
    class TestLineSearch(jtu.JaxTestCase):
tests/third_party/scipy/line_search_test.py:64: in TestLineSearch
    @parameterized.named_parameters(jtu.cases_from_list(
jax/_src/test_util.py:661: in cases_from_list
    k = min(n, FLAGS.num_generated_cases)
jax/_src/config.py:453: in __getattr__
    return self._getter(name)
jax/_src/config.py:97: in read
    return self._read(name)
jax/_src/config.py:101: in _read
    return getattr(self.absl_flags.FLAGS, name)
.../lib/python3.10/site-packages/absl/flags/_flagvalues.py:471: in __getattr__
    raise AttributeError(name)
E   AttributeError: num_generated_cases

But FLAGS.num_generated_cases is defined in jax/_src/test_util.py:

flags.DEFINE_integer(                              
  'num_generated_cases',                           
  int(os.getenv('JAX_NUM_GENERATED_CASES', '10')), 
  help='Number of generated cases to test')        

Running individual tests, i.e. pytest tests/third_party/scipy/line_search_test.py works fine. During collecting the test suites, ABSL flags seem to not have been registered properly (maybe at the import time).

What jax/jaxlib version are you using?

latest master (0.3.17+)

Which accelerator(s) are you using?

N/A

Additional System Info

macOS

  • Python: 3.10.5 (miniconda3)
  • pytest 7.1.3
  • absl-py 1.2.0

wookayin avatar Sep 19 '22 01:09 wookayin

absl's flag mechanism has always been a big pain for me because it has a side effect on the module-level, e.g., use of global variables as state. I wonder if absl flag is reset somewhere while collecting/executing the module, or absl flag parsing is skipped due to already_configured_with_absl. It looks like once parse_flags_with_absl is called, all subsequent FLAGS.define_XXX are ignored.

wookayin avatar Sep 19 '22 02:09 wookayin

Given the above observation, a workaround I find is to add following lines to conftest.py, which is executed before collecting pytest suites:

# Ensure JAX test flags are registered before collecting test suites.
# see https://github.com/google/jax/issues/12411
# pylint: disable-next=unused-import
import jax._src.test_util

which eliminates all the AttributeErrors during test collection.

wookayin avatar Sep 19 '22 02:09 wookayin

I suspect the issue is we always run pytest tests not pytest. But it seems like a reasonable thing to want to do.

hawkinsp avatar Sep 23 '22 17:09 hawkinsp