equinox
equinox copied to clipboard
Unstable PPO?
Hello Patrick,
I am doing an implementation of the PPO algorithm for a custom environment and first wanted to test things out with a standard example and I choose CartPole-v1 implemented with gymnax
.
I compared the Flax implementation from gymnax-blines
that is stable/tested and solves a lot of environments with an equinox based solution. Essentially I got rid of the flax.TrainState
as well as the flax NN model and replaced them with the usage of Equinox.Module
based model.
The code runs with no errors but silently something is going on and I don't know what. To elaborate further, even though the agent is learning the following "problems" are still present:
- The Flax implementation converges much faster to the optimal with exactly the same configuration parameters (learning rate, clipping ratio e.t.c.) and same agent model (layers/initialization).
- The Flax implementation stays at the optimal while the performance of the equinox agent is unstable in the sense that it reaches close to the maximum but it exhibits oscillations. For instance the maximum in CarPole-v1 is 500 and the equinox agent can get from 499 to 300 to 150 and so on, in consecutive training epochs even after a training session whereas for the same training length (num of epochs) the Flax implementation is locked at 500 forever after a point.
I checked how randomization is progressing and I get exactly the same keys as the flax implementation since I start with the same seed and then split the keys at exactly the same places.
Could it be something with the way gradients are updating the parameters of the model? My input is batched so I know that I have to used jax.vmap
according to the docs. The problem is that If I use jax.vmap
and then define the optimization state as:
opt_state = optimizer.init(eqx.filter(equinox_model, eqx.is_array))
I then get None
in the gradients because the output of jax.vmap
is of type: type.FunctionType
and that is not detectable by eqx.is_array. This pushed me to use eqx.filter_vmap
an got the gradients with eqx.filter_value_and_grad
but am I messing things up somewhere?
FYI, the model is defined as:
class Agent(eqx.Module):
critic: List
actor: List
def __init__(self, env, env_params, key):
obs_shape = env.observation_space(env_params).shape
keys = jrandom.split(key, 6)
# critic similar to actor but with out_dim=1...
# keys[0:2] are used in the critic layers.
self.actor = [
eqx.filter_vmap(
eqx.nn.Linear(
jnp.array(obs_shape).prod(),
64,
key=keys[3],
),
in_axes=0,
),
jnn.relu,
eqx.filter_vmap(
eqx.nn.Linear(
64,
64,
key=keys[4]
),
in_axes=0),
jnn.relu,
eqx.filter_vmap(
eqx.nn.Linear(
64,
env.num_actions,
key=keys[5]
),
in_axes=0,
),
]
def get_action_and_value(self, x):
x_a = x
x_v = x
for layer in self.actor:
x_a = layer(x_a)
for layer in self.critic:
x_v = layer(x_v)
pi = tfp.distributions.Categorical(logits=x_a)
return x_v, pi
Any suggestions would be greatly appreciated. Thanks!
Hard to say exactly what the issue is (if you have the MVC that would be helpful) from this description. I usually do the neural network defined for a single input, then vmap outside of that (instead of internally), but that's a smaller point. You could also check out: https://github.com/patrick-kidger/rl-test/blob/master/src/algorithms.py#L61.
Not super helpful in the near term, but I am working on https://github.com/sotetsuk/pgx/issues/1059 which will result in a stable equinox PPO version (but the PR for that issue won't be done in the near term).
Hey @lockwo. Thanks for the comment.
I have seen the PPO algo that Patrick wrote in the past as well as other implementations like CleanRL.
I reported this as I thought there was an internal issue with Equinox. Of course, I realize that most probably my revisions are what broke the algorithm and not Equinox but wanted a second opinion.
I am still working on the issue and will update with new info as I go.
Hard to say exactly what the issue is (if you have the MVC that would be helpful) from this description. I usually do the neural network defined for a single input, then vmap outside of that (instead of internally), but that's a smaller point.
I would note that this design pattern also means you don't need to worry about getting gradients as you then discussed with "I then get None in the gradients because the output of jax.vmap is of type: type.FunctionType and that is not detectable by eqx.is_array." (But if you do want to use this pattern, then using filter_vmap
as you have is indeed the appropriate solution.)
One possible difference between the implementations may be that Equinox and Flax use different initialisations for the same seed (e.g. sampling from a random vs a uniform distribution).
One possible difference between the implementations may be that Equinox and Flax use different initialisations for the same seed (e.g. sampling from a random vs a uniform distribution).
I have ensured that both models start with the same initialization and all PPO parameters are of course the same.
What I did:
- [x] 1. Have a common config file to ensure same PPO parameters
- [x] 2. Initialize identical equinox and flax models before training.
- [x] 3. Copy the weights and biases from the equinox model to the flax model parameters to ensure equal initialization. I use orthogonal initialization but flax converges with Lecun as well.
- [x] 4. Check if both models produce the same outputs given the same observations.
- [x] 5. Check if the used randomization keys throughout PPO are always the same.
What is baffling is that even though the first batch of data is the same, the gradients begin to diverge ever so slightly. After the first batch I cannot keep the same input as the two models choose different actions and thus get different observations etc.
The result is again that the Equinox model oscillates around the optimal and then makes huge dips in performance while the Flax model converges.
Also tried x64 in JAX config, still the same.
Any other thoughts would be greatly appreciated. Still looking into this.
Tried with the model for single observation and vmap outside. I get the same performance as before (I mean exactly the same to the last decimal point). The model is the following:
class Agent(eqx.Module):
critic: List
actor: List
def __init__(self, env, env_params, key):
obs_shape = env.observation_space(env_params).shape
keys = jrandom.split(key, 6)
self.critic = eqx.nn.Sequential(
[
eqx.nn.Linear(
jnp.array(obs_shape).prod(),
64,
key=keys[0],
),
eqx.nn.Lambda(jnn.relu),
eqx.nn.Linear(64, 64, key=keys[1]),
eqx.nn.Lambda(jnn.relu),
eqx.nn.Linear(64, 1, key=keys[2]),
]
)
self.actor = eqx.nn.Sequential(
[
eqx.nn.Linear(
jnp.array(obs_shape).prod(),
64,
key=keys[3],
),
eqx.nn.Lambda(jnn.relu),
eqx.nn.Linear(64, 64, key=keys[4]),
eqx.nn.Lambda(jnn.relu),
eqx.nn.Linear(64, env.num_actions, key=keys[5]),
]
)
@eqx.filter_jit
def __call__(self, x):
return self.critic(x), self.actor(x)
If you meant something else and not using these two sequential models please explain further. Thanks again.
That is what I meant. It was mostly a design pattern note, I wouldn’t expect any numerical differences
That is what I meant. It was mostly a design pattern note, I wouldn’t expect any numerical differences
Yeah I was just paranoid and checked everything.
I managed to make it work and it is stable when I do the following hacky thing:
I make a namedtuple trainstate
as follows:
eqxTrainState = namedtuple("eqxTrainState", ["params", "static", "tx", "opt_state"])
Then basically, I carry this trainstate
around everywhere instead of the equinox model. I reconstruct the model when needed from the params
and static
and then do a forward pass. I will test now to see if I can do this without the partition.
I attach the following image so you can get an idea of what I have been seeing so far (green) and what the performance is now (orange) for Equinox vs Flax (blue).
As it can be seen, the "correct" equinox and the flax implementations are nearly identical at the beginning. At some points there are differences but that is to be expected? This remains a bit of a mystery. Also the dips are sharper for the equinox agent but much better than the green tragedy :)
edit: Works without the partition as well