axon
axon copied to clipboard
Reinforcement Learning?
Hey I'm quite new to Axon and have only experimented with reinforcement learning a tiny bit, so please excuse if I haven't understood something.
From what I can gather there is no way to use Axon.Loop to do reinforcement learning, as it requires a batch of training data? Am I intended to implement the training / evaluation stuff myself? Or jump straight to some sort of experience replay, providing that as training data? I'm honestly not sure - and I couldn't really figure out how to do training without the convenience of Axon.Loop (though this may be because I've never done such a thing outside of Axon either).
I believe the simplest way to implement reinforcement learning would be to use Deep Q Learning? I think I've figured out most of what's required to do that, though I'm not sure how to properly update weights, I would've though Axon has some kind of helpers built in? (I've never dealt with updating weights myself directly, so I'm just guessing, please correct me)
I'd very much appreciate any help & tips :smile:
Hey @HurricanKai! I don't have significant experience with reinforcement learning, though I think it should be possible with Axon's loop API and/or some custom implementation. Here is a detailed write up I did about how to make a hypothetical DQN: https://elixirforum.com/t/dqn-rl-with-axon-nx/43760
Some things have changed about the Loop API, but the concepts more or less remain the same. I believe @vans163 has messed around with Reinforcement Learning in Axon as well. If you are interested in providing an example, I would welcome that and definitely help offer my guidance wherever necessary!
As far as conveniences for updating the model, Axon offers the Axon.Updates and Axon.Optimizers API. There's some more details in the documentation, but here's a short example of how they work:
defn objective(model_state, inputs, targets) do
preds = predict_function(model_state, inputs)
Axon.Losses.mean_squared_error(preds, targets, reduction: :mean)
end
defn update_model(model_state, optimizer_state, inputs, targets, model_update_fn) do
# Take the gradient of an objective function w.r.t. model state, this is basically a
# parameterized loss function
grads = grad(model_state, &objective(&1, inputs, targets)
# Scale updates according to optimizer
{scaled_updates, new_optimizer_state} = model_update_fn.(grads, optimizer_state, model_state)
# Apply updates
new_model_state = Axon.Updates.apply_updates(model_state, scaled_updates)
# Return new model state and new optimizer state to use in the next iteration
{new_model_state, new_optimizer_state}
end
You can construct an optimizer from one of the optimizers in Axon.Optimizers:
{init_optimizer, update_fn} = Axon.Optimizer.adam(1.0e-3)
optimizer_state = init_optimizer.(model_state)
update_model(model_state, optimizer_state, inputs, targets, update_fn)
Thank you so much, that helps a lot. I haven't quite figured out how the target / main net would work with this setup, but I'm sure I can make it work with some further research.
I'm building Snake in livebook and adding Deep Q-Learning to it right now, I'd be very happy to contribute that, though I'm not too sure about the code quality so some help to cleanup that code would be appreciated once it's done
Hey @HurricanKai! I don't have significant experience with reinforcement learning, though I think it should be possible with Axon's loop API and/or some custom implementation. Here is a detailed write up I did about how to make a hypothetical DQN: https://elixirforum.com/t/dqn-rl-with-axon-nx/43760
Some things have changed about the Loop API, but the concepts more or less remain the same. I believe @vans163 has messed around with Reinforcement Learning in Axon as well. If you are interested in providing an example, I would welcome that and definitely help offer my guidance wherever necessary!
As far as conveniences for updating the model, Axon offers the
Axon.UpdatesandAxon.OptimizersAPI. There's some more details in the documentation, but here's a short example of how they work:defn objective(model_state, inputs, targets) do preds = predict_function(model_state, inputs) Axon.Losses.mean_squared_error(preds, targets, reduction: :mean) end defn update_model(model_state, optimizer_state, inputs, targets, model_update_fn) do # Take the gradient of an objective function w.r.t. model state, this is basically a # parameterized loss function grads = grad(model_state, &objective(&1, inputs, targets) # Scale updates according to optimizer {scaled_updates, new_optimizer_state} = model_update_fn.(grads, optimizer_state, model_state) # Apply updates new_model_state = Axon.Updates.apply_updates(model_state, scaled_updates) # Return new model state and new optimizer state to use in the next iteration {new_model_state, new_optimizer_state} endYou can construct an optimizer from one of the optimizers in
Axon.Optimizers:{init_optimizer, update_fn} = Axon.Optimizer.adam(1.0e-3) optimizer_state = init_optimizer.(model_state) update_model(model_state, optimizer_state, inputs, targets, update_fn)
There is some stuff with playing Klondike Solitare using PPO here, https://github.com/gpuedge/lab/blob/main/lib/solitaire_ppo.ex. DQN did not seem to work well when there is a large input space (or I misconfigured it). DQN learned really well simple things, like reducing the game of solitaire to 2 actions for example.
I tried to do DQN input masking with stop_grad but it did not seem to have any effect (either I implemented it wrong or).
After that moved to PPO but the implementation is incomplete still, thats the work in progress linked above.
NOTE: SolitaireEasy module provides a dumbed down version of solitaire where all the single suit cards are in row0. The model needs to learn that the optimal solution is moving row_0 to pile_0, until row_0 has 0 cards. If it does any other action that is not optimal. I was using this to test growing the input space. Ideally if the input space is say 8k actions, the network should still learn to take that same row_0->pile_0 action. I could not achieve this.
PPO is expected to do much better than vanilla DQN though (and is also correspondingly more complex)
Cleaning up the issue tracker. Closing this as tracked in #47