How to train an NNX netowork using scan and jit?
Hi Flax team,
I'm writing as I've been struggling the past week trying to implement PPO for reinforcement learning using NNX. I'm at the point where I need some help from the experts.
I've been following the implementation of purejaxrl and the one in Brax. The thing is that these two use the old linen API. I wanted to move to the new NNX API but had lots of problems with the JIT. Before I start showing what I have, I also tried to follow the guidelines for performance in the docs and this reply by @puct9 https://github.com/google/flax/issues/4045#issuecomment-2350903096.
Ok then, first a simple example:
def rollout_trajectories(env: rl.Env, graphdef: nnx.GraphDef, gstate: nnx.GraphState, env_state: rl.EnvState, rng: jnp.ndarray, num_steps: int):
model, _, _ = nnx.merge(graphdef, gstate)
def _env_step(carry, _):
env_state, rng = carry
obs = env.observation(env_state)
pi, value = model(obs)
rng, subkey = jax.random.split(rng)
action = pi.sample(seed=subkey)
logp = pi.log_prob(action)
state_jdem, system_jdem, env_params_jdem = env.step(env_state, action)
new_env_state = rl.EnvState(state_jdem, system_jdem, env_params_jdem)
reward = env.reward(new_env_state)
done = env.done(new_env_state)
info = env.info(new_env_state)
trans = TrajectoryData(done, action, value, reward, logp, obs, info)
return (new_env_state, rng), trans
return jax.lax.scan(_env_step, (env_state, rng), None, num_steps)
This code is very straightforward. Just use the scan to get information about the trajectories. The problem comes when I try to JIT it. No matter what I do: defining the nnx parameters as static, using nnx.jit capturing them in the scope, etc., the function always errors with something similar to TypeError: Argument 'args[0]' of shape float32[1,2] of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.
Is there a way to pass/do inference inside a JITed function? I know that in the MINIST tutorial, it works, but I'm unable to replicate the same result.
Now, the important problem. Training using scan. Following the same problem I had before, no matter if I use nnx.scan or jax.lax.Scan, I'm having trouble passing the NNX object. This makes it impossible for me even to attempt to train the network. My code looks something like this:
rng, sub = jax.random.split(rng)
network = SharedActorCritic(env.observation_space, env.action_space, nnx.Rngs(sub))
tx = optax.chain(
optax.clip_by_global_norm(max_grad_norm),
optax.adam(learning_rate, eps=1e-5)
)
optimizer = nnx.Optimizer(network, tx)
metrics = nnx.MultiMetric(loss=nnx.metrics.Average())
graphdef, init_gstate = nnx.split((network, optimizer, metrics))
# reset envs
rng, sub = jax.random.split(rng)
reset_rngs = jax.random.split(sub, num_envs)
env_state = env.reset(reset_rngs, env_params)
train_state = (init_gstate, env_state, rng)
def train_step(train_state, _):
gstate, env_state, rng = train_state
m, opt, met = nnx.merge(graphdef, gstate)
(env_state, rng), trajectory_data = rollout_trajectories(env, graphdef, gstate, env_state, rng, num_steps)
advantages, targets = calculate_general_advantage(trajectory_data, gamma, gae_lambda)
# PPO logic
return (nnx.state((m, opt, met)), env_state, rng), 0
I cant make this function work as long as I don't JIT rollout_trajectories and don't return or pass any of the nnx objects. The problem with this is that I am unable to update them. I could use a regular for loop to call train_step and only use scan for the minibatches, but I was hoping to achieve high-performance code. I tried using train_state, and that helped me make some progress, but in the end, I still encountered the same problem.
If you want to play with this code, clone this repo and execute this script inside. I think everything should work (you will get the nnx error with the JIT though).
I see that with the linen API this moving in and out of JIT land is not that hard. I would appreciate it if someone could explain to me how to do it. Additionally, if there is a recommended approach to these types of problems with NNX, I would appreciate it if you could point me to it.
Best regards.
I currently lack the capacity to answer this question very specifically and in depth, but I happen to have recently done something extremely similar (train RL PPO in Flax NNX, also featuring the scan trick). The code is in a single notebook here.
Apologies for being unable to answer the question in a more targeted manner, but I hope this helps.
Hi @puct9, your answer might be just what I needed. Thank you very much. I'll review the code and use it to make my own. I'll come back if I have any questions.
Hey @cdelv, it would be great to have a reproducible example we can run to simplify debugging. Based on what you showed its a little bit tricky to figure out. The only thing I can say is that passing model as a capture is generally not a good idea as any mutation would raise an error, it better to pass it as an explicit input to nnx.scan-ed function.
Hi @cgarciae, thank you for your response.
I wrote this standalone script that should reproduce the problem I'm talking about. I haven't quite finished going through @puct9's script, but it looks like they use a regular for loop for the training loop and scan for the batches. What I'm trying to do is use a scan for everything.
I want to have the highest performance possible. I aprecciate any sugestions on how to achive this.
I'll come back when I finish adapting @puct9's script.
Hi @cgarciae,
After trying different approaches, I found this example in the documentation using TrainState. I attempted to make the code work using TrainState, and it got me really far. However, the code fails at the end with TypeError: Argument 'Traced<ShapedArray(float32[1,2])>with<DynamicJaxprTrace>' of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type.
Here is the complete script: main.txt
There is an issue with passing the train state around, as it becomes a tracer that JAX cannot use. I am unable to get around it. I know that in linen this is possible. I don't understand why in nnx it does not work.
I was unable to run the provided script due to a potentially unrelated (?) issue -- ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field pos is not allowed: use default_factory (Python 3.11) so I don't think that I can directly address the problem this time, but I have a couple points anyway.
I haven't quite finished going through @puct9's script, but it looks like they use a regular for loop for the training loop and scan for the batches. What I'm trying to do is use a scan for everything.
I updated the train loop to perform any number of iterations at a time 😉
Now more on the actual question
I want to have the highest performance possible. I aprecciate any sugestions on how to achive this.
I notice a small edit you've attempted to make from what I assume was the reference material:
- def train(rng):
+ @jax.jit
+ def train(key):
And your wish for my original notebook's training loop to be in a scan.
Hence, I think the point of performance and why we want to sometimes jit has been somewhat missed. For example:
- In extremely quick train loops, it can be insufficient to simply
jitthe function that applies a single update. The reason is because we will spend more time traversing between the compiled function and Python. In this case, due to the more computationally intensive nature of needing to run the environment, computing losses, and updating the network, it's unlikely we can achieve more than a few dozen iterations per second. This means that the theoretical difference in performance between repeatedly calling ajitted function from Python and having it in something like ascandecreases to a point where it is negligible. - Jax's performance certainly is solid, but it's not magic. The greatest gain from being able to describe the environment in Jax is the level of convenience for the good performance. The only other competitive alternative (actually superior in this case) is probably to handle the environment and a lot of what Jax does yourself in a low-level language like C++ at a great inconvenience. Likewise, the expectation that going beyond
jitting a significant portion of the train loop (e.g., the train step function) andjitting the train loop itself will bring much performance benefit will likely result in disappointment and a nonnegligible amount of wasted time. If you want to understand where you're gaining/losing time, consider profiling the program. - Whether deliberate or by coincidence, I didn't find a single instance of
nnx.scanornnx.jitin your script. There is a difference between thennxandjaxversions of those functions, such as being able to passnnx.Moduleandnnx.Rngsobjects correctly. I will make no assumptions about your intent, but "don't usennx.jitbecause it's slower" would certainly be the wrong lesson to take from https://github.com/google/flax/issues/4045. The source of the slowdown is the traversal between Python and a function compiled withnnx.jit, and even that cost is negligible as long as the operation is sufficiently expensive (which in this case, it is).
This also isn't to say that if you remove jit things are much more likely to work. There are a number of other dubious practices some of which I'm not sure are meant to work at all:
- Constructing
nnx.Rngsfrom a vanilla Jax keykey, subkey = jax.random.split(key) network = SharedActorCritic(env.observation_space, env.action_space, nnx.Rngs(subkey)) - Use of
TrainState. I'd recommend attempting the method shown in the Flax NNX vs JAX transformations (click onJAX transforms) or here as the ideal way of using objects originating from Flax or NNX in places likejax.jitted functions. Though, your hand may be forced -- https://github.com/google/flax/issues/4545#issuecomment-2657917039, sinceoptax.chainreturnsGradientTransformationExtraArgs.
Regarding points [1, 3], I've personally observed a complete lack of difference in performance in my script (attached previously) after lowering the train loop into a scan.
Hi @puct9, Thank you for getting back to me.
Indeed, the error you see is the main reason why I opened this issue. No matter what I did, I was unable to escape from it. That's why I tried the usual NNX way, the functional API, and the legacy TrainState before giving up and opening the issue. However, I just discovered that the source of the error was completely unrelated to the training function or to jax.lax.scan, jit or nnx.scan. Turns out that the culprit was my actor critic implementation:
class SharedActorCritic(nnx.Module):
def __init__(self,
observation_space: int,
action_space: int,
key: nnx.Rngs,
architecture: List[int] = [64, 64],
in_scale: float = jnp.sqrt(2),
actor_scale: float = 1.0,
critic_scale: float = 0.01,
activation = nnx.relu
):
layers = []
input_dim = observation_space
for output_dim in architecture:
layers.append(
nnx.Linear(
in_features=input_dim,
out_features=output_dim,
kernel_init=nnx.initializers.orthogonal(in_scale),
bias_init=nnx.initializers.constant(0.0),
rngs=key
)
)
layers.append(activation)
input_dim = output_dim
self.network = nnx.Sequential(*layers)
self.actor = nnx.Linear(
in_features=input_dim,
out_features=action_space,
kernel_init=nnx.initializers.orthogonal(actor_scale),
bias_init=nnx.initializers.constant(0.0),
rngs=key
)
self.critic = nnx.Linear(
in_features=input_dim,
out_features=1,
kernel_init=nnx.initializers.orthogonal(critic_scale),
bias_init=nnx.initializers.constant(0.0),
rngs=key
)
self.log_std = nnx.Param(jax.random.uniform(key.params(), (1, action_space,)))
def __call__(self, x):
x = self.network(x)
pi = distrax.MultivariateNormalDiag(self.actor(x), jnp.exp(self.log_std.value)) # the error was here, .value is required
return pi, self.critic(x)
When calling the probability distribution, I had to call .value in the log_std parameter. If you change this, the script that uses train state should work. I'm not entirely sure why, but that seems to resolve the issue. I discovered that removing all the scans and jited functions until the error became clear.
Thank you for your comments. I will implement them. I spent some time writing a pure JAX physics engine that is quite fast to create my environments. So, using @jax.jit instead of nnx is just a way to keep everything consistent. But I wasn't trying to avoid @nnx.jit. The script came out that way after many modifications.
About creating the rng key from a jax.key. I just did that to define only one key. Didn't know it was a dubious practice. What do you recommend for key handling, as I have other things that require a random key?
Please let me know if you have more recommendations. I would love to hear them. Also, thanks for the update in the notebook. I'll be checking it out.
Best regards.
When calling the probability distribution, I had to call .value in the log_std parameter. If you change this, the script that uses train state should work. I'm not entirely sure why, but that seems to resolve the issue.
@cdelv the reason most likely is that nnx.Variable (e.g. nnx.Param) implements the __jax_array__ protocol so JAX functions treat it as an Array. The main issue is that the protocol is not fully supported so you might get errors from time to time so we're evaluation is we should continue using it or have users explicitly access the .value or use other syntax like [...].
Let's close this issue as resolved, @cdelv feel free to reopen if needed more discussion on the topic.