deepmd-kit
deepmd-kit copied to clipboard
[BUG] import deepmd cause failure in jax.random.PRNGKey
Bug summary
import deepmd cause failure in jax.random.PRNGKey.
only happens >= 2.2.8 <= 2.2.7 work.
DeePMD-kit Version
>= 2.2.8
TensorFlow Version
2.14.0
How did you download the software?
pip
Input Files, Running Commands, Error Log, etc.
The minimal code for reproducing the bug. Let's say the script is named as reprod.py
import jax
import tensorflow as tf
jax.random.PRNGKey(44)
print("OK"*30)
import deepmd
jax.random.PRNGKey(44)
print("OK"*30)
Steps to Reproduce
python reprod.py
2024-02-23 21:18:27.871442: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-23 21:18:27.871482: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-23 21:18:27.871502: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-23 21:18:28.465643: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708694309.049719 14210 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
OKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOK
WARNING:tensorflow:From /opt/mamba/lib/python3.10/site-packages/tensorflow/python/compat/v2_compat.py:108: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
WARNING:root:To get the best performance, it is recommended to adjust the number of threads by setting the environment variables OMP_NUM_THREADS, TF_INTRA_OP_PARALLELISM_THREADS, and TF_INTER_OP_PARALLELISM_THREADS. See https://deepmd.rtfd.io/parallelism/ for more information.
2024-02-23 21:18:32.980326: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1541] Failed to get stream capture info: device kernel image is invalid
2024-02-23 21:18:32.980474: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: INVALID_ARGUMENT: stream is uninitialized or in an error state
Traceback (most recent call last):
File "/mnt/user/wangh/workspace/hydrogen/tests/n16.rs.10141820.nf5120.attn_h2.06.01.w128.node01.main/reprod.py", line 8, in <module>
jax.random.PRNGKey(44)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/random.py", line 190, in PRNGKey
return _return_prng_keys(True, _key('PRNGKey', seed, impl))
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/random.py", line 152, in _key
return prng.seed_with_impl(impl, seed)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 413, in seed_with_impl
return random_seed(seed, impl=impl)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 695, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/core.py", line 386, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/core.py", line 389, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/core.py", line 821, in process_primitive
return primitive.impl(*tracers, **params)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 707, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 712, in random_seed_impl_base
return seed(seeds)
File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 941, in threefry_seed
return _threefry_seed(seed)
ValueError: INVALID_ARGUMENT: stream is uninitialized or in an error state
I0000 00:00:1708694313.208544 14210 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
Further Information, Files, and Links
No response
I cannot reproduce.
conda create -n tf214 python=3.11
pip install tensorflow==2.14 jax jaxlib deepmd-kit==2.2.8
python reprod.py
It prints OK.
2024-02-23 16:05:02.287719: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-23 16:05:02.287752: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-23 16:05:02.287773: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-23 16:05:02.875749: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
OKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOK
WARNING:tensorflow:From /home/jz748/anaconda3/envs/tf214/lib/python3.11/site-packages/tensorflow/python/compat/v2_compat.py:108: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
WARNING:root:To get the best performance, it is recommended to adjust the number of threads by setting the environment variables OMP_NUM_THREADS, TF_INTRA_OP_PARALLELISM_THREADS, and TF_INTER_OP_PARALLELISM_THREADS. See https://deepmd.rtfd.io/parallelism/ for more information.
OKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOK
I was using a GPU machine, the jax was installed by
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I still cannot reproduce. I post my pip list output below
Package Version
---------------------------- ---------------------
absl-py 2.1.0
astunparse 1.6.3
bracex 2.4
cachetools 5.3.2
certifi 2024.2.2
charset-normalizer 3.3.2
dargs 0.4.4
deepmd-kit 2.2.8
flatbuffers 23.5.26
gast 0.5.4
google-auth 2.28.1
google-auth-oauthlib 1.0.0
google-pasta 0.2.0
grpcio 1.62.0
h5py 3.10.0
idna 3.6
jax 0.4.24
jaxlib 0.4.24+cuda11.cudnn86
keras 2.14.0
libclang 16.0.6
Markdown 3.5.2
MarkupSafe 2.1.5
ml-dtypes 0.2.0
numpy 1.26.4
nvidia-cublas-cu11 11.11.3.6
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-nvcc-cu11 11.8.89
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cudnn-cu11 8.9.6.50
nvidia-cufft-cu11 10.9.0.58
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusparse-cu11 11.7.5.86
nvidia-nccl-cu11 2.19.3
oauthlib 3.2.2
opt-einsum 3.3.0
packaging 23.2
pip 24.0
protobuf 4.25.3
pyasn1 0.5.1
pyasn1-modules 0.3.0
python-hostlist 1.23.0
PyYAML 6.0.1
requests 2.31.0
requests-oauthlib 1.3.1
rsa 4.9
scipy 1.12.0
setuptools 69.1.1
six 1.16.0
tensorboard 2.14.1
tensorboard-data-server 0.7.2
tensorflow 2.14.0
tensorflow-estimator 2.14.0
tensorflow-io-gcs-filesystem 0.36.0
termcolor 2.4.0
typeguard 4.1.5
typing_extensions 4.9.0
urllib3 2.2.1
wcmatch 8.5.1
Werkzeug 3.0.1
wheel 0.42.0
wrapt 1.14.1
I stopped trying to use tf, dp and jax in the same python code...
I would suggest providing a Docker or conda environment that can fully reproduce the problem.