brax icon indicating copy to clipboard operation
brax copied to clipboard

Domain randomization with mjx backend

Open yardenas opened this issue 1 year ago • 0 comments

I have the following test file, which I use to test whether my domain randomization implementation works. The idea is that I condition (randomized) environments on the same action and initial state, and expect to get perturbed state when taking one step forward.

from typing import Any
import jax

import jax.numpy as jnp
from brax import base
from brax.io import mjcf
from brax.envs import register_environment
from brax.envs.base import PipelineEnv, State
from brax import envs

xml = """
<!-- https://github.com/google-deepmind/dm_control/blob/main/dm_control/suite/cartpole.xml -->
<mujoco model="cart-pole">
  <option timestep="0.01">
    <flag contact="disable"/>
  </option>
  <default>
    <default class="pole">
      <joint type="hinge" axis="0 1 0"  damping="2e-6"/>
      <geom type="capsule" fromto="0 0 0 0 0 1" size="0.045" mass=".1"/>
    </default>
  </default>
  <worldbody>
    <light name="light" pos="0 0 6"/>
    <camera name="fixed" pos="0 -4 1" zaxis="0 -1 0"/>
    <camera name="lookatcart" mode="targetbody" target="cart" pos="0 -2 2"/>
    <geom name="floor" pos="0 0 -.05" size="4 4 .2" type="plane"/>
    <geom name="rail1" type="capsule" pos="0  .07 1" zaxis="1 0 0" size="0.02 2"/>
    <geom name="rail2" type="capsule" pos="0 -.07 1" zaxis="1 0 0" size="0.02 2"/>
    <body name="cart" pos="0 0 1">
      <joint name="slider" type="slide" limited="true" axis="1 0 0" range="-1.8 1.8" solreflimit=".08 1" damping="5e-4"/>
      <geom name="cart" type="box" size="0.2 0.15 0.1" mass="1"/>
      <body name="pole_1" childclass="pole">
        <joint name="hinge_1"/>
        <geom name="pole_1"/>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor name="slide" joint="slider" gear="10" ctrllimited="true" ctrlrange="-1 1" />
  </actuator>
</mujoco>
"""


_PARALLEL_ENVS = 128


def domain_randomization(sys, rng):
    @jax.vmap
    def randomize(rng):
        cpole = jax.random.normal(rng) * 1. + sys.actuator.gear[0]
        mass = sys.actuator.gear.at[0].set(cpole)
        return mass, cpole

    mass, samples = randomize(rng)
    in_axes = jax.tree_map(lambda x: None, sys)
    in_axes = in_axes.tree_replace({"actuator.gear": 0})
    sys = sys.tree_replace({"actuator.gear": mass})
    return sys, in_axes, samples[:, None]


