brax icon indicating copy to clipboard operation
brax copied to clipboard

Jacobian of State Dynamics

Open Jaldrich2426 opened this issue 1 year ago • 3 comments

Hi, In a similar direction to #83, I'm attempting to use Brax for a form of optimal control, but I need access to the Jacobian of the dynamics step function with respect to both the input state and action. Specifically, if

f(state,action) -> new_state

I'd like the jacobian of new_state with respect to certain elements in state, and the jacobian of new_state with respect to action. If this were a linear model under

x_next = Ax+Bu

I'd want A and B. For the ant sample environment, I was able to compute a gradient with respect to action by defining a helper step function of

def trimmed_state_step(state, action):
    new_state = env.step(state, action)
    return new_state.pipeline_state.q

and calling

jax.jacobian(trimmed_state_step,argnums=1)(state, act)

However, attempting to achieve similar results for relevant fields of the state (q in this example)

jax.jacobian(trimmed_state_step, argnums=0)(state, act).pipeline_state.q

yields an all 0's result. Is there a better way to compute these jacobians, and is this supported behavior by Brax?

Jaldrich2426 avatar Jul 01 '24 19:07 Jaldrich2426

Hi @Jaldrich2426 - which pipeline are you using? Yes, I would expect the jacobian wrt state to give meaningful results. Can you post a colab or code snippet?

erikfrey avatar Jul 02 '24 18:07 erikfrey

Thanks @erikfrey! I'm working off a trimmed-down example of this notebook here. I've primarily been testing in the positional pipeline but tried using the others with no luck. I've also moved to using the observation instead of the pipeline state since it contains the pipeline's state's position and velocity (q and qd) in this example. Here's my trimmed version:

import functools
import jax
from jax import numpy as jp

from brax import envs
from brax.io import model
from brax.training.agents.ppo import train as ppo


env_name = 'ant'
backend = 'positional'  # @param ['generalized', 'positional', 'spring']

env = envs.get_environment(env_name=env_name,
                           backend=backend)

state = env.reset(rng=jax.random.PRNGKey(seed=0))

train_fn = functools.partial(ppo.train,  num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1,
                             unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1)

make_inference_fn, params, _ = train_fn(environment=env)
inference_fn = make_inference_fn(params)

env = envs.create(env_name=env_name, backend=backend)

def trimmed_state_step(state, action):
    new_state = env.step(state, action)
    return new_state.obs

rng = jax.random.PRNGKey(seed=1)
state = env.reset(rng=rng)

act_rng, rng = jax.random.split(rng)
act, _ = inference_fn(state.obs, act_rng)

new_q = trimmed_state_step(state, act)
print(new_q)
print(jax.jacobian(trimmed_state_step,argnums=1)(state, act))
print(jax.jacobian(trimmed_state_step, argnums=0)(state, act).obs)
print(jp.linalg.norm(jax.jacobian(trimmed_state_step, argnums=0)(state, act).obs))

Jaldrich2426 avatar Jul 02 '24 19:07 Jaldrich2426

Update: the issue ended up being twofold

The Jacobian of a state with respect to the previous state seems to only function on the generalized pipeline and does not propagate the value to the observation field. To remedy this, you can take the pipeline_state field of the state gradient and obtain the relevant information from there.

For example, if you want to use the observation field, you can create a gradient version of your environment's _get_obs method to convert the pipeline state gradient into an observation gradient. For the Ant environment, the observation function involves simply slicing, so you just need a slightly modified version to handle the extra dimension. I've included a working example below.

import functools
import jax
from jax import numpy as jp

from brax import envs
from brax.io import model
from brax.training.agents.ppo import train as ppo


env_name = 'ant'
backend = 'generalized'  # @param ['generalized', 'positional', 'spring']

env = envs.get_environment(env_name=env_name,
                           backend=backend)

state = env.reset(rng=jax.random.PRNGKey(seed=0))

train_fn = functools.partial(ppo.train,  num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1,
                             unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1)

make_inference_fn, params, _ = train_fn(environment=env)
inference_fn = make_inference_fn(params)

env = envs.create(env_name=env_name, backend=backend)

def trimmed_state_step(state, action):
    new_state = env.step(state, action)
    return new_state.obs

def example_get_obs_grad(pipeline_grad):
    return jp.concatenate([pipeline_grad.q[:,2:], pipeline_grad.qd], axis=-1)

rng = jax.random.PRNGKey(seed=1)
state = env.reset(rng=rng)

act_rng, rng = jax.random.split(rng)
act, _ = inference_fn(state.obs, act_rng)

new_q = trimmed_state_step(state, act)

print(f"observation: {new_q.shape}")
print(new_q)

pipeline_grad=jax.jacobian(trimmed_state_step, argnums=0)(state, act).pipeline_state
dobs_dstate = example_get_obs_grad(pipeline_grad)

print(f"pipeline state gradient: {dobs_dstate.shape}")
print(dobs_dstate)

Jaldrich2426 avatar Jul 11 '24 19:07 Jaldrich2426

Thanks for the investigation @Jaldrich2426 ! Converting this to a discussion thread for others to find

btaba avatar Oct 04 '24 17:10 btaba