JAX-RL icon indicating copy to clipboard operation
JAX-RL copied to clipboard

JAX implementations of various deep reinforcement learning algorithms.

trafficstars

JAX-RL

JAX implementations of various deep reinforcement learning algorithms.

Main libraries used:

  • JAX - main framework
  • Haiku - neural networks
  • Optax - gradient based optimisation

Algorithms implemented

Algorithms Paper
Proximal Policy Optimization (PPO) https://arxiv.org/abs/1707.06347
Deep Q-Network (DQN) https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
Double Deep Q-Network (DDQN) https://arxiv.org/abs/1509.06461
Deep Recurrent Q-Network (DRQN) https://arxiv.org/abs/1507.06527
Deep Deterministic Policy Gradient (DDPG) https://arxiv.org/abs/1509.02971

Tabular algorithms

  • Q-learning
  • Double Q-learning
  • SARSA
  • Expected SARSA

Installation

$ pip install git+https://github.com/hamishs/JAX-RL