class Cartpole(PipelineEnv):
    def __init__(self, backend="mjx", **kwargs):
        sys = mjcf.loads(xml)
        self.sparse = kwargs.pop("sparse", False)
        self.swingup = kwargs.pop("swingup", False)
        super().__init__(sys=sys, backend=backend, n_frames=1, **kwargs)

    def reset(self, rng: jax.Array) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jax.random.split(rng, 3)
        if self.swingup:
            q = self.sys.init_q + jax.random.normal(rng1, (self.sys.q_size(),)) * 0.01
            q = q.at[1].add(jnp.pi)
        else:
            q = self.sys.init_q
            q = q.at[0].set(jax.random.uniform(rng1, shape=(), minval=-1.0, maxval=1.0))
            q = q.at[1].set(
                jax.random.uniform(rng1, shape=(), minval=-0.034, maxval=0.034)
            )
        qd = jax.random.normal(rng2, (self.sys.qd_size(),)) * 0.01
        pipeline_state = self.pipeline_init(q, qd)
        obs = self._get_obs(pipeline_state)
        reward, done = jnp.zeros(2)
        metrics: dict[str, Any] = {}
        return State(pipeline_state, obs, reward, done, metrics)

    def step(self, state: State, action: jax.Array) -> State:
        """Run one timestep of the environment's dynamics."""
        # Scale action from [-1,1] to actuator limits
        action_min = self.sys.actuator.ctrl_range[:, 0]
        action_max = self.sys.actuator.ctrl_range[:, 1]
        action = (action + 1) * (action_max - action_min) * 0.5 + action_min
        pipeline_state = self.pipeline_step(state.pipeline_state, action)
        obs = self._get_obs(pipeline_state)
        done = jnp.zeros_like(state.done)
        return state.replace(
            pipeline_state=pipeline_state,
            obs=obs,
            reward=self._reward(pipeline_state, action),
            done=done,
        )

    def cart_position(self, pipeline_state: base.State) -> jax.Array:
        return pipeline_state.q[0]

    def pole_angle_components(self, pipeline_state: base.State) -> jax.Array:
        return jnp.cos(pipeline_state.q[1]), jnp.sin(pipeline_state.q[1])

    def bounded_position(self, pipeline_state: base.State) -> jax.Array:
        return jnp.hstack(
            (
                self.cart_position(pipeline_state),
                *self.pole_angle_components(pipeline_state),
            )
        )

    def _reward(self, pipeline_state: base.State, action: jax.Array) -> jax.Array:
        if self.sparse:
            cart_in_bounds = 1.0
            angle_in_bounds = 1.0
            return cart_in_bounds * angle_in_bounds
        else:
            upright = (self.pole_angle_components(pipeline_state)[0] + 1) / 2
            centered = 1.0
            centered = (1 + centered) / 2
            small_control = 1.0
            small_control = (4 + small_control) / 5
            small_velocity = 1.0
            small_velocity = (1 + small_velocity) / 2
            return upright.mean() * small_control * small_velocity * centered

    @property
    def action_size(self):
        return 1

    def _get_obs(self, pipeline_state: base.State) -> jax.Array:
        """Observe cartpole body position and velocities."""
        return jnp.concatenate(
            [self.bounded_position(pipeline_state), pipeline_state.qd]
        )


register_environment(
    "cartpole_swingup_sparse",
    lambda **kwargs: Cartpole(swingup=True, sparse=True, **kwargs),
)
register_environment(
    "cartpole_swingup", lambda **kwargs: Cartpole(swingup=True, **kwargs)
)
register_environment(
    "cartpole_balance", lambda **kwargs: Cartpole(swingup=False, **kwargs)
)


def test_parameterization():
    def policy(*_, **__):
        return jnp.ones((1,)), None

    environment = envs.get_environment(env_name="cartpole_balance")
    rng = jax.random.PRNGKey(1)
    rng = jax.random.split(rng, _PARALLEL_ENVS)
    new_sys, in_axes, samples = domain_randomization(environment.sys, rng)
    environment = envs.training.wrap(
        environment,
        action_repeat=1,
        randomization_fn=lambda *_, **__: (new_sys, in_axes),
    )
    policy = jax.vmap(policy, in_axes=(0, None))
    state = environment.reset(jnp.asarray([jax.random.PRNGKey(0)] * _PARALLEL_ENVS))
    next_state = environment.step(state, policy(state, None)[0])
    count = sum(
        jnp.allclose(next_state.obs[0, :], obs[0, :])
        for obs in jnp.split(next_state.obs[1:, :], _PARALLEL_ENVS - 1, axis=0)
    )
    assert (
        count / _PARALLEL_ENVS < 0.2
    ), "Different environment initializations should have different trajectories"


if __name__ == "__main__":
    test_parameterization()

As seen, the environment is based on deepmind control suite cartpole. Now, for some reason, when I use the mjx backend, in line 64 (Cartpole's constructor), I get the exact same next state (while I would expect different states, as the dynamics are perturbed), while if I use generalized backend, the test passes.

Any idea why this may happen? The barkour tutorial suggests that domain randomization with the mjx backend should work.

Any help would be very much appreciated!

EDIT: tried now running a similar experiment with the InvertedPendulum environment -- seems like I get the same behavior. Happy to create a standalone running code for it as well.

yardenas avatar Jul 18 '24 17:07 yardenas