RLeXplore
RLeXplore copied to clipboard
Correct usage with SB3 / Callbacks?
Hi, this looks like a really interesting set of algorithms. I wanted to try some out using the SB3-zoo and was hoping for a plug-and-play approach. I wondered if I could integrate rlexplore using callbacks so I came up with the following:
from stable_baselines3.common.callbacks import BaseCallback
from rlexplore import REVD
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
class RLeXploreCallback(BaseCallback):
def __init__(self):
super().__init__()
self.explorer = None
self.buffer = None
pass
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
super().init_callback(model)
env = self.training_env
self.explorer = REVD(obs_shape=env.observation_space.shape, action_shape=env.action_space.shape, device=model.device, latent_dim=128, beta=1e-2, kappa=1e-5)
if isinstance(self.model, OnPolicyAlgorithm):
self.buffer = self.model.rollout_buffer
elif isinstance(self.model, OffPolicyAlgorithm):
self.buffer = self.model.replay_buffer
pass
def _on_rollout_end(self) -> None:
intrinsic_rewards = self.explorer.compute_irs(
rollouts={'observations': self.buffer.observations},
time_steps=self.num_timesteps,
k=3)
self.buffer.rewards += intrinsic_rewards[:, :, 0]
pass
def _on_step(self) -> bool:
# TODO maybe log to TensorBoard?
return True
Then I include it in my list of callbacks and it seems to run. However, I'm still poking around without fully understanding what I'm doing (dangerous!) so does the above look correct? If it is correct, maybe it can be added as an example for others.
Second question is did I do this bit right: time_steps=self.num_timesteps
?
Third question I have is that in the examples directory the sample uses rollout_buffer
but is it valid to use this for Off Policy algorithms like DQN (switching for the replay_buffer
instead?)