RLeXplore icon indicating copy to clipboard operation
RLeXplore copied to clipboard

Correct usage with SB3 / Callbacks?

Open emrul opened this issue 1 year ago • 1 comments

Hi, this looks like a really interesting set of algorithms. I wanted to try some out using the SB3-zoo and was hoping for a plug-and-play approach. I wondered if I could integrate rlexplore using callbacks so I came up with the following:

from stable_baselines3.common.callbacks import BaseCallback
from rlexplore import REVD
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm

class RLeXploreCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.explorer = None
        self.buffer = None
        pass

    def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
        super().init_callback(model)
        env = self.training_env
        self.explorer = REVD(obs_shape=env.observation_space.shape, action_shape=env.action_space.shape, device=model.device, latent_dim=128, beta=1e-2, kappa=1e-5)

        if isinstance(self.model, OnPolicyAlgorithm):
            self.buffer = self.model.rollout_buffer
        elif isinstance(self.model, OffPolicyAlgorithm):
            self.buffer = self.model.replay_buffer
        pass

    def _on_rollout_end(self) -> None:
        intrinsic_rewards = self.explorer.compute_irs(
            rollouts={'observations': self.buffer.observations},
            time_steps=self.num_timesteps,
            k=3)
        self.buffer.rewards += intrinsic_rewards[:, :, 0]
        pass

    def _on_step(self) -> bool:
        # TODO maybe log to TensorBoard?
        return True


Then I include it in my list of callbacks and it seems to run. However, I'm still poking around without fully understanding what I'm doing (dangerous!) so does the above look correct? If it is correct, maybe it can be added as an example for others.

Second question is did I do this bit right: time_steps=self.num_timesteps?

Third question I have is that in the examples directory the sample uses rollout_buffer but is it valid to use this for Off Policy algorithms like DQN (switching for the replay_buffer instead?)

emrul avatar Feb 02 '23 13:02 emrul