brax icon indicating copy to clipboard operation
brax copied to clipboard

Cannot use 'mjx' and 'generalized' backend

Open eleninisioti opened this issue 5 months ago • 0 comments

I am trying to run this simple code:

from brax import envs
env_name = 'ant'  
backend = 'mjx' 
env = envs.get_environment(env_name=env_name,
                           backend=backend)
print(env.observation_size)
print(env.action_size)
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

which gives the error

  File ".venv/lib/python3.13/site-packages/brax/envs/base.py", line 147, in observation_size
    reset_state = self.unwrapped.reset(rng)
  File ".venv/lib/python3.13/site-packages/brax/envs/ant.py", line 215, in reset
    pipeline_state = self.pipeline_init(q, qd)
  File ".venv/lib/python3.13/site-packages/brax/envs/base.py", line 126, in pipeline_init
    return self._pipeline.init(self.sys, q, qd, act, ctrl, self._debug)
  File ".venv/lib/python3.13/site-packages/brax/mjx/pipeline.py", line 72, in init
    data = mjx.forward(sys, data)
  File ".venv/lib/python3.13/site-packages/mujoco/mjx/_src/forward.py", line 59, in wrapper
    res = fn(*args, **kwargs)
  File ".venv/lib/python3.13/site-packages/mujoco/mjx/_src/forward.py", line 413, in forward
    d = fwd_position(m, d)
  File ".venv/lib/python3.13/site-packages/mujoco/mjx/_src/forward.py", line 59, in wrapper
    res = fn(*args, **kwargs)
  File ".venv/lib/python3.13/site-packages/mujoco/mjx/_src/forward.py", line 74, in fwd_position
    d = smooth.factor_m(m, d)
  File ".venv/lib/python3.13/site-packages/mujoco/mjx/_src/smooth.py", line 305, in factor_m
    qh, _ = jax.scipy.linalg.cho_factor(d.qM)
  File ".venv/lib/python3.13/site-packages/jax/_src/scipy/linalg.py", line 154, in cho_factor
    return (cholesky(a, lower=lower), lower)
  File ".venv/lib/python3.13/site-packages/jax/_src/scipy/linalg.py", line 106, in cholesky
    return _cholesky(a, lower)
jaxlib._jax.XlaRuntimeError: INTERNAL: cuSolver internal error

When I set the backend to positional it works fine. I think it can be a problem with my installation. Here is my environment:

absl-py==2.3.0
annotated-types==0.7.0
attrs==25.3.0
blinker==1.9.0
brax==0.12.3
certifi==2025.4.26
charset-normalizer==3.4.2
chex==0.1.89
click==8.2.1
cloudpickle==3.1.1
contourpy==1.3.2
cycler==0.12.1
dm-env==1.6
dm-tree==0.1.9
docker-pycreds==0.4.0
dotmap==1.3.30
equinox==0.12.2
etils==1.12.2
evosax==0.2.0
farama-notifications==0.0.4
flask==3.1.1
flask-cors==6.0.0
flax==0.10.6
fonttools==4.58.1
fsspec==2025.5.1
gitdb==4.0.12
gitpython==3.1.44
glfw==2.9.0
grpcio==1.72.1
gym==0.26.2
gym-notices==0.0.8
gymnasium==1.1.1
gymnax==0.0.9
humanize==4.12.3
idna==3.10
importlib-resources==6.5.2
itsdangerous==2.2.0
jax==0.6.1
jax-cuda12-pjrt==0.6.1
jax-cuda12-plugin==0.6.1
jaxlib==0.6.1
jaxopt==0.8.5
jaxtyping==0.3.2
jinja2==3.1.6
joblib==1.5.1
kiwisolver==1.4.8
markdown-it-py==3.0.0
markupsafe==3.0.2
matplotlib==3.10.3
mdurl==0.1.2
ml-collections==1.1.0
ml-dtypes==0.5.1
msgpack==1.1.0
mujoco==3.3.2
mujoco-mjx==3.3.2
nest-asyncio==1.6.0
networkx==3.5
numpy==2.3.0
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.9.79
nvidia-cuda-nvcc-cu12==12.9.86
nvidia-cuda-runtime-cu12==12.9.79
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.4.1.4
nvidia-cusolver-cu12==11.7.5.82
nvidia-cusparse-cu12==12.5.10.65
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.9.86
nvidia-nvshmem-cu12==3.2.5
opt-einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.13
packaging==25.0
pandas==2.2.3
pillow==11.2.1
platformdirs==4.3.8
protobuf==6.31.1
psutil==7.0.0
pyaml==25.5.0
pydantic==2.11.5
pydantic-core==2.33.2
pygments==2.19.1
pyopengl==3.1.9
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytinyrenderer==0.0.14
pytz==2025.2
pyyaml==6.0.2
requests==2.32.3
rich==14.0.0
scikit-learn==1.6.1
scipy==1.15.3
seaborn==0.13.2
sentry-sdk==2.29.1
setproctitle==1.3.6
setuptools==80.9.0
simplejson==3.20.1
six==1.17.0
smmap==5.0.2
tensorboardx==2.6.2.2
tensorstore==0.1.75
threadpoolctl==3.6.0
toolz==1.0.0
tqdm==4.67.1
treescope==0.1.9
trimesh==4.6.10
typing-extensions==4.14.0
typing-inspection==0.4.1
tzdata==2025.2
urllib3==2.4.0
wadler-lindig==0.1.6
wandb==0.19.11
werkzeug==3.1.3
wrapt==1.17.2
zipp==3.22.0

eleninisioti avatar Jun 09 '25 17:06 eleninisioti