brax
brax copied to clipboard
`_generate_eval_unroll` throws "INTERNAL: the requested functionality is not supported"
I am trying to run a simple SAC example for quadruped locomotion.
The script I use is
env_name = 'barkour'
env = envs.get_environment(env_name)
num_timesteps = 100000
train_fn = functools.partial(
sac.train,
num_timesteps=num_timesteps,
episode_length=env.eps_length,
num_envs=1, #4096,
learning_rate=3e-4,
discounting=0.99,
batch_size=256,
num_evals=1, #10,
# normalize_observations=True,
min_replay_size=1000,
max_replay_size=num_timesteps,
network_factory=sac_networks.make_sac_networks,
randomization_fn=domain_randomize,
)
# Reset environments since internals may be overwritten by tracers from the
# domain randomization function.
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)
make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)
where the environment name is the same as in the Quadruped Collab example.
However, I keep getting the error
E0122 13:56:59.833734 391093 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: INTERNAL: the requested functionality is not supported
Traceback (most recent call last):
File "scripts/train_barkour_straight_sac.py", line 94, in <module>
main()
File "scripts/train_barkour_straight_sac.py", line 67, in main
make_inference_fn, params, _ = train_fn(environment=env,
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/training/agents/sac/train.py", line 577, in train
metrics = evaluator.run_evaluation(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/training/acting.py", line 134, in run_evaluation
eval_state = self._generate_eval_unroll(policy_params, unroll_key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: the requested functionality is not supported
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I am using jax==0.5.0 and brax==0.12.1.
@btaba you mentioned you've been using SAC without any issues. Anything seem obvious about this off the top of your head?
Additionally, how much GPU RAM do you use? I seem to also be running into memory issues using brax's implementation vs the implementation of A Walk in the Park.
Hi @varunagrawal , nothing looks obviously wrong to me for the small code snippet you provided. What's different compared to the quadruped example? And which quadruped example are you using at this point (we have several)?
I don't usually futz with the JAX / XLA memory allocation, I let JAX do its thing (I'm assuming you mean via XLA_PYTHON_CLIENT_MEM_FRACTION).
Interesting. I'll spend some time making a Colab example. If I can reproduce it there, that will be telling.
What example are you basing this on?
This Deepmind example, specifically the Quadruped section.
PPO works great for me, but SAC throws the above error so this is a really curious issue.
Hi @varunagrawal let us know when you have a reproducible example
@btaba here is a repo which I am able to reproduce the issue in: https://github.com/varunagrawal/locomotion
Please let me know if you are able to get the same error. Again, PPO works great with the same environment, so I am very lost on why SAC is giving this issue.
In your example I get:
python train_barkour_straight_sac.py
Traceback (most recent call last):
File "locomotion/train_barkour_straight_sac.py", line 20, in <module>
from locomotion.envs import domain_randomize
File "locomotion/locomotion/__init__.py", line 3, in <module>
from .envs import BarkourStraightEnv
File "locomotion/locomotion/envs/__init__.py", line 1, in <module>
from .barkour_straight import BarkourStraightEnv, domain_randomize
File "locomotion/locomotion/envs/barkour_straight.py", line 14, in <module>
from fill.envs.add_obstacles import add_rand_loc
ModuleNotFoundError: No module named 'fill'
Let us know if you can update it.
@btaba just updated it! Thanks for catching that and looking into this.