stable-baselines3
stable-baselines3 copied to clipboard
Support VecEnv for gymnasium.vector.VectorEnv and Brax
🚀 Feature
It would be nice to have a wrapper that ingested gymnasium.vector.VectorEnv and gave back a VecEnv.
Motivation
I want to do highly parallelized hardware accelerated simulation. This pretty much leaves Isaac or Brax. Brax has a lighter weight setup plus runs on TPUs. Stable baselines has well documented and tested implementations of most of the algorithms that I'm interested in using, as well as deep integration with the imitate library. I'd like to use both of these libraries.
Pitch
Brax currently provides a wrapper for legacy OpenAI gym vectorized environments. I have a request up to support Gymnasium vectorized API (pretty much just change the imports to Gymnasium instead of Gym). Stable baselines requires vectorized environments to be implemented against it's specific VecEnv specification. As far as I can tell, it's pretty simple to migrate between gymnasium vectorized env API and sb3's representation.
I'd like a wrapper class to be provided that implements VecEnv with an underlying gymnasium vectorized env.
Alternatives
Given the public API allows users to extend the library to write this themselves, that would be the chief alternative.
Additional context
No response
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] If I'm requesting a new feature, I have proposed alternatives
Where is the duplicate? I searched for it but couldn't find it. Would appreciate a pointer
Partial duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1568#issuecomment-1600595147 and https://github.com/DLR-RM/stable-baselines3/issues/229
For short: a VecEnvWrapper
would be indeed a good idea but only after gymnasium 1.0 is released and fully tested. Would you be willing to contribute such wrapper?
Related doc: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#sb3-with-envpool-or-isaac-gym
Related issues: https://github.com/DLR-RM/stable-baselines3/issues/1712 and https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002
I happened to step away from using Gymnasium APIs. I was focused on Brax.
from typing import ClassVar, Optional
from brax.envs.base import PipelineEnv
from brax.io import image
# import gym
# from gym import spaces
import gymnasium
import gymnasium as gym
from gymnasium import spaces
from gymnasium.vector import utils
import jax
import jax.numpy as jp
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices
class SB3Wrapper(VecEnv):
def __init__(self,
env: PipelineEnv,
seed: int = 0,
info_keys: Optional[Sequence[str]] = None,
backend: Optional[str] = None):
self._env = env
self.info_keys = info_keys
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 1 / self._env.dt
}
if not hasattr(self._env, 'batch_size'):
raise ValueError('underlying env must be batched')
if not hasattr(self._env, 'episode_length'):
raise ValueError('underlying env must be wrapped with an episode wrapper')
obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
obs_space = spaces.Box(-obs, obs, dtype='float32')
action = jax.tree_map(np.array, self._env.sys.actuator.ctrl_range)
action_space = spaces.Box(action[:, 0], action[:, 1], dtype='float32')
self.num_envs = self._env.batch_size
self.observation_space = obs_space
self.action_space = action_space
# self.batch_observation_space = utils.batch_space(obs_space, self.num_envs)
# self.batch_action_space = utils.batch_space(action_space, self.num_envs)
self.seed(seed)
self.backend = backend
self._state = None
def reset(key):
key1, key2 = jax.random.split(key)
state = self._env.reset(key2)
return state, state.obs, key1
self._reset = jax.jit(reset, backend=self.backend)
def step(state, action):
state = self._env.step(state, action)
info = {**state.metrics, **state.info}
return state, state.obs, state.reward, state.done, state.info['truncation'], info
self._step = jax.jit(step, backend=self.backend)
def reset(self, **kwargs):
self._state, obs, self._key = self._reset(self._key)
return np.array(obs)
def step_async(self, action):
self.action = jp.array(action)
def step_wait(self):
self._state, obs, reward, done, truncation, info = self._step(self._state, self.action)
def batch_dict_to_list_dict(batched_dict, keys_to_process):
return [{} for i in range(self.num_envs)]
# if keys_to_process is None:
# return [{} for i in range(self.num_envs)]
# # Filter the dictionary to only include specified keys that are JAX arrays
# filtered_dict = {key: batched_dict[key] for key in keys_to_process if key in batched_dict and isinstance(batched_dict[key], jnp.ndarray)}
# # Find the batch size from the first item in the filtered dictionary
# batch_size = filtered_dict[next(iter(filtered_dict))].shape[0] if filtered_dict else 0
# # Create a list of dictionaries for each batch index
# return [{key: filtered_dict[key][i] for key in filtered_dict} for i in range(batch_size)]
info = batch_dict_to_list_dict(info, self.info_keys)
# print(reward)
return np.array(obs), np.array(reward), np.array(done), info
def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)
def render(self, mode='human'):
if mode == 'rgb_array':
sys, state = self._env.sys, self._state
if state is None:
raise RuntimeError('must call reset or step before rendering')
return image.render_array(sys, state.pipeline_state.take(0), 256, 256)
else:
return super().render(mode=mode) # just raise an exception
def close(self):
pass
def env_is_wrapped(self, wrapper_class):
return [False] * self.num_envs
def step(self, actions):
self.step_async(actions)
return self.step_wait()
def get_attr(self, attr_name, indicies):
return getattr(self, attr_name)
def set_attr(self, attr_name, value, indicies):
return setattr(self, attr_name, value)
def env_method(self, method_name, *method_args, indicies, **method_kwargs):
return self.get_attr(method_name)(method_args, method_kwargs)
class AutoResetWrapper2(Wrapper):
"""Automatically resets Brax envs that are done."""
def reset(self, rng: jax.Array) -> State:
base_state = self.env.reset(rng)
info = base_state.info.copy()
info.update({
'initial_base_state': base_state,
'current_base_state': base_state
})
return State(
pipeline_state=base_state.pipeline_state,
obs=base_state.obs,
reward=base_state.reward,
done=base_state.done,
metrics=base_state.metrics,
info=info
)
def step(self, state: State, action: jax.Array) -> State:
initial_base_state = state.info['initial_base_state']
current_base_state = state.info['current_base_state']
next_base_state = self.env.step(current_base_state, action)
done = next_base_state.done
def where_done(x, y):
return jp.where(done, x, y)
info = jax.tree_map(where_done, initial_base_state.info, next_base_state.info).copy()
info.update ({
'initial_base_state': initial_base_state,
'current_base_state': jax.tree_map(where_done, initial_base_state, next_base_state),
})
return State(
pipeline_state=jax.tree_map(where_done, initial_base_state.pipeline_state, next_base_state.pipeline_state),
obs=jax.tree_map(where_done, initial_base_state.obs, next_base_state.obs),
reward=jax.tree_map(where_done, initial_base_state.reward, next_base_state.reward),
done=next_base_state.done,
metrics=jax.tree_map(where_done, initial_base_state.metrics, next_base_state.metrics),
info=info
)
from brax.envs.wrappers.training import VmapWrapper, EpisodeWrapper, AutoResetWrapper
from brax.envs.ant import Ant
from brax.envs.humanoid import Humanoid
episode_length = 1000
backend = 'spring'
batch_size = 1024
action_repeat = 1
env = Ant(backend='spring')
env = EpisodeWrapper(env, episode_length, action_repeat=action_repeat)
env = AutoResetWrapper2(env)
env = VmapWrapper(env, batch_size)
^ This is really hacky stuff and there's tons that's terrible about it. This is a high level sketch of everything that would be needed to get this to work.
Hello, thanks for providing the code =) Do you need any help to get it to work? I would be happy to link it in our doc (and maybe integrate it in the zoo or sb3 contrib) as it should be similar to envpool/isaac gym: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#sb3-with-envpool-or-isaac-gym
@vyeevani did you finalise this into a working version?