deepmd-kit icon indicating copy to clipboard operation
deepmd-kit copied to clipboard

[BUG] import deepmd cause failure in jax.random.PRNGKey

Open wanghan-iapcm opened this issue 1 year ago • 5 comments

Bug summary

import deepmd cause failure in jax.random.PRNGKey.

only happens >= 2.2.8 <= 2.2.7 work.

DeePMD-kit Version

>= 2.2.8

TensorFlow Version

2.14.0

How did you download the software?

pip

Input Files, Running Commands, Error Log, etc.

The minimal code for reproducing the bug. Let's say the script is named as reprod.py

import jax
import tensorflow as tf

jax.random.PRNGKey(44)
print("OK"*30)

import deepmd
jax.random.PRNGKey(44)
print("OK"*30)

Steps to Reproduce

python reprod.py 
2024-02-23 21:18:27.871442: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-23 21:18:27.871482: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-23 21:18:27.871502: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-23 21:18:28.465643: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708694309.049719   14210 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
OKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOK
WARNING:tensorflow:From /opt/mamba/lib/python3.10/site-packages/tensorflow/python/compat/v2_compat.py:108: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
WARNING:root:To get the best performance, it is recommended to adjust the number of threads by setting the environment variables OMP_NUM_THREADS, TF_INTRA_OP_PARALLELISM_THREADS, and TF_INTER_OP_PARALLELISM_THREADS. See https://deepmd.rtfd.io/parallelism/ for more information.
2024-02-23 21:18:32.980326: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1541] Failed to get stream capture info: device kernel image is invalid
2024-02-23 21:18:32.980474: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: INVALID_ARGUMENT: stream is uninitialized or in an error state
Traceback (most recent call last):
  File "/mnt/user/wangh/workspace/hydrogen/tests/n16.rs.10141820.nf5120.attn_h2.06.01.w128.node01.main/reprod.py", line 8, in <module>
    jax.random.PRNGKey(44)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/random.py", line 190, in PRNGKey
    return _return_prng_keys(True, _key('PRNGKey', seed, impl))
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/random.py", line 152, in _key
    return prng.seed_with_impl(impl, seed)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 413, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 695, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/core.py", line 386, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/core.py", line 389, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/core.py", line 821, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 707, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 712, in random_seed_impl_base
    return seed(seeds)
  File "/opt/mamba/lib/python3.10/site-packages/jax/_src/prng.py", line 941, in threefry_seed
    return _threefry_seed(seed)
ValueError: INVALID_ARGUMENT: stream is uninitialized or in an error state
I0000 00:00:1708694313.208544   14210 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

Further Information, Files, and Links

No response

wanghan-iapcm avatar Feb 23 '24 13:02 wanghan-iapcm

I cannot reproduce.

conda create -n tf214 python=3.11
pip install tensorflow==2.14 jax jaxlib deepmd-kit==2.2.8
python reprod.py

It prints OK.

2024-02-23 16:05:02.287719: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-23 16:05:02.287752: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-23 16:05:02.287773: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-23 16:05:02.875749: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
OKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOK
WARNING:tensorflow:From /home/jz748/anaconda3/envs/tf214/lib/python3.11/site-packages/tensorflow/python/compat/v2_compat.py:108: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
WARNING:root:To get the best performance, it is recommended to adjust the number of threads by setting the environment variables OMP_NUM_THREADS, TF_INTRA_OP_PARALLELISM_THREADS, and TF_INTER_OP_PARALLELISM_THREADS. See https://deepmd.rtfd.io/parallelism/ for more information.
OKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOKOK

njzjz avatar Feb 23 '24 21:02 njzjz

I was using a GPU machine, the jax was installed by

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

wanghan-iapcm avatar Feb 24 '24 01:02 wanghan-iapcm

I still cannot reproduce. I post my pip list output below

Package                      Version
---------------------------- ---------------------
absl-py                      2.1.0
astunparse                   1.6.3
bracex                       2.4
cachetools                   5.3.2
certifi                      2024.2.2
charset-normalizer           3.3.2
dargs                        0.4.4
deepmd-kit                   2.2.8
flatbuffers                  23.5.26
gast                         0.5.4
google-auth                  2.28.1
google-auth-oauthlib         1.0.0
google-pasta                 0.2.0
grpcio                       1.62.0
h5py                         3.10.0
idna                         3.6
jax                          0.4.24
jaxlib                       0.4.24+cuda11.cudnn86
keras                        2.14.0
libclang                     16.0.6
Markdown                     3.5.2
MarkupSafe                   2.1.5
ml-dtypes                    0.2.0
numpy                        1.26.4
nvidia-cublas-cu11           11.11.3.6
nvidia-cuda-cupti-cu11       11.8.87
nvidia-cuda-nvcc-cu11        11.8.89
nvidia-cuda-nvrtc-cu11       11.8.89
nvidia-cuda-runtime-cu11     11.8.89
nvidia-cudnn-cu11            8.9.6.50
nvidia-cufft-cu11            10.9.0.58
nvidia-cusolver-cu11         11.4.1.48
nvidia-cusparse-cu11         11.7.5.86
nvidia-nccl-cu11             2.19.3
oauthlib                     3.2.2
opt-einsum                   3.3.0
packaging                    23.2
pip                          24.0
protobuf                     4.25.3
pyasn1                       0.5.1
pyasn1-modules               0.3.0
python-hostlist              1.23.0
PyYAML                       6.0.1
requests                     2.31.0
requests-oauthlib            1.3.1
rsa                          4.9
scipy                        1.12.0
setuptools                   69.1.1
six                          1.16.0
tensorboard                  2.14.1
tensorboard-data-server      0.7.2
tensorflow                   2.14.0
tensorflow-estimator         2.14.0
tensorflow-io-gcs-filesystem 0.36.0
termcolor                    2.4.0
typeguard                    4.1.5
typing_extensions            4.9.0
urllib3                      2.2.1
wcmatch                      8.5.1
Werkzeug                     3.0.1
wheel                        0.42.0
wrapt                        1.14.1

njzjz avatar Feb 24 '24 03:02 njzjz

I stopped trying to use tf, dp and jax in the same python code...

wanghan-iapcm avatar Feb 28 '24 01:02 wanghan-iapcm

I would suggest providing a Docker or conda environment that can fully reproduce the problem.

njzjz avatar Mar 16 '24 21:03 njzjz