flax icon indicating copy to clipboard operation
flax copied to clipboard

How to train an NNX netowork using scan and jit?

Open cdelv opened this issue 9 months ago • 8 comments

Hi Flax team,

I'm writing as I've been struggling the past week trying to implement PPO for reinforcement learning using NNX. I'm at the point where I need some help from the experts.

I've been following the implementation of purejaxrl and the one in Brax. The thing is that these two use the old linen API. I wanted to move to the new NNX API but had lots of problems with the JIT. Before I start showing what I have, I also tried to follow the guidelines for performance in the docs and this reply by @puct9 https://github.com/google/flax/issues/4045#issuecomment-2350903096.

Ok then, first a simple example:

def rollout_trajectories(env: rl.Env, graphdef: nnx.GraphDef, gstate: nnx.GraphState, env_state: rl.EnvState, rng: jnp.ndarray, num_steps: int):
    model, _, _    = nnx.merge(graphdef, gstate)
    
    def _env_step(carry, _):
        env_state, rng = carry

        obs            = env.observation(env_state)
        pi, value      = model(obs)
        rng, subkey    = jax.random.split(rng)
        action         = pi.sample(seed=subkey)
        logp           = pi.log_prob(action)

        state_jdem, system_jdem, env_params_jdem = env.step(env_state, action)
        new_env_state  = rl.EnvState(state_jdem, system_jdem, env_params_jdem)
        reward         = env.reward(new_env_state)
        done           = env.done(new_env_state)
        info           = env.info(new_env_state)

        trans = TrajectoryData(done, action, value, reward, logp, obs, info)
        return (new_env_state, rng), trans

    return jax.lax.scan(_env_step, (env_state, rng), None, num_steps)

This code is very straightforward. Just use the scan to get information about the trajectories. The problem comes when I try to JIT it. No matter what I do: defining the nnx parameters as static, using nnx.jit capturing them in the scope, etc., the function always errors with something similar to TypeError: Argument 'args[0]' of shape float32[1,2] of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.

Is there a way to pass/do inference inside a JITed function? I know that in the MINIST tutorial, it works, but I'm unable to replicate the same result.

Now, the important problem. Training using scan. Following the same problem I had before, no matter if I use nnx.scan or jax.lax.Scan, I'm having trouble passing the NNX object. This makes it impossible for me even to attempt to train the network. My code looks something like this:

    rng, sub  = jax.random.split(rng)
    network   = SharedActorCritic(env.observation_space, env.action_space, nnx.Rngs(sub))
    tx        = optax.chain(
                    optax.clip_by_global_norm(max_grad_norm),
                    optax.adam(learning_rate, eps=1e-5)
                )
    optimizer = nnx.Optimizer(network, tx)
    metrics   = nnx.MultiMetric(loss=nnx.metrics.Average())
    graphdef, init_gstate = nnx.split((network, optimizer, metrics))

    # reset envs
    rng, sub       = jax.random.split(rng)
    reset_rngs     = jax.random.split(sub, num_envs)
    env_state      = env.reset(reset_rngs, env_params)

    train_state = (init_gstate, env_state, rng)

    def train_step(train_state, _):
        gstate, env_state, rng = train_state
        m, opt, met = nnx.merge(graphdef, gstate)

        (env_state, rng), trajectory_data = rollout_trajectories(env, graphdef, gstate, env_state, rng, num_steps)
        advantages, targets = calculate_general_advantage(trajectory_data, gamma, gae_lambda)


        # PPO logic

        return (nnx.state((m, opt, met)), env_state, rng), 0

I cant make this function work as long as I don't JIT rollout_trajectories and don't return or pass any of the nnx objects. The problem with this is that I am unable to update them. I could use a regular for loop to call train_step and only use scan for the minibatches, but I was hoping to achieve high-performance code. I tried using train_state, and that helped me make some progress, but in the end, I still encountered the same problem.

If you want to play with this code, clone this repo and execute this script inside. I think everything should work (you will get the nnx error with the JIT though).

I see that with the linen API this moving in and out of JIT land is not that hard. I would appreciate it if someone could explain to me how to do it. Additionally, if there is a recommended approach to these types of problems with NNX, I would appreciate it if you could point me to it.

Best regards.

cdelv avatar Apr 20 '25 23:04 cdelv

