clrs
clrs copied to clipboard
Problems with jax
I installed the required libraries by pip install -r requirement.txt. The CUDA works well and the GPU can be found by tensorflow. However, when I try to run the code, an error occurs.
"Unable to initialize backend 'cuda'": module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I searched on the web and found it might be caused by the version of the libraries.
Could you please share the versions of those packages you used with me? Thank you very much.
Yes, same for me @WilliamLi0623. Current solution : use pip install -U jaxlib=={version no}+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
cc @PetarV-
it still doesnt work
log dump
2023-01-21 20:01:53.374343: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:53.537534: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:53.539034: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:54.230774: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:55.316097: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-21 20:01:55.316175: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-01-21 20:01:55.316187: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-01-21 20:01:59.913842: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.913949: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914012: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914076: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914133: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcurand.so.10'; dlerror: libcurand.so.10: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914191: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusolver.so.11'; dlerror: libcusolver.so.11: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914254: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914312: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914338: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
I0121 20:01:59.939431 139832542189376 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0121 20:02:00.399612 139832542189376 xla_bridge.py:355] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Host Interpreter
I0121 20:02:00.400197 139832542189376 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0121 20:02:00.400516 139832542189376 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
2023-01-21 20:02:01.082623: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:109] Couldn't get ptxas version : FAILED_PRECONDITION: Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version
2023-01-21 20:02:01.083637: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:451] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: Failed to launch ptxas' If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.
Fatal Python error: Aborted
Thread 0x00007f2d4d015740 (most recent call first):
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 1014 in backend_compile
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/profiler.py", line 314 in wrapper
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 1079 in compile_or_get_cached
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3439 in from_hlo
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3170 in _compile_unloaded
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3202 in compile
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 359 in _xla_callable_uncached
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 202 in xla_primitive_callable
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/util.py", line 247 in cached
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/util.py", line 254 in wrapper
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 118 in apply_primitive
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 712 in process_primitive
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 332 in bind_with_trace
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 329 in bind
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 509 in shift_right_logical
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 827 in threefry_seed
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 592 in random_seed_impl_base
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 587 in random_seed_impl
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 712 in process_primitive
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 332 in bind_with_trace
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 329 in bind
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 575 in random_seed
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 267 in seed_with_impl
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/random.py", line 133 in PRNGKey
File "/mnt/infonas/data/prateekch/clrs/clrs/examples/run.py", line 379 in main
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/absl/app.py", line 254 in _run_main
File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/absl/app.py", line 308 in run
File "/mnt/infonas/data/prateekch/clrs/clrs/examples/run.py", line 537 in <module>
File "/usr/lib/python3.8/runpy.py", line 87 in _run_code
File "/usr/lib/python3.8/runpy.py", line 194 in _run_module_as_main
Aborted (core dumped)
Sorry that the package sometimes doesn't work on GPU out-of-the-box. You need to make sure that JAX is installed in cuda-compatible version. Maybe try the following steps:
python3 -m venv clrs_env
source clrs_env/bin/activate
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install git+https://github.com/deepmind/clrs.git
If the jax[cuda] install is successful, you should be able to run
import jax
print(jax.local_devices())
in a python interpreter and see your GPU listed among the devices.
The JAX installation guide has good pointers to potential problems with the JAX installation on GPU. One thing to keep in mind is that JAX expects the CUDA installation to be at /usr/local/cuda-X.X
. If the CUDA libraries are somewhere else, try creating a symlink
sudo ln -s /path/to/cuda /usr/local/cuda-X.X
Good Afternoon,
I am still experiencing issues, when I run python3 -m clrs.examples.run
, I get this error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure
. The full error is shown below, I have had a look but couldn't find anything to fix this. Any help would be appreciated.
Thanks, Sean
I am running Python3.9 with JAX: jax 0.4.4 jaxlib 0.4.4+cuda11.cudnn86
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
Traceback (most recent call last):
File "/usr/lib64/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib64/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 537, in <module>
app.run(main)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 380, in main
rng_key = jax.random.PRNGKey(rng.randint(2**32))
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/random.py", line 136, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 267, in seed_with_impl
return random_seed(seed, impl=impl)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 570, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 343, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 346, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 789, in process_primitive
return primitive.impl(*tracers, **params)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 582, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 587, in random_seed_impl_base
return seed(seeds)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 822, in threefry_seed
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/lax/lax.py", line 511, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packagesserialized_computation/jax/_src/core.py", line 343, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 346, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 789, in process_primitive
return primitive.impl(*tracers, **params)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 123, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/util.py", line 246, in cached
return f(*args, **kwargs)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 202, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 355, in _xla_callable_uncached
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3254, in compile
executable = self._compile_unloaded(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3225, in _compile_unloaded
return UnloadedMeshExecutable.from_hlo(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3512, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 1095, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 1040, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:627) dnn != nullptr
According to this, jax/jaxlib >= 0.4.3 seems to be incompatible with CuDNN 8.6. It would seem you have to use jax/jaxlib 0.4.2 or CuDNN 8.8, give it a try.
Good Evening,
Thank you for the help, that seems to have got the benchmark code started with GPU support but I am now seeing this error: jax._src.traceback_util.UnfilteredStackTrace: TypeError: Subscripted generics cannot be used with class and instance checks
.
From google I think this is to do with Python 3.9 and above. I don't see anything related in the JAX GitHub Issues. Is there a recommended version of Python for the CLRS code or do you think there is another cause?
Thanks, Sean
2023-02-23 19:07:41.985101: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /local/java/cuda-11.6.0/lib64/:/local/java/cudnn-linux-x86_64-8.5.0.96_cuda11-archive/lib/
2023-02-23 19:07:41.985869: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /local/java/cuda-11.6.0/lib64/:/local/java/cudnn-linux-x86_64-8.5.0.96_cuda11-archive/lib/
2023-02-23 19:07:41.985881: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
I0223 19:07:49.846865 139926348851008 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0223 19:07:49.906153 139926348851008 xla_bridge.py:355] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Host Interpreter CUDA
I0223 19:07:49.906481 139926348851008 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0223 19:07:49.906545 139926348851008 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
I0223 19:07:55.778950 139926348851008 run.py:299] Creating samplers for algo bfs
W0223 19:07:55.779258 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:55.779516 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:56.005018 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:56.005211 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:56.241070 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:56.241253 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:56.525957 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:56.526148 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:56.848360 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:56.848542 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:57.231205 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
I0223 19:07:57.231393 139926348851008 samplers.py:112] Creating a dataset with 64 samples.
I0223 19:07:57.256827 139926348851008 run.py:158] Dataset not found in /tmp/CLRS30/CLRS30_v1.0.0. Downloading...
I0223 19:08:11.790027 139926348851008 dataset_info.py:482] Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/bfs_test/1.0.0
I0223 19:08:11.791674 139926348851008 dataset_info.py:482] Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/bfs_test/1.0.0
I0223 19:08:11.792154 139926348851008 dataset_builder.py:366] Reusing dataset clrs_dataset (/tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/bfs_test/1.0.0)
I0223 19:08:11.792221 139926348851008 logging_logger.py:44] Constructing tf.data.Dataset clrs_dataset for split test, from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/bfs_test/1.0.0
WARNING:tensorflow:From /dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0223 19:08:11.990011 139926348851008 deprecation.py:350] From /dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Traceback (most recent call last):
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 537, in <module>
app.run(main)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 464, in main
cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 370, in feedback
loss, self._device_params, self._device_opt_state = self.jitted_feedback(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/api.py", line 564, in cache_miss
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 241, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 357, in _xla_callable_uncached
computation = sharded_lowering(fun, device, backend, name, donated_invars,
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 348, in sharded_lowering
return pxla.lower_sharding_computation(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/interpreters/pxla.py", line 2790, in lower_sharding_computation
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2073, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2006, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 318, in _feedback
params, opt_state = self._update_params(params, grads, opt_state,
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 428, in _update_params
updates, opt_state = filter_null_grads(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 769, in filter_null_grads
flat_opt_state = jax.tree_util.tree_map(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/tree_util.py", line 207, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/tree_util.py", line 207, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 771, in <lambda>
if not isinstance(x, _Array) else x, opt_state_skeleton, opt_state)
File "/usr/lib64/python3.9/typing.py", line 720, in __instancecheck__
return self.__subclasscheck__(type(obj))
File "/usr/lib64/python3.9/typing.py", line 723, in __subclasscheck__
raise TypeError("Subscripted generics cannot be used with"
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Subscripted generics cannot be used with class and instance checks
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib64/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib64/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 537, in <module>
app.run(main)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 464, in main
cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx)
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 370, in feedback
loss, self._device_params, self._device_opt_state = self.jitted_feedback(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 318, in _feedback
params, opt_state = self._update_params(params, grads, opt_state,
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 428, in _update_params
updates, opt_state = filter_null_grads(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 769, in filter_null_grads
flat_opt_state = jax.tree_util.tree_map(
File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 771, in <lambda>
if not isinstance(x, _Array) else x, opt_state_skeleton, opt_state)
File "/usr/lib64/python3.9/typing.py", line 720, in __instancecheck__
return self.__subclasscheck__(type(obj))
File "/usr/lib64/python3.9/typing.py", line 723, in __subclasscheck__
raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks
@mcleish7 it should be fixed in https://github.com/deepmind/clrs/commit/2b37ff3f6d56b2e2e43806b7c5635282e888f505
@hbq1 @bibarzgoogle Thank you for your help, it is now working.