flax icon indicating copy to clipboard operation
flax copied to clipboard

Test failures: "RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ensure that `set_n_cpu_devices` is executed before any JAX operation."

Open samuela opened this issue 1 year ago • 1 comments

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04.6 LTS
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax 0.6.5, jax 0.4.5, jaxlib 0.4.4
  • Python version: 3.10
  • GPU/TPU model and memory: n/a
  • CUDA version (if applicable): n/a

Problem you have encountered:

I'm seeing a number of failures when running the test suite:

__________ ERROR at setup of PadShardUnpadTest.test_static_argnames8 ___________
[gw18] linux -- Python 3.10.12 /nix/store/jhflvwr40xbb0xr6jx4311icp9cym1fp-python3-3.10.12/bin/python3.10

    def setUpModule():
>     chex.set_n_cpu_devices(NDEV)

tests/jax_utils_test.py:29: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

n = 4

    def set_n_cpu_devices(n: Optional[int] = None) -> None:
      """Forces XLA to use `n` CPU threads as host devices.
    
      This allows `jax.pmap` to be tested on a single-CPU platform.
      This utility only takes effect before XLA backends are initialized, i.e.
      before any JAX operation is executed (including `jax.devices()` etc.).
      See https://github.com/google/jax/issues/1408.
    
      Args:
        n: A required number of CPU devices (``FLAGS.chex_n_cpu_devices`` is used by
          default).
    
      Raises:
        RuntimeError: If XLA backends were already initialized.
      """
      n = n or FLAGS['chex_n_cpu_devices'].value
    
      n_devices = get_n_cpu_devices_from_xla_flags()
      cpu_backend = (jax.lib.xla_bridge._backends or {}).get('cpu', None)  # pylint: disable=protected-access
      if cpu_backend is not None and n_devices != n:
>       raise RuntimeError(
            f'Attempted to set {n} devices, but {n_devices} CPUs already available:'
            ' ensure that `set_n_cpu_devices` is executed before any JAX operation.'
        )
E       RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ensure that `set_n_cpu_devices` is executed before any JAX operation.

/nix/store/pkv6vvr3hcd9zanwi1r0psmpmdaqb9bp-python3.10-chex-0.1.6/lib/python3.10/site-packages/chex/_src/fake.py:74: RuntimeError
=========================== short test summary info ============================
ERROR tests/checkpoints_test.py
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics12 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics23 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum9 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum19 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames22 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames23 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_session - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics4 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees11 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees15 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames11 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum11 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum5 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum6 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum7 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics6 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames9 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum20 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames19 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees16 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames12 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees17 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees9 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames5 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees5 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees6 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames15 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames16 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees19 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics13 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics14 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees20 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees7 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames13 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames17 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames20 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum12 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees12 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees13 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics16 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics17 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum13 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum22 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum23 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics7 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum15 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics8 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics20 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics21 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum16 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum17 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames6 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics10 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees22 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees23 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_min_device_batch_avoids_recompile0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_min_device_batch_avoids_recompile1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_min_device_batch_avoids_recompile2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames7 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics18 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames4 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum10 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum8 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum21 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames10 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees18 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics5 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees21 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees10 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames14 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames18 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames21 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees8 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees14 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics15 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum14 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics9 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum18 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics11 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics22 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum4 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees4 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_min_device_batch_avoids_recompile3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics19 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames8 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...

What you expected to happen:

the test suite to pass

Logs, error messages, etc:

Steps to reproduce:

checkout flax @ 0.6.5 and run the test suite

samuela avatar Jul 27 '23 21:07 samuela

How are you running the test? Is it with pytest -nauto tests/jax_utils_test.py? Are you running any JAX operations before running this test? I was able to re-create the error by inserting the JAX operation jax.numpy.array([[1, 2, 3]]) right beforechex.set_n_cpu_devices(NDEV)

chiamp avatar Sep 19 '23 00:09 chiamp