batch_rl icon indicating copy to clipboard operation
batch_rl copied to clipboard

JAX code

Open lucasliunju opened this issue 2 years ago • 11 comments

Hi,

I would like to ask whether there is a jax-based code.

And whether there are some recommendations about jax-based offline rl algorithms.

Thanks!

lucasliunju avatar Sep 08 '22 09:09 lucasliunju

Releasing the JAX code might take some time but it should be easy to modify existing dopamine agents. In the meanwhile, here are some tips to get started with jax agents:

  • You can easily modify the JAX dopamine agents to run offline RL on Atari datasets.
  • For the replay buffer, you can use the class FixedReplayBuffer in fixed_replay_buffer.py for creating the offline buffer

@gin.configurable
class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent):
  """A JAX implementation of the Offline DQN agent."""

  def __init__(self,
               num_actions,
               replay_data_dir,
               summary_writer=None):
    """Initializes the agent and constructs the necessary components.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      replay_data_dir: str, log Directory from which to load the replay buffer.
      summary_writer: SummaryWriter object for outputting training statistics
    """
    logging.info('Creating %s agent with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay directory: %s', replay_data_dir)
    self.replay_data_dir = replay_data_dir
    super().__init__(
        num_actions, update_period=1, summary_writer=summary_writer)

  def _build_replay_buffer(self):
    """Creates the fixed replay buffer used by the agent."""
    return fixed_replay_buffer.FixedReplayBuffer(
        data_dir=self.replay_data_dir,
        observation_shape=self.observation_shape,
        stack_size=self.stack_size,
        update_horizon=self.update_horizon,
        gamma=self.gamma,
        observation_dtype=self.observation_dtype)

  def reload_data(self):
    # This needs to be called every iteration to subsample a portion of the dataset.
    self._replay.reload_data()

  def step(self, reward, observation):
    """Returns the agent's next action and update agent's state.

    Args:
      reward: float, the reward received from the agent's most recent action.
      observation: numpy array, the most recent observation.

    Returns:
      int, the selected action.
    """
    self._record_observation(observation)
    self._rng, self.action = dqn_agent.select_action(
        self.network_def, self.online_params, self.state, self._rng,
        self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
        self.epsilon_decay_period, self.training_steps, self.min_replay_history,
        self.epsilon_fn)
    self.action = onp.asarray(self.action)
    return self.action

  def train_step(self):
    """Exposes the train step for offline learning."""
    super()._train_step()

agarwl avatar Sep 10 '22 03:09 agarwl

Thank you very much! May I ask if I can also run the code in TPU-VM with JAX?

Best, Lucas

lucasliunju avatar Sep 14 '22 02:09 lucasliunju

I think so -- you probably want to use the tfds datasets or use much larger batch sizes with the dopamine codebase.

agarwl avatar Sep 14 '22 04:09 agarwl

Thank you very much! I'll have a try.

lucasliunju avatar Sep 14 '22 04:09 lucasliunju

Dear agarwl,

I try to follow your provided code and reproduce the results of offline dqn based on jax. I find the training speed of jax is quite slow compared with TensorFlow. May I ask the possible reason about that. I try to change these parts in the vanilla dopamine code: (1) I try to rewrite the Runner in dopamine/dopamine/discrete_domains/run_experiment.py based on the code in batch_rl/batch_rl/fixed_replay/run_experiment.py:

@gin.configurable
class FixedReplayRunner(run_experiment.Runner):
  """Object that handles running Dopamine experiments with fixed replay buffer."""

  def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
    super(FixedReplayRunner, self)._initialize_checkpointer_and_maybe_resume(
        checkpoint_file_prefix)

    # Code for the loading a checkpoint at initialization
    init_checkpoint_dir = self._agent._init_checkpoint_dir  # pylint: disable=protected-access
    if (self._start_iteration == 0) and (init_checkpoint_dir is not None):
      if checkpointer.get_latest_checkpoint_number(self._checkpoint_dir) < 0:
        # No checkpoint loaded yet, read init_checkpoint_dir
        init_checkpointer = checkpointer.Checkpointer(
            init_checkpoint_dir, checkpoint_file_prefix)
        latest_init_checkpoint = checkpointer.get_latest_checkpoint_number(
            init_checkpoint_dir)
        if latest_init_checkpoint >= 0:
          experiment_data = init_checkpointer.load_checkpoint(
              latest_init_checkpoint)
          if self._agent.unbundle(
              init_checkpoint_dir, latest_init_checkpoint, experiment_data):
            if experiment_data is not None:
              assert 'logs' in experiment_data
              assert 'current_iteration' in experiment_data
              self._logger.data = experiment_data['logs']
              self._start_iteration = experiment_data['current_iteration'] + 1
            tf.logging.info(
                'Reloaded checkpoint from %s and will start from iteration %d',
                init_checkpoint_dir, self._start_iteration)

  def _run_train_phase(self):
    """Run training phase."""
    self._agent.eval_mode = False
    start_time = time.time()
    for _ in range(self._training_steps):
      self._agent._train_step()  # pylint: disable=protected-access
    time_delta = time.time() - start_time
    tf.logging.info('Average training steps per second: %.2f',
                    self._training_steps / time_delta)

  def _run_one_iteration(self, iteration):
    """Runs one iteration of agent/environment interaction."""
    statistics = iteration_statistics.IterationStatistics()
    tf.logging.info('Starting iteration %d', iteration)
    # pylint: disable=protected-access
    if not self._agent._replay_suffix:
      # Reload the replay buffer
      self._agent._replay.memory.reload_buffer(num_buffers=5)
    # pylint: enable=protected-access
    self._run_train_phase()

    num_episodes_eval, average_reward_eval = self._run_eval_phase(statistics)

    self._save_tensorboard_summaries(
        iteration, num_episodes_eval, average_reward_eval)
    return statistics.data_lists

  def _save_tensorboard_summaries(self, iteration,
                                  num_episodes_eval,
                                  average_reward_eval):
    """Save statistics as tensorboard summaries.
    Args:
      iteration: int, The current iteration number.
      num_episodes_eval: int, number of evaluation episodes run.
      average_reward_eval: float, The average evaluation reward.
    """
    summary = tf.Summary(value=[
        tf.Summary.Value(tag='Eval/NumEpisodes',
                         simple_value=num_episodes_eval),
        tf.Summary.Value(tag='Eval/AverageReturns',
                         simple_value=average_reward_eval)
    ])
    self._summary_writer.add_summary(summary, iteration)

(2) creat offline buffer: fixed_replay_buffer.py

(3) create OfflineJaxDQNAgent:

@gin.configurable
class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent):
  """A JAX implementation of the Offline DQN agent."""

  def __init__(self,
               num_actions,
               replay_data_dir,
               summary_writer=None):
    """Initializes the agent and constructs the necessary components.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      replay_data_dir: str, log Directory from which to load the replay buffer.
      summary_writer: SummaryWriter object for outputting training statistics
    """
    logging.info('Creating %s agent with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay directory: %s', replay_data_dir)
    self.replay_data_dir = replay_data_dir
    super().__init__(
        num_actions, update_period=1, summary_writer=summary_writer)

  def _build_replay_buffer(self):
    """Creates the fixed replay buffer used by the agent."""
    return fixed_replay_buffer.FixedReplayBuffer(
        data_dir=self.replay_data_dir,
        observation_shape=self.observation_shape,
        stack_size=self.stack_size,
        update_horizon=self.update_horizon,
        gamma=self.gamma,
        observation_dtype=self.observation_dtype)

  def reload_data(self):
    # This needs to be called every iteration to subsample a portion of the dataset.
    self._replay.reload_data()

  def step(self, reward, observation):
    """Returns the agent's next action and update agent's state.

    Args:
      reward: float, the reward received from the agent's most recent action.
      observation: numpy array, the most recent observation.

    Returns:
      int, the selected action.
    """
    self._record_observation(observation)
    self._rng, self.action = dqn_agent.select_action(
        self.network_def, self.online_params, self.state, self._rng,
        self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
        self.epsilon_decay_period, self.training_steps, self.min_replay_history,
        self.epsilon_fn)
    self.action = onp.asarray(self.action)
    return self.action

  def train_step(self):
    """Exposes the train step for offline learning."""
    super()._train_step()

(4) I try to compare the difference between the jax code and vanilla tf code, I find they use different repaly buffer (FixedReplayBuffer in JAX and WrappedFixedReplayBuffer in TF). I'm not sure whether this is the main reason.

Best

lucasliunju avatar Jan 27 '23 09:01 lucasliunju

Hi I find the update_period is 1 and the tf code is 4. Maybe that is the main reason.

lucasliunju avatar Jan 28 '23 06:01 lucasliunju

Yeah, update_period 1 corresponds to 1 gradient step every environment step (default is 4 which corresponds to 1 grad step every env step). In each iteration, we do 62.5K grad steps, so we can also set num_training_steps to 62.5K with update period 1.

agarwl avatar Jan 28 '23 06:01 agarwl

Hi @agarwl Thanks for your reply. I will try it. By the way, I would like to ask can I run the TF code on TPU-VM? Since I find TF is still a little bit faster.

lucasliunju avatar Jan 28 '23 06:01 lucasliunju

Sure -- you may not see much benefit of using TPUs (due to small batch size and dopamine replay) but the code be run on TPU.

agarwl avatar Feb 09 '23 13:02 agarwl

Here's some JAX code for reference: https://github.com/google/dopamine/tree/master/dopamine/labs/offline_rl

agarwl avatar Apr 10 '23 09:04 agarwl

@agarwl Thank you very much! I will have a try. Thanks!

lucasliunju avatar Apr 10 '23 11:04 lucasliunju