I currently lack the capacity to answer this question very specifically and in depth, but I happen to have recently done something extremely similar (train RL PPO in Flax NNX, also featuring the scan trick). The code is in a single notebook here.

Apologies for being unable to answer the question in a more targeted manner, but I hope this helps.

puct9 avatar Apr 21 '25 07:04 puct9

Hi @puct9, your answer might be just what I needed. Thank you very much. I'll review the code and use it to make my own. I'll come back if I have any questions.

cdelv avatar Apr 21 '25 13:04 cdelv

Hey @cdelv, it would be great to have a reproducible example we can run to simplify debugging. Based on what you showed its a little bit tricky to figure out. The only thing I can say is that passing model as a capture is generally not a good idea as any mutation would raise an error, it better to pass it as an explicit input to nnx.scan-ed function.

cgarciae avatar Apr 21 '25 21:04 cgarciae

Hi @cgarciae, thank you for your response.

I wrote this standalone script that should reproduce the problem I'm talking about. I haven't quite finished going through @puct9's script, but it looks like they use a regular for loop for the training loop and scan for the batches. What I'm trying to do is use a scan for everything.

I want to have the highest performance possible. I aprecciate any sugestions on how to achive this.

I'll come back when I finish adapting @puct9's script.

cdelv avatar Apr 22 '25 19:04 cdelv

Hi @cgarciae,

After trying different approaches, I found this example in the documentation using TrainState. I attempted to make the code work using TrainState, and it got me really far. However, the code fails at the end with TypeError: Argument 'Traced<ShapedArray(float32[1,2])>with<DynamicJaxprTrace>' of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type.

Here is the complete script: main.txt

There is an issue with passing the train state around, as it becomes a tracer that JAX cannot use. I am unable to get around it. I know that in linen this is possible. I don't understand why in nnx it does not work.

cdelv avatar Apr 22 '25 22:04 cdelv

I was unable to run the provided script due to a potentially unrelated (?) issue -- ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field pos is not allowed: use default_factory (Python 3.11) so I don't think that I can directly address the problem this time, but I have a couple points anyway.

I haven't quite finished going through @puct9's script, but it looks like they use a regular for loop for the training loop and scan for the batches. What I'm trying to do is use a scan for everything.

I updated the train loop to perform any number of iterations at a time 😉

Now more on the actual question

I want to have the highest performance possible. I aprecciate any sugestions on how to achive this.

I notice a small edit you've attempted to make from what I assume was the reference material:

-    def train(rng):
+    @jax.jit
+    def train(key):

And your wish for my original notebook's training loop to be in a scan.

Hence, I think the point of performance and why we want to sometimes jit has been somewhat missed. For example:

  1. In extremely quick train loops, it can be insufficient to simply jit the function that applies a single update. The reason is because we will spend more time traversing between the compiled function and Python. In this case, due to the more computationally intensive nature of needing to run the environment, computing losses, and updating the network, it's unlikely we can achieve more than a few dozen iterations per second. This means that the theoretical difference in performance between repeatedly calling a jitted function from Python and having it in something like a scan decreases to a point where it is negligible.
  2. Jax's performance certainly is solid, but it's not magic. The greatest gain from being able to describe the environment in Jax is the level of convenience for the good performance. The only other competitive alternative (actually superior in this case) is probably to handle the environment and a lot of what Jax does yourself in a low-level language like C++ at a great inconvenience. Likewise, the expectation that going beyond jitting a significant portion of the train loop (e.g., the train step function) and jitting the train loop itself will bring much performance benefit will likely result in disappointment and a nonnegligible amount of wasted time. If you want to understand where you're gaining/losing time, consider profiling the program.
  3. Whether deliberate or by coincidence, I didn't find a single instance of nnx.scan or nnx.jit in your script. There is a difference between the nnx and jax versions of those functions, such as being able to pass nnx.Module and nnx.Rngs objects correctly. I will make no assumptions about your intent, but "don't use nnx.jit because it's slower" would certainly be the wrong lesson to take from https://github.com/google/flax/issues/4045. The source of the slowdown is the traversal between Python and a function compiled with nnx.jit, and even that cost is negligible as long as the operation is sufficiently expensive (which in this case, it is).

This also isn't to say that if you remove jit things are much more likely to work. There are a number of other dubious practices some of which I'm not sure are meant to work at all:

  1. Constructing nnx.Rngs from a vanilla Jax key
            key, subkey = jax.random.split(key)
            network = SharedActorCritic(env.observation_space, env.action_space, nnx.Rngs(subkey))
    
  2. Use of TrainState. I'd recommend attempting the method shown in the Flax NNX vs JAX transformations (click on JAX transforms) or here as the ideal way of using objects originating from Flax or NNX in places like jax.jitted functions. Though, your hand may be forced -- https://github.com/google/flax/issues/4545#issuecomment-2657917039, since optax.chain returns GradientTransformationExtraArgs.

Regarding points [1, 3], I've personally observed a complete lack of difference in performance in my script (attached previously) after lowering the train loop into a scan.

puct9 avatar Apr 23 '25 16:04 puct9

Hi @puct9, Thank you for getting back to me.

Indeed, the error you see is the main reason why I opened this issue. No matter what I did, I was unable to escape from it. That's why I tried the usual NNX way, the functional API, and the legacy TrainState before giving up and opening the issue. However, I just discovered that the source of the error was completely unrelated to the training function or to jax.lax.scan, jit or nnx.scan. Turns out that the culprit was my actor critic implementation:

class SharedActorCritic(nnx.Module):
    def __init__(self, 
        observation_space: int, 
        action_space: int, 
        key: nnx.Rngs, 
        architecture: List[int] = [64, 64], 
        in_scale: float = jnp.sqrt(2), 
        actor_scale: float = 1.0, 
        critic_scale: float = 0.01,
        activation = nnx.relu
    ):
        layers = []
        input_dim = observation_space

        for output_dim in architecture:
            layers.append(
                nnx.Linear(
                    in_features=input_dim,
                    out_features=output_dim,
                    kernel_init=nnx.initializers.orthogonal(in_scale),
                    bias_init=nnx.initializers.constant(0.0),
                    rngs=key
                )
            )
            layers.append(activation)  
            input_dim = output_dim

        self.network = nnx.Sequential(*layers)
        self.actor = nnx.Linear(
                    in_features=input_dim,
                    out_features=action_space,
                    kernel_init=nnx.initializers.orthogonal(actor_scale),
                    bias_init=nnx.initializers.constant(0.0),
                    rngs=key
                )
        self.critic = nnx.Linear(
                    in_features=input_dim,
                    out_features=1,
                    kernel_init=nnx.initializers.orthogonal(critic_scale),
                    bias_init=nnx.initializers.constant(0.0),
                    rngs=key
                )
        self.log_std = nnx.Param(jax.random.uniform(key.params(), (1, action_space,)))

    def __call__(self, x):
        x = self.network(x)
        pi = distrax.MultivariateNormalDiag(self.actor(x), jnp.exp(self.log_std.value))  # the error was here, .value is required
        return pi, self.critic(x)

When calling the probability distribution, I had to call .value in the log_std parameter. If you change this, the script that uses train state should work. I'm not entirely sure why, but that seems to resolve the issue. I discovered that removing all the scans and jited functions until the error became clear.

Thank you for your comments. I will implement them. I spent some time writing a pure JAX physics engine that is quite fast to create my environments. So, using @jax.jit instead of nnx is just a way to keep everything consistent. But I wasn't trying to avoid @nnx.jit. The script came out that way after many modifications.

About creating the rng key from a jax.key. I just did that to define only one key. Didn't know it was a dubious practice. What do you recommend for key handling, as I have other things that require a random key?

Please let me know if you have more recommendations. I would love to hear them. Also, thanks for the update in the notebook. I'll be checking it out.

Best regards.

cdelv avatar Apr 23 '25 18:04 cdelv

When calling the probability distribution, I had to call .value in the log_std parameter. If you change this, the script that uses train state should work. I'm not entirely sure why, but that seems to resolve the issue.

@cdelv the reason most likely is that nnx.Variable (e.g. nnx.Param) implements the __jax_array__ protocol so JAX functions treat it as an Array. The main issue is that the protocol is not fully supported so you might get errors from time to time so we're evaluation is we should continue using it or have users explicitly access the .value or use other syntax like [...].

cgarciae avatar Apr 23 '25 20:04 cgarciae

Let's close this issue as resolved, @cdelv feel free to reopen if needed more discussion on the topic.

vfdev-5 avatar Oct 24 '25 08:10 vfdev-5