axon icon indicating copy to clipboard operation
axon copied to clipboard

Reinforcement Learning?

Open HurricanKai opened this issue 3 years ago • 3 comments

Hey I'm quite new to Axon and have only experimented with reinforcement learning a tiny bit, so please excuse if I haven't understood something.

From what I can gather there is no way to use Axon.Loop to do reinforcement learning, as it requires a batch of training data? Am I intended to implement the training / evaluation stuff myself? Or jump straight to some sort of experience replay, providing that as training data? I'm honestly not sure - and I couldn't really figure out how to do training without the convenience of Axon.Loop (though this may be because I've never done such a thing outside of Axon either).

I believe the simplest way to implement reinforcement learning would be to use Deep Q Learning? I think I've figured out most of what's required to do that, though I'm not sure how to properly update weights, I would've though Axon has some kind of helpers built in? (I've never dealt with updating weights myself directly, so I'm just guessing, please correct me)

I'd very much appreciate any help & tips :smile:

HurricanKai avatar Jun 20 '22 22:06 HurricanKai

Hey @HurricanKai! I don't have significant experience with reinforcement learning, though I think it should be possible with Axon's loop API and/or some custom implementation. Here is a detailed write up I did about how to make a hypothetical DQN: https://elixirforum.com/t/dqn-rl-with-axon-nx/43760

Some things have changed about the Loop API, but the concepts more or less remain the same. I believe @vans163 has messed around with Reinforcement Learning in Axon as well. If you are interested in providing an example, I would welcome that and definitely help offer my guidance wherever necessary!

As far as conveniences for updating the model, Axon offers the Axon.Updates and Axon.Optimizers API. There's some more details in the documentation, but here's a short example of how they work:

defn objective(model_state, inputs, targets) do
  preds = predict_function(model_state, inputs)
  Axon.Losses.mean_squared_error(preds, targets, reduction: :mean)
end

defn update_model(model_state, optimizer_state, inputs, targets, model_update_fn) do
  # Take the gradient of an objective function w.r.t. model state, this is basically a
  # parameterized loss function
  grads = grad(model_state, &objective(&1, inputs, targets)
  # Scale updates according to optimizer
  {scaled_updates, new_optimizer_state} = model_update_fn.(grads, optimizer_state, model_state)
  # Apply updates
  new_model_state = Axon.Updates.apply_updates(model_state, scaled_updates)
  # Return new model state and new optimizer state to use in the next iteration
  {new_model_state, new_optimizer_state}
end

You can construct an optimizer from one of the optimizers in Axon.Optimizers:

{init_optimizer, update_fn} = Axon.Optimizer.adam(1.0e-3)
optimizer_state = init_optimizer.(model_state)

update_model(model_state, optimizer_state, inputs, targets, update_fn)

seanmor5 avatar Jun 21 '22 00:06 seanmor5

Thank you so much, that helps a lot. I haven't quite figured out how the target / main net would work with this setup, but I'm sure I can make it work with some further research.

I'm building Snake in livebook and adding Deep Q-Learning to it right now, I'd be very happy to contribute that, though I'm not too sure about the code quality so some help to cleanup that code would be appreciated once it's done

HurricanKai avatar Jun 21 '22 10:06 HurricanKai

Hey @HurricanKai! I don't have significant experience with reinforcement learning, though I think it should be possible with Axon's loop API and/or some custom implementation. Here is a detailed write up I did about how to make a hypothetical DQN: https://elixirforum.com/t/dqn-rl-with-axon-nx/43760

Some things have changed about the Loop API, but the concepts more or less remain the same. I believe @vans163 has messed around with Reinforcement Learning in Axon as well. If you are interested in providing an example, I would welcome that and definitely help offer my guidance wherever necessary!

As far as conveniences for updating the model, Axon offers the Axon.Updates and Axon.Optimizers API. There's some more details in the documentation, but here's a short example of how they work:

defn objective(model_state, inputs, targets) do
  preds = predict_function(model_state, inputs)
  Axon.Losses.mean_squared_error(preds, targets, reduction: :mean)
end

defn update_model(model_state, optimizer_state, inputs, targets, model_update_fn) do
  # Take the gradient of an objective function w.r.t. model state, this is basically a
  # parameterized loss function
  grads = grad(model_state, &objective(&1, inputs, targets)
  # Scale updates according to optimizer
  {scaled_updates, new_optimizer_state} = model_update_fn.(grads, optimizer_state, model_state)
  # Apply updates
  new_model_state = Axon.Updates.apply_updates(model_state, scaled_updates)
  # Return new model state and new optimizer state to use in the next iteration
  {new_model_state, new_optimizer_state}
end

You can construct an optimizer from one of the optimizers in Axon.Optimizers:

{init_optimizer, update_fn} = Axon.Optimizer.adam(1.0e-3)
optimizer_state = init_optimizer.(model_state)

update_model(model_state, optimizer_state, inputs, targets, update_fn)

There is some stuff with playing Klondike Solitare using PPO here, https://github.com/gpuedge/lab/blob/main/lib/solitaire_ppo.ex. DQN did not seem to work well when there is a large input space (or I misconfigured it). DQN learned really well simple things, like reducing the game of solitaire to 2 actions for example.

I tried to do DQN input masking with stop_grad but it did not seem to have any effect (either I implemented it wrong or).

After that moved to PPO but the implementation is incomplete still, thats the work in progress linked above.

NOTE: SolitaireEasy module provides a dumbed down version of solitaire where all the single suit cards are in row0. The model needs to learn that the optimal solution is moving row_0 to pile_0, until row_0 has 0 cards. If it does any other action that is not optimal. I was using this to test growing the input space. Ideally if the input space is say 8k actions, the network should still learn to take that same row_0->pile_0 action. I could not achieve this.

vans163 avatar Jun 25 '22 14:06 vans163

PPO is expected to do much better than vanilla DQN though (and is also correspondingly more complex)

joaogui1 avatar Sep 11 '22 23:09 joaogui1

Cleaning up the issue tracker. Closing this as tracked in #47

seanmor5 avatar Sep 15 '22 17:09 seanmor5