batch_rl
batch_rl copied to clipboard
JAX code
Hi,
I would like to ask whether there is a jax-based code.
And whether there are some recommendations about jax-based offline rl algorithms.
Thanks!
Releasing the JAX code might take some time but it should be easy to modify existing dopamine agents. In the meanwhile, here are some tips to get started with jax agents:
- You can easily modify the JAX dopamine agents to run offline RL on Atari datasets.
- For the replay buffer, you can use the class
FixedReplayBuffer
in fixed_replay_buffer.py for creating the offline buffer
@gin.configurable
class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent):
"""A JAX implementation of the Offline DQN agent."""
def __init__(self,
num_actions,
replay_data_dir,
summary_writer=None):
"""Initializes the agent and constructs the necessary components.
Args:
num_actions: int, number of actions the agent can take at any state.
replay_data_dir: str, log Directory from which to load the replay buffer.
summary_writer: SummaryWriter object for outputting training statistics
"""
logging.info('Creating %s agent with the following parameters:',
self.__class__.__name__)
logging.info('\t replay directory: %s', replay_data_dir)
self.replay_data_dir = replay_data_dir
super().__init__(
num_actions, update_period=1, summary_writer=summary_writer)
def _build_replay_buffer(self):
"""Creates the fixed replay buffer used by the agent."""
return fixed_replay_buffer.FixedReplayBuffer(
data_dir=self.replay_data_dir,
observation_shape=self.observation_shape,
stack_size=self.stack_size,
update_horizon=self.update_horizon,
gamma=self.gamma,
observation_dtype=self.observation_dtype)
def reload_data(self):
# This needs to be called every iteration to subsample a portion of the dataset.
self._replay.reload_data()
def step(self, reward, observation):
"""Returns the agent's next action and update agent's state.
Args:
reward: float, the reward received from the agent's most recent action.
observation: numpy array, the most recent observation.
Returns:
int, the selected action.
"""
self._record_observation(observation)
self._rng, self.action = dqn_agent.select_action(
self.network_def, self.online_params, self.state, self._rng,
self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
self.epsilon_decay_period, self.training_steps, self.min_replay_history,
self.epsilon_fn)
self.action = onp.asarray(self.action)
return self.action
def train_step(self):
"""Exposes the train step for offline learning."""
super()._train_step()
Thank you very much! May I ask if I can also run the code in TPU-VM with JAX?
Best, Lucas
I think so -- you probably want to use the tfds datasets or use much larger batch sizes with the dopamine codebase.
Thank you very much! I'll have a try.
Dear agarwl,
I try to follow your provided code and reproduce the results of offline dqn based on jax. I find the training speed of jax is quite slow compared with TensorFlow. May I ask the possible reason about that. I try to change these parts in the vanilla dopamine code:
(1) I try to rewrite the Runner in dopamine/dopamine/discrete_domains/run_experiment.py
based on the code in batch_rl/batch_rl/fixed_replay/run_experiment.py
:
@gin.configurable
class FixedReplayRunner(run_experiment.Runner):
"""Object that handles running Dopamine experiments with fixed replay buffer."""
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
super(FixedReplayRunner, self)._initialize_checkpointer_and_maybe_resume(
checkpoint_file_prefix)
# Code for the loading a checkpoint at initialization
init_checkpoint_dir = self._agent._init_checkpoint_dir # pylint: disable=protected-access
if (self._start_iteration == 0) and (init_checkpoint_dir is not None):
if checkpointer.get_latest_checkpoint_number(self._checkpoint_dir) < 0:
# No checkpoint loaded yet, read init_checkpoint_dir
init_checkpointer = checkpointer.Checkpointer(
init_checkpoint_dir, checkpoint_file_prefix)
latest_init_checkpoint = checkpointer.get_latest_checkpoint_number(
init_checkpoint_dir)
if latest_init_checkpoint >= 0:
experiment_data = init_checkpointer.load_checkpoint(
latest_init_checkpoint)
if self._agent.unbundle(
init_checkpoint_dir, latest_init_checkpoint, experiment_data):
if experiment_data is not None:
assert 'logs' in experiment_data
assert 'current_iteration' in experiment_data
self._logger.data = experiment_data['logs']
self._start_iteration = experiment_data['current_iteration'] + 1
tf.logging.info(
'Reloaded checkpoint from %s and will start from iteration %d',
init_checkpoint_dir, self._start_iteration)
def _run_train_phase(self):
"""Run training phase."""
self._agent.eval_mode = False
start_time = time.time()
for _ in range(self._training_steps):
self._agent._train_step() # pylint: disable=protected-access
time_delta = time.time() - start_time
tf.logging.info('Average training steps per second: %.2f',
self._training_steps / time_delta)
def _run_one_iteration(self, iteration):
"""Runs one iteration of agent/environment interaction."""
statistics = iteration_statistics.IterationStatistics()
tf.logging.info('Starting iteration %d', iteration)
# pylint: disable=protected-access
if not self._agent._replay_suffix:
# Reload the replay buffer
self._agent._replay.memory.reload_buffer(num_buffers=5)
# pylint: enable=protected-access
self._run_train_phase()
num_episodes_eval, average_reward_eval = self._run_eval_phase(statistics)
self._save_tensorboard_summaries(
iteration, num_episodes_eval, average_reward_eval)
return statistics.data_lists
def _save_tensorboard_summaries(self, iteration,
num_episodes_eval,
average_reward_eval):
"""Save statistics as tensorboard summaries.
Args:
iteration: int, The current iteration number.
num_episodes_eval: int, number of evaluation episodes run.
average_reward_eval: float, The average evaluation reward.
"""
summary = tf.Summary(value=[
tf.Summary.Value(tag='Eval/NumEpisodes',
simple_value=num_episodes_eval),
tf.Summary.Value(tag='Eval/AverageReturns',
simple_value=average_reward_eval)
])
self._summary_writer.add_summary(summary, iteration)
(2) creat offline buffer: fixed_replay_buffer.py
(3) create OfflineJaxDQNAgent:
@gin.configurable
class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent):
"""A JAX implementation of the Offline DQN agent."""
def __init__(self,
num_actions,
replay_data_dir,
summary_writer=None):
"""Initializes the agent and constructs the necessary components.
Args:
num_actions: int, number of actions the agent can take at any state.
replay_data_dir: str, log Directory from which to load the replay buffer.
summary_writer: SummaryWriter object for outputting training statistics
"""
logging.info('Creating %s agent with the following parameters:',
self.__class__.__name__)
logging.info('\t replay directory: %s', replay_data_dir)
self.replay_data_dir = replay_data_dir
super().__init__(
num_actions, update_period=1, summary_writer=summary_writer)
def _build_replay_buffer(self):
"""Creates the fixed replay buffer used by the agent."""
return fixed_replay_buffer.FixedReplayBuffer(
data_dir=self.replay_data_dir,
observation_shape=self.observation_shape,
stack_size=self.stack_size,
update_horizon=self.update_horizon,
gamma=self.gamma,
observation_dtype=self.observation_dtype)
def reload_data(self):
# This needs to be called every iteration to subsample a portion of the dataset.
self._replay.reload_data()
def step(self, reward, observation):
"""Returns the agent's next action and update agent's state.
Args:
reward: float, the reward received from the agent's most recent action.
observation: numpy array, the most recent observation.
Returns:
int, the selected action.
"""
self._record_observation(observation)
self._rng, self.action = dqn_agent.select_action(
self.network_def, self.online_params, self.state, self._rng,
self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
self.epsilon_decay_period, self.training_steps, self.min_replay_history,
self.epsilon_fn)
self.action = onp.asarray(self.action)
return self.action
def train_step(self):
"""Exposes the train step for offline learning."""
super()._train_step()
(4) I try to compare the difference between the jax code and vanilla tf code, I find they use different repaly buffer (FixedReplayBuffer in JAX and WrappedFixedReplayBuffer in TF). I'm not sure whether this is the main reason.
Best
Hi I find the update_period is 1 and the tf code is 4. Maybe that is the main reason.
Yeah, update_period 1 corresponds to 1 gradient step every environment step (default is 4 which corresponds to 1 grad step every env step). In each iteration, we do 62.5K grad steps, so we can also set num_training_steps to 62.5K with update period 1.
Hi @agarwl Thanks for your reply. I will try it. By the way, I would like to ask can I run the TF code on TPU-VM? Since I find TF is still a little bit faster.
Sure -- you may not see much benefit of using TPUs (due to small batch size and dopamine replay) but the code be run on TPU.
Here's some JAX code for reference: https://github.com/google/dopamine/tree/master/dopamine/labs/offline_rl
@agarwl Thank you very much! I will have a try. Thanks!