equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Unstable PPO?

Open stergiosba opened this issue 7 months ago • 7 comments

Hello Patrick,

I am doing an implementation of the PPO algorithm for a custom environment and first wanted to test things out with a standard example and I choose CartPole-v1 implemented with gymnax.

I compared the Flax implementation from gymnax-blines that is stable/tested and solves a lot of environments with an equinox based solution. Essentially I got rid of the flax.TrainState as well as the flax NN model and replaced them with the usage of Equinox.Module based model.

The code runs with no errors but silently something is going on and I don't know what. To elaborate further, even though the agent is learning the following "problems" are still present:

  1. The Flax implementation converges much faster to the optimal with exactly the same configuration parameters (learning rate, clipping ratio e.t.c.) and same agent model (layers/initialization).
  2. The Flax implementation stays at the optimal while the performance of the equinox agent is unstable in the sense that it reaches close to the maximum but it exhibits oscillations. For instance the maximum in CarPole-v1 is 500 and the equinox agent can get from 499 to 300 to 150 and so on, in consecutive training epochs even after a training session whereas for the same training length (num of epochs) the Flax implementation is locked at 500 forever after a point.

I checked how randomization is progressing and I get exactly the same keys as the flax implementation since I start with the same seed and then split the keys at exactly the same places.

Could it be something with the way gradients are updating the parameters of the model? My input is batched so I know that I have to used jax.vmap according to the docs. The problem is that If I use jax.vmap and then define the optimization state as:

opt_state = optimizer.init(eqx.filter(equinox_model, eqx.is_array))

I then get None in the gradients because the output of jax.vmap is of type: type.FunctionType and that is not detectable by eqx.is_array. This pushed me to use eqx.filter_vmap an got the gradients with eqx.filter_value_and_grad but am I messing things up somewhere?

FYI, the model is defined as:

class Agent(eqx.Module):
    critic: List
    actor: List

    def __init__(self, env, env_params, key):
        obs_shape = env.observation_space(env_params).shape
        keys = jrandom.split(key, 6)

        # critic similar to actor but with out_dim=1... 
        # keys[0:2] are used in the critic layers.

        self.actor = [
            eqx.filter_vmap(
                eqx.nn.Linear(
                    jnp.array(obs_shape).prod(),
                    64,
                    key=keys[3],
                ),
                in_axes=0,
            ),
            jnn.relu,
            eqx.filter_vmap(
                eqx.nn.Linear(
                    64, 
                    64, 
                    key=keys[4]
                ), 
                in_axes=0),
            jnn.relu,
            eqx.filter_vmap(
                eqx.nn.Linear(
                    64, 
                    env.num_actions, 
                    key=keys[5]
                ),
                in_axes=0,
            ),
        ]
    
    def get_action_and_value(self, x):
        x_a = x
        x_v = x
        for layer in self.actor:
            x_a = layer(x_a)
            
        for layer in self.critic:
            x_v = layer(x_v)
            
        pi = tfp.distributions.Categorical(logits=x_a)

        return x_v, pi

Any suggestions would be greatly appreciated. Thanks!

stergiosba avatar Nov 21 '23 05:11 stergiosba