brax icon indicating copy to clipboard operation
brax copied to clipboard

`_generate_eval_unroll` throws "INTERNAL: the requested functionality is not supported"

Open varunagrawal opened this issue 10 months ago • 9 comments

I am trying to run a simple SAC example for quadruped locomotion.

The script I use is

    env_name = 'barkour'
    env = envs.get_environment(env_name)

    num_timesteps = 100000

    train_fn = functools.partial(
        sac.train,
        num_timesteps=num_timesteps,
        episode_length=env.eps_length,
        num_envs=1, #4096,
        learning_rate=3e-4,
        discounting=0.99,
        batch_size=256,
        num_evals=1, #10,
        # normalize_observations=True,
        min_replay_size=1000,
        max_replay_size=num_timesteps,
        network_factory=sac_networks.make_sac_networks,
        randomization_fn=domain_randomize,
    )

    # Reset environments since internals may be overwritten by tracers from the
    # domain randomization function.
    env = envs.get_environment(env_name)
    eval_env = envs.get_environment(env_name)

    make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)

where the environment name is the same as in the Quadruped Collab example.

However, I keep getting the error

E0122 13:56:59.833734  391093 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: INTERNAL: the requested functionality is not supported
Traceback (most recent call last):
  File "scripts/train_barkour_straight_sac.py", line 94, in <module>
    main()
  File "scripts/train_barkour_straight_sac.py", line 67, in main
    make_inference_fn, params, _ = train_fn(environment=env,
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/training/agents/sac/train.py", line 577, in train
    metrics = evaluator.run_evaluation(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/training/acting.py", line 134, in run_evaluation
    eval_state = self._generate_eval_unroll(policy_params, unroll_key)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: the requested functionality is not supported
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

I am using jax==0.5.0 and brax==0.12.1.

varunagrawal avatar Jan 22 '25 19:01 varunagrawal

@btaba you mentioned you've been using SAC without any issues. Anything seem obvious about this off the top of your head?

Additionally, how much GPU RAM do you use? I seem to also be running into memory issues using brax's implementation vs the implementation of A Walk in the Park.

varunagrawal avatar Jan 24 '25 22:01 varunagrawal

Hi @varunagrawal , nothing looks obviously wrong to me for the small code snippet you provided. What's different compared to the quadruped example? And which quadruped example are you using at this point (we have several)?

I don't usually futz with the JAX / XLA memory allocation, I let JAX do its thing (I'm assuming you mean via XLA_PYTHON_CLIENT_MEM_FRACTION).

btaba avatar Jan 24 '25 22:01 btaba

Interesting. I'll spend some time making a Colab example. If I can reproduce it there, that will be telling.

varunagrawal avatar Jan 24 '25 22:01 varunagrawal

What example are you basing this on?

btaba avatar Jan 24 '25 22:01 btaba

This Deepmind example, specifically the Quadruped section.

PPO works great for me, but SAC throws the above error so this is a really curious issue.

varunagrawal avatar Jan 24 '25 23:01 varunagrawal

Hi @varunagrawal let us know when you have a reproducible example

btaba avatar Jan 28 '25 19:01 btaba

@btaba here is a repo which I am able to reproduce the issue in: https://github.com/varunagrawal/locomotion

Please let me know if you are able to get the same error. Again, PPO works great with the same environment, so I am very lost on why SAC is giving this issue.

varunagrawal avatar Mar 06 '25 23:03 varunagrawal

In your example I get:

python train_barkour_straight_sac.py
Traceback (most recent call last):
  File "locomotion/train_barkour_straight_sac.py", line 20, in <module>
    from locomotion.envs import domain_randomize
  File "locomotion/locomotion/__init__.py", line 3, in <module>
    from .envs import BarkourStraightEnv
  File "locomotion/locomotion/envs/__init__.py", line 1, in <module>
    from .barkour_straight import BarkourStraightEnv, domain_randomize
  File "locomotion/locomotion/envs/barkour_straight.py", line 14, in <module>
    from fill.envs.add_obstacles import add_rand_loc
ModuleNotFoundError: No module named 'fill'

Let us know if you can update it.

btaba avatar Apr 07 '25 23:04 btaba

@btaba just updated it! Thanks for catching that and looking into this.

varunagrawal avatar Apr 11 '25 17:04 varunagrawal