brax
brax copied to clipboard
Jacobian of State Dynamics
Hi, In a similar direction to #83, I'm attempting to use Brax for a form of optimal control, but I need access to the Jacobian of the dynamics step function with respect to both the input state and action. Specifically, if
f(state,action) -> new_state
I'd like the jacobian of new_state with respect to certain elements in state, and the jacobian of new_state with respect to action. If this were a linear model under
x_next = Ax+Bu
I'd want A and B. For the ant sample environment, I was able to compute a gradient with respect to action by defining a helper step function of
def trimmed_state_step(state, action):
new_state = env.step(state, action)
return new_state.pipeline_state.q
and calling
jax.jacobian(trimmed_state_step,argnums=1)(state, act)
However, attempting to achieve similar results for relevant fields of the state (q in this example)
jax.jacobian(trimmed_state_step, argnums=0)(state, act).pipeline_state.q
yields an all 0's result. Is there a better way to compute these jacobians, and is this supported behavior by Brax?
Hi @Jaldrich2426 - which pipeline are you using? Yes, I would expect the jacobian wrt state to give meaningful results. Can you post a colab or code snippet?
Thanks @erikfrey! I'm working off a trimmed-down example of this notebook here. I've primarily been testing in the positional pipeline but tried using the others with no luck. I've also moved to using the observation instead of the pipeline state since it contains the pipeline's state's position and velocity (q and qd) in this example. Here's my trimmed version:
import functools
import jax
from jax import numpy as jp
from brax import envs
from brax.io import model
from brax.training.agents.ppo import train as ppo
env_name = 'ant'
backend = 'positional' # @param ['generalized', 'positional', 'spring']
env = envs.get_environment(env_name=env_name,
backend=backend)
state = env.reset(rng=jax.random.PRNGKey(seed=0))
train_fn = functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1,
unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1)
make_inference_fn, params, _ = train_fn(environment=env)
inference_fn = make_inference_fn(params)
env = envs.create(env_name=env_name, backend=backend)
def trimmed_state_step(state, action):
new_state = env.step(state, action)
return new_state.obs
rng = jax.random.PRNGKey(seed=1)
state = env.reset(rng=rng)
act_rng, rng = jax.random.split(rng)
act, _ = inference_fn(state.obs, act_rng)
new_q = trimmed_state_step(state, act)
print(new_q)
print(jax.jacobian(trimmed_state_step,argnums=1)(state, act))
print(jax.jacobian(trimmed_state_step, argnums=0)(state, act).obs)
print(jp.linalg.norm(jax.jacobian(trimmed_state_step, argnums=0)(state, act).obs))
Update: the issue ended up being twofold
The Jacobian of a state with respect to the previous state seems to only function on the generalized pipeline and does not propagate the value to the observation field. To remedy this, you can take the pipeline_state field of the state gradient and obtain the relevant information from there.
For example, if you want to use the observation field, you can create a gradient version of your environment's _get_obs method to convert the pipeline state gradient into an observation gradient. For the Ant environment, the observation function involves simply slicing, so you just need a slightly modified version to handle the extra dimension. I've included a working example below.
import functools
import jax
from jax import numpy as jp
from brax import envs
from brax.io import model
from brax.training.agents.ppo import train as ppo
env_name = 'ant'
backend = 'generalized' # @param ['generalized', 'positional', 'spring']
env = envs.get_environment(env_name=env_name,
backend=backend)
state = env.reset(rng=jax.random.PRNGKey(seed=0))
train_fn = functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1,
unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1)
make_inference_fn, params, _ = train_fn(environment=env)
inference_fn = make_inference_fn(params)
env = envs.create(env_name=env_name, backend=backend)
def trimmed_state_step(state, action):
new_state = env.step(state, action)
return new_state.obs
def example_get_obs_grad(pipeline_grad):
return jp.concatenate([pipeline_grad.q[:,2:], pipeline_grad.qd], axis=-1)
rng = jax.random.PRNGKey(seed=1)
state = env.reset(rng=rng)
act_rng, rng = jax.random.split(rng)
act, _ = inference_fn(state.obs, act_rng)
new_q = trimmed_state_step(state, act)
print(f"observation: {new_q.shape}")
print(new_q)
pipeline_grad=jax.jacobian(trimmed_state_step, argnums=0)(state, act).pipeline_state
dobs_dstate = example_get_obs_grad(pipeline_grad)
print(f"pipeline state gradient: {dobs_dstate.shape}")
print(dobs_dstate)
Thanks for the investigation @Jaldrich2426 ! Converting this to a discussion thread for others to find