flax
flax copied to clipboard
Test failures: "RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ensure that `set_n_cpu_devices` is executed before any JAX operation."
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04.6 LTS
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: flax 0.6.5, jax 0.4.5, jaxlib 0.4.4 - Python version: 3.10
- GPU/TPU model and memory: n/a
- CUDA version (if applicable): n/a
Problem you have encountered:
I'm seeing a number of failures when running the test suite:
__________ ERROR at setup of PadShardUnpadTest.test_static_argnames8 ___________
[gw18] linux -- Python 3.10.12 /nix/store/jhflvwr40xbb0xr6jx4311icp9cym1fp-python3-3.10.12/bin/python3.10
def setUpModule():
> chex.set_n_cpu_devices(NDEV)
tests/jax_utils_test.py:29:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
n = 4
def set_n_cpu_devices(n: Optional[int] = None) -> None:
"""Forces XLA to use `n` CPU threads as host devices.
This allows `jax.pmap` to be tested on a single-CPU platform.
This utility only takes effect before XLA backends are initialized, i.e.
before any JAX operation is executed (including `jax.devices()` etc.).
See https://github.com/google/jax/issues/1408.
Args:
n: A required number of CPU devices (``FLAGS.chex_n_cpu_devices`` is used by
default).
Raises:
RuntimeError: If XLA backends were already initialized.
"""
n = n or FLAGS['chex_n_cpu_devices'].value
n_devices = get_n_cpu_devices_from_xla_flags()
cpu_backend = (jax.lib.xla_bridge._backends or {}).get('cpu', None) # pylint: disable=protected-access
if cpu_backend is not None and n_devices != n:
> raise RuntimeError(
f'Attempted to set {n} devices, but {n_devices} CPUs already available:'
' ensure that `set_n_cpu_devices` is executed before any JAX operation.'
)
E RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ensure that `set_n_cpu_devices` is executed before any JAX operation.
/nix/store/pkv6vvr3hcd9zanwi1r0psmpmdaqb9bp-python3.10-chex-0.1.6/lib/python3.10/site-packages/chex/_src/fake.py:74: RuntimeError
=========================== short test summary info ============================
ERROR tests/checkpoints_test.py
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics12 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics23 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum9 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum19 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames22 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames23 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_session - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics4 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees11 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees15 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames11 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum11 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum5 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum6 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum7 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics6 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames9 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum20 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames19 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees16 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames12 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees17 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees9 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames5 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees5 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees6 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames15 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames16 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees19 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics13 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics14 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees20 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees7 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames13 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames17 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames20 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum12 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees12 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees13 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics16 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics17 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum13 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum22 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum23 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics7 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum15 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics8 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics20 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics21 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum16 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum17 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames6 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics10 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees22 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees23 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_min_device_batch_avoids_recompile0 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_min_device_batch_avoids_recompile1 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_min_device_batch_avoids_recompile2 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames7 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics18 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames4 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum10 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum8 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum21 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames10 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees18 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics5 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees21 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees10 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames14 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames18 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames21 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees8 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees14 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics15 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum14 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics9 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum18 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics11 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics22 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnum4 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_trees4 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_min_device_batch_avoids_recompile3 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_basics19 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
ERROR tests/jax_utils_test.py::PadShardUnpadTest::test_static_argnames8 - RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: ens...
What you expected to happen:
the test suite to pass
Logs, error messages, etc:
Steps to reproduce:
checkout flax @ 0.6.5 and run the test suite
How are you running the test? Is it with pytest -nauto tests/jax_utils_test.py
? Are you running any JAX operations before running this test? I was able to re-create the error by inserting the JAX operation jax.numpy.array([[1, 2, 3]])
right beforechex.set_n_cpu_devices(NDEV)