corax icon indicating copy to clipboard operation
corax copied to clipboard

Corax: Core RL in JAX

Corax: Core RL in JAX

Code style: black Ruff test

Installation | Examples | Agents | Datasets

Corax is a library for reinforcement learning algorithms in JAX. It aims at providing modular, pure and functional components for RL algorithms that can be easily used in different training loops and accelerator configurations. Currently, we are exploring the design of a LearnerCore and ActorCore design that allows easy composition and scaling of RL algorithms. At the same time, Corax aims to provide strong baseline agents that can be forked and customized for future RL research.

Corax starts as a fork of the dm-acme library while aiming to provide a better experience for researchers working on Online/offline RL in JAX. Future development of Corax may diverge from the design in Acme.

Installation

You can install Corax with

pip install 'git+https://github.com/ethanluoyc/corax#egg=corax[tf,jax]'

To use Corax with GPU, you need to install JAX with GPU support. Follow the instructions here for how to install JAX with GPU support.

Note on TensorFlow dependency

The base corax package does not depend on specific deep learning frameworks. However, the JAX agent depends on TensorFlow for efficient data-processing.

We provide optional extras for installing tensorflow-cpu and compatible versions of TensorFlow Probability and Reverb.

However, should that be incompatible with your own dependency requirements, you can optionally specify these dependencies yourself and opt-out our extras. Check out the pyproject.toml for examples on how to specify compatible TensorFlow versions. Here is an example workflow of determining the compatible versions of TensorFlow, TensorFlow-Probability and Reverb. Assume that you will use tensorflow-cpu~=2.13.0, then

  1. Looking at https://github.com/tensorflow/probability/releases, the tensorflow-probability version compatible with tensorflow~=2.13.0 is 0.21.0.
  2. Looking at https://github.com/google-deepmind/reverb/tree/master#reverb-releases, the dm-reverb version compatible is 0.12.0.

Therefore, as an application developer, you should put the following in your requirements.txt

tensorflow-cpu~=2.13.0
tensorflow-probability~=0.21.0
dm-reverb~=0.12.0

If you use dm-launchpad. The workflow is similar, although as of 17 Oct, 2023 Launchpad does not provide an official build for tensorflow 2.13.0. We however, have an unofficial manylinux build for Python 3.9 and 3.10 available at https://github.com/ethanluoyc/launchpad/releases/tag/v0.6.0rc0. If you intend to use this version, you should include in your requirements.txt

# Use the exact link to the wheel file for your Python version
dm-launchpad @ https://github.com/ethanluoyc/launchpad/releases/download/v0.6.0rc0/dm_launchpad-0.6.0rc0-cp39-cp39-manylinux2014_x86_64.whl

We currently do not have build for tensorflow 2.14.0 due to https://github.com/google-deepmind/launchpad/issues/44.

Examples

Examples can be found in projects.

Development

git clone https://github.com/ethanluoyc/jax
cd corax
# Create a virtual environment with the method of your choice.
python3 -m venv .venv
source .venv/bin/activate
# Then run
pip install -e '.[tf,jax,test,dev]'
# Install pre-commit hooks if you intend to create PRs.
pre-commit install
# Install the baselines by running
pip install -r projects/baselines/requirements.txt -e projects/baselines

Agents

Corax includes high-quality implementation of many popular RL agents. These agents are meant to be forked and customized for future RL research.

The implementation has been used in numerous research projects and we intend to provide benchmark results for these agents in the future.

Corax currently implements the following agents JAX:

Agent Paper Code
CalQL Nakamoto et al., 2023 calql
CQL Kumar et al., 2020 calql
IQL Kostrikov et al., 2021 iql
RLPD Ball et al., 2023 redq
Decision Transformer Chen et al., 2021a decision_transformer
DrQ-v2(-BC) Yarats et al., 2021 drq_v2
ORIL Zolna et al., 2020 oril
OTR Luo et al., 2023 otr
REDQ Chen et al., 2021b redq
TD3 Fujimoto et al., 2018 td3
TD3-BC Fujimoto et al., 2021 td3
TD-MPC Hansen et al., 2021 tdmpc

More agents, including those implemented in Magi may be added in the future. Contributions to include new agents are welcome!

Datasets

For online RL, Corax uses Reverb for online RL agents.

When working with offline RL, existing datasets provided by the community may come in different formats. It can be time-consuming to integrate existing algorithms with different datasets.

Therefore, for offline RL, Corax provides additional TFDS dataset builders that can build datasets stored in RLDS format. This allows easily running the same offline RL algorithm on offline RL datasets in a consistent manner. You may want to check out the list of datasets officially supported by the TFDS/RLDS team.

In addition to the official RLDS datasets, the following datasets can be built with Corax:

Dataset Paper Code
V-D4RL Lu et al., 2023 vd4rl
Watch and Match Haldar et al., 2022 rot
ExoRL Yarats et al., 2022 exorl
GWIL Fickinger et al., 2022 gwil
Adroit Binary Nair et al., 2022 adroit_binary

NOTE: Some of these datasets do not yet cover all splits provided by the original dataset. They will be added as the need arises.

Acknowledgements

We would like to thank the Acme authors who have provided a great starting point for Corax. Without them, Corax would not exist as a significant portion of the current code is forked from them. You should check out Acme if you are looking for more RL agent implementations.

We would like to thank the authors of the original papers for open-sourcing their code which has been a great help in our re-implementation.