brax
brax copied to clipboard
Domain randomization with mjx backend
I have the following test file, which I use to test whether my domain randomization implementation works. The idea is that I condition (randomized) environments on the same action and initial state, and expect to get perturbed state when taking one step forward.
from typing import Any
import jax
import jax.numpy as jnp
from brax import base
from brax.io import mjcf
from brax.envs import register_environment
from brax.envs.base import PipelineEnv, State
from brax import envs
xml = """
<!-- https://github.com/google-deepmind/dm_control/blob/main/dm_control/suite/cartpole.xml -->
<mujoco model="cart-pole">
<option timestep="0.01">
<flag contact="disable"/>
</option>
<default>
<default class="pole">
<joint type="hinge" axis="0 1 0" damping="2e-6"/>
<geom type="capsule" fromto="0 0 0 0 0 1" size="0.045" mass=".1"/>
</default>
</default>
<worldbody>
<light name="light" pos="0 0 6"/>
<camera name="fixed" pos="0 -4 1" zaxis="0 -1 0"/>
<camera name="lookatcart" mode="targetbody" target="cart" pos="0 -2 2"/>
<geom name="floor" pos="0 0 -.05" size="4 4 .2" type="plane"/>
<geom name="rail1" type="capsule" pos="0 .07 1" zaxis="1 0 0" size="0.02 2"/>
<geom name="rail2" type="capsule" pos="0 -.07 1" zaxis="1 0 0" size="0.02 2"/>
<body name="cart" pos="0 0 1">
<joint name="slider" type="slide" limited="true" axis="1 0 0" range="-1.8 1.8" solreflimit=".08 1" damping="5e-4"/>
<geom name="cart" type="box" size="0.2 0.15 0.1" mass="1"/>
<body name="pole_1" childclass="pole">
<joint name="hinge_1"/>
<geom name="pole_1"/>
</body>
</body>
</worldbody>
<actuator>
<motor name="slide" joint="slider" gear="10" ctrllimited="true" ctrlrange="-1 1" />
</actuator>
</mujoco>
"""
_PARALLEL_ENVS = 128
def domain_randomization(sys, rng):
@jax.vmap
def randomize(rng):
cpole = jax.random.normal(rng) * 1. + sys.actuator.gear[0]
mass = sys.actuator.gear.at[0].set(cpole)
return mass, cpole
mass, samples = randomize(rng)
in_axes = jax.tree_map(lambda x: None, sys)
in_axes = in_axes.tree_replace({"actuator.gear": 0})
sys = sys.tree_replace({"actuator.gear": mass})
return sys, in_axes, samples[:, None]
class Cartpole(PipelineEnv):
def __init__(self, backend="mjx", **kwargs):
sys = mjcf.loads(xml)
self.sparse = kwargs.pop("sparse", False)
self.swingup = kwargs.pop("swingup", False)
super().__init__(sys=sys, backend=backend, n_frames=1, **kwargs)
def reset(self, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
if self.swingup:
q = self.sys.init_q + jax.random.normal(rng1, (self.sys.q_size(),)) * 0.01
q = q.at[1].add(jnp.pi)
else:
q = self.sys.init_q
q = q.at[0].set(jax.random.uniform(rng1, shape=(), minval=-1.0, maxval=1.0))
q = q.at[1].set(
jax.random.uniform(rng1, shape=(), minval=-0.034, maxval=0.034)
)
qd = jax.random.normal(rng2, (self.sys.qd_size(),)) * 0.01
pipeline_state = self.pipeline_init(q, qd)
obs = self._get_obs(pipeline_state)
reward, done = jnp.zeros(2)
metrics: dict[str, Any] = {}
return State(pipeline_state, obs, reward, done, metrics)
def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""
# Scale action from [-1,1] to actuator limits
action_min = self.sys.actuator.ctrl_range[:, 0]
action_max = self.sys.actuator.ctrl_range[:, 1]
action = (action + 1) * (action_max - action_min) * 0.5 + action_min
pipeline_state = self.pipeline_step(state.pipeline_state, action)
obs = self._get_obs(pipeline_state)
done = jnp.zeros_like(state.done)
return state.replace(
pipeline_state=pipeline_state,
obs=obs,
reward=self._reward(pipeline_state, action),
done=done,
)
def cart_position(self, pipeline_state: base.State) -> jax.Array:
return pipeline_state.q[0]
def pole_angle_components(self, pipeline_state: base.State) -> jax.Array:
return jnp.cos(pipeline_state.q[1]), jnp.sin(pipeline_state.q[1])
def bounded_position(self, pipeline_state: base.State) -> jax.Array:
return jnp.hstack(
(
self.cart_position(pipeline_state),
*self.pole_angle_components(pipeline_state),
)
)
def _reward(self, pipeline_state: base.State, action: jax.Array) -> jax.Array:
if self.sparse:
cart_in_bounds = 1.0
angle_in_bounds = 1.0
return cart_in_bounds * angle_in_bounds
else:
upright = (self.pole_angle_components(pipeline_state)[0] + 1) / 2
centered = 1.0
centered = (1 + centered) / 2
small_control = 1.0
small_control = (4 + small_control) / 5
small_velocity = 1.0
small_velocity = (1 + small_velocity) / 2
return upright.mean() * small_control * small_velocity * centered
@property
def action_size(self):
return 1
def _get_obs(self, pipeline_state: base.State) -> jax.Array:
"""Observe cartpole body position and velocities."""
return jnp.concatenate(
[self.bounded_position(pipeline_state), pipeline_state.qd]
)
register_environment(
"cartpole_swingup_sparse",
lambda **kwargs: Cartpole(swingup=True, sparse=True, **kwargs),
)
register_environment(
"cartpole_swingup", lambda **kwargs: Cartpole(swingup=True, **kwargs)
)
register_environment(
"cartpole_balance", lambda **kwargs: Cartpole(swingup=False, **kwargs)
)
def test_parameterization():
def policy(*_, **__):
return jnp.ones((1,)), None
environment = envs.get_environment(env_name="cartpole_balance")
rng = jax.random.PRNGKey(1)
rng = jax.random.split(rng, _PARALLEL_ENVS)
new_sys, in_axes, samples = domain_randomization(environment.sys, rng)
environment = envs.training.wrap(
environment,
action_repeat=1,
randomization_fn=lambda *_, **__: (new_sys, in_axes),
)
policy = jax.vmap(policy, in_axes=(0, None))
state = environment.reset(jnp.asarray([jax.random.PRNGKey(0)] * _PARALLEL_ENVS))
next_state = environment.step(state, policy(state, None)[0])
count = sum(
jnp.allclose(next_state.obs[0, :], obs[0, :])
for obs in jnp.split(next_state.obs[1:, :], _PARALLEL_ENVS - 1, axis=0)
)
assert (
count / _PARALLEL_ENVS < 0.2
), "Different environment initializations should have different trajectories"
if __name__ == "__main__":
test_parameterization()
As seen, the environment is based on deepmind control suite cartpole. Now, for some reason, when I use the mjx backend, in line 64 (Cartpole's constructor), I get the exact same next state (while I would expect different states, as the dynamics are perturbed), while if I use generalized backend, the test passes.
Any idea why this may happen? The barkour tutorial suggests that domain randomization with the mjx backend should work.
Any help would be very much appreciated!
EDIT: tried now running a similar experiment with the InvertedPendulum environment -- seems like I get the same behavior. Happy to create a standalone running code for it as well.