MuyGPyS
MuyGPyS copied to clipboard
numpy version has error checking jax
pip install --upgrade muygpys[hnswlib]
[ins] In [1]: import MuyGPyS
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[1], line 1
----> 1 import MuyGPyS
File ~/miniforge3/lib/python3.10/site-packages/MuyGPyS/__init__.py:12
8 import importlib.metadata
10 __version__ = importlib.metadata.version(__package__)
---> 12 from MuyGPyS._src.config import (
13 config as config,
14 jax_config as jax_config,
15 MPI as MPI,
16 )
File ~/miniforge3/lib/python3.10/site-packages/MuyGPyS/_src/config.py:82
77 config.state.jax_enabled = val
80 # JAX and GPU states
---> 82 enable_jax = config.define_bool_state(
83 name="muygpys_jax_enabled",
84 default=False,
85 help="Enable use of jax implementations of math functions.",
86 update_global_hook=_update_jax_global,
87 update_thread_local_hook=_update_jax_thread_local,
88 )
91 def _update_gpu_global(val):
92 config.state.gpu_enabled = val
AttributeError: 'MuyGPySConfig' object has no attribute 'define_bool_state'
This appears to be because this line: https://github.com/LLNL/MuyGPyS/blob/0dad6a882048bcf885c59a2a23ce09181b7e67f4/src/MuyGPyS/_src/config.py#L8
I do have jax installed and the JaxConfig does not have define_bool_state
Is this a version compatibility issue?
Downgrading to jax 0.4.24 fixed this
Thanks @esheldon for investigating. There is a known incompatibility with recent versions of JAX in Python >= 3.9 arising from their config objects. We can fix this in a future release, but in the meantime thank you for identifying a compatible version of JAX.