imitation icon indicating copy to clipboard operation
imitation copied to clipboard

Changes to support image-based environments

Open qxcv opened this issue 5 years ago • 5 comments

This PR adds some features necessary to get clean support for image-based environments:

  • Switched to a (temporary) fork of SB3 that removes transpose magic. This makes it simpler for our algorithms to interoperate with SB3 when working with image-based environments.
  • BC policies and GAIL discriminators can now use augmentations at training time. I haven't added augmentation support for augmentations in GAIL policies because (1) doing so is hard, and (2) I haven't found it works well in MAGICAL.
  • A bunch of other related changes that we needed while working on the il-representations project.

qxcv avatar Aug 17 '20 21:08 qxcv

Codecov Report

Merging #221 (3d895c9) into master (4ccc96f) will decrease coverage by 6.96%. The diff coverage is 85.51%.

:exclamation: Current head 3d895c9 differs from pull request most recent head 5860e5c. Consider uploading reports for the commit 5860e5c to get more accurate results Impacted file tree graph

@@            Coverage Diff             @@
##           master     #221      +/-   ##
==========================================
- Coverage   96.53%   89.57%   -6.97%     
==========================================
  Files          73       81       +8     
  Lines        5539     5725     +186     
==========================================
- Hits         5347     5128     -219     
- Misses        192      597     +405     
Impacted Files Coverage Δ
tests/test_dagger.py 100.00% <ø> (ø)
src/imitation/augment/color.py 41.83% <41.83%> (ø)
src/imitation/data/rollout.py 95.15% <50.00%> (-4.85%) :arrow_down:
tests/test_envs.py 87.50% <57.14%> (-12.50%) :arrow_down:
src/imitation/algorithms/bc.py 91.76% <80.59%> (-8.24%) :arrow_down:
src/imitation/util/util.py 96.34% <86.36%> (-3.66%) :arrow_down:
src/imitation/algorithms/adversarial.py 95.16% <92.30%> (ø)
tests/test_adversarial.py 98.07% <94.44%> (ø)
src/imitation/augment/convenience.py 97.33% <97.33%> (ø)
src/imitation/augment/__init__.py 100.00% <100.00%> (ø)
... and 121 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 4ccc96f...5860e5c. Read the comment docs.

codecov[bot] avatar Aug 17 '20 23:08 codecov[bot]

Sorry this seems to have fallen by the wayside. I'm happy to review this if it's brought up to date with current master. Do we still need the SB3 fork or have those changes been upstreamed? If not I'd be happy to work with other SB3 maintainers to find a suitable fix.

AdamGleave avatar Aug 28 '21 23:08 AdamGleave

I'm pretty sure the SB3 fork is unnecessary now. IIRC something was just getting force-wrapped with vecenvs previously, but I think it's fixed. Can take another look at this later; going to set a reminder for Thursday.

qxcv avatar Aug 30 '21 21:08 qxcv

I'm in the process of (belatedly) bringing our SB3 and imitation branches even with master for the IL representations project. Here is a diff between our SB3 branch and the current SB3 master (as of October 18th). There are some irrelevant/spurious changes here; the main relevant ones that I see are:

  • A change to BaseAlgorithm that was meant to fix the FPS counter (it monotonically increases at the moment).
  • Addition of a dump_logs kwarg to OnPolicyAlgorithm that can be used to prevent it from dumping logs at every iteration (I forget why I wanted this; I think it might have been so I could dump GAIL discriminator and policy optimisation logs at the same time).
  • Added a target standard deviation for reward normalisation, and changed policy optimisation to ensure that the terminal observation is normalised too. The former change was helpful for us when tuning GAIL, but I don't think the latter change (to observation normalisation) was actually needed for our project, so we could remove it if desired.
  • Propagating unknown kwargs from PPO constructor to its parent (I'm not sure why it wasn't doing this already).

(I'm not going to try merging these into SB3 right away, but I imagine they'll be easy to merge at some point.)

diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py
index 1b38555..a667ef7 100644
--- a/stable_baselines3/common/base_class.py
+++ b/stable_baselines3/common/base_class.py
@@ -399,7 +399,9 @@ class BaseAlgorithm(ABC):
         :param eval_freq: How many steps between evaluations
         :param n_eval_episodes: How many episodes to play per evaluation
         :param log_path: Path to a folder where the evaluations will be saved
-        :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
+        :param reset_num_timesteps: Whether to reset or not the
+            ``num_timesteps`` attribute and related tracking information (e.g.
+            start time, episode number).
         :param tb_log_name: the name of the run for tensorboard log
         :return:
         """
@@ -413,6 +415,9 @@ class BaseAlgorithm(ABC):
         if self.action_noise is not None:
             self.action_noise.reset()
 
+        if reset_num_timesteps or self.start_time is None:
+            self.start_time = time.time()
+
         if reset_num_timesteps:
             self.num_timesteps = 0
             self._episode_num = 0
diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py
index 41e193d..9de57f0 100644
--- a/stable_baselines3/common/on_policy_algorithm.py
+++ b/stable_baselines3/common/on_policy_algorithm.py
@@ -223,6 +223,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
         tb_log_name: str = "OnPolicyAlgorithm",
         eval_log_path: Optional[str] = None,
         reset_num_timesteps: bool = True,
+        dump_logs: bool = True,
     ) -> "OnPolicyAlgorithm":
         iteration = 0
 
@@ -252,7 +253,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
                 self.logger.record("time/fps", fps)
                 self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
                 self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
-                self.logger.dump(step=self.num_timesteps)
+                if dump_logs:
+                    self.logger.dump(step=self.num_timesteps)
 
             self.train()
 
diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py
index 5eae7f5..f0db855 100644
--- a/stable_baselines3/common/vec_env/vec_normalize.py
+++ b/stable_baselines3/common/vec_env/vec_normalize.py
@@ -24,6 +24,8 @@ class VecNormalize(VecEnvWrapper):
     :param clip_reward: Max value absolute for discounted reward
     :param gamma: discount factor
     :param epsilon: To avoid division by zero
+    :param norm_reward_std: Target standard deviation for reward when
+        `norm_reward is True` (default: 1.0).
     """
 
     def __init__(
@@ -36,6 +38,7 @@ class VecNormalize(VecEnvWrapper):
         clip_reward: float = 10.0,
         gamma: float = 0.99,
         epsilon: float = 1e-8,
+        norm_reward_std: float = 1.0,
     ):
         VecEnvWrapper.__init__(self, venv)
 
@@ -54,6 +57,7 @@ class VecNormalize(VecEnvWrapper):
         self.ret_rms = RunningMeanStd(shape=())
         self.clip_obs = clip_obs
         self.clip_reward = clip_reward
+        self.norm_reward_std = 1.0
         # Returns: discounted rewards
         self.returns = np.zeros(self.num_envs)
         self.gamma = gamma
@@ -124,6 +128,15 @@ class VecNormalize(VecEnvWrapper):
 
         obs = self.normalize_obs(obs)
 
+        # terminal_observation doesn't count towards running mean, but we still
+        # normalize it
+        infos = list(infos)
+        for idx, info in enumerate(infos):
+            term_obs = info.get("terminal_observation")
+            if term_obs is not None:
+                infos[idx] = dict(info)
+                infos[idx]["terminal_observation"] = self.normalize_obs(term_obs)
+
         if self.training:
             self._update_reward(rewards)
         rewards = self.normalize_reward(rewards)
@@ -182,7 +195,8 @@ class VecNormalize(VecEnvWrapper):
         Calling this method does not update statistics.
         """
         if self.norm_reward:
-            reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
+            reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon) * self.norm_reward_std,
+                             -self.clip_reward, self.clip_reward)
         return reward
 
     def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py
index 9e16e04..5073636 100644
--- a/stable_baselines3/ppo/ppo.py
+++ b/stable_baselines3/ppo/ppo.py
@@ -294,6 +294,7 @@ class PPO(OnPolicyAlgorithm):
         tb_log_name: str = "PPO",
         eval_log_path: Optional[str] = None,
         reset_num_timesteps: bool = True,
+        **kwargs,
     ) -> "PPO":
 
         return super(PPO, self).learn(
@@ -306,4 +307,5 @@ class PPO(OnPolicyAlgorithm):
             tb_log_name=tb_log_name,
             eval_log_path=eval_log_path,
             reset_num_timesteps=reset_num_timesteps,
+            **kwargs,
         )

qxcv avatar Oct 18 '21 21:10 qxcv

Thanks for the summary Sam!

I'm in the process of (belatedly) bringing our SB3 and imitation branches even with master for the IL representations project. Here is a diff between our SB3 branch and the current SB3 master (as of October 18th). There are some irrelevant/spurious changes here; the main relevant ones that I see are:

  • A change to BaseAlgorithm that was meant to fix the FPS counter (it monotonically increases at the moment). Sounds like an uncontroversial bugfix.
  • Addition of a dump_logs kwarg to OnPolicyAlgorithm that can be used to prevent it from dumping logs at every iteration (I forget why I wanted this; I think it might have been so I could dump GAIL discriminator and policy optimisation logs at the same time). Not too objectionable but not sure maintainers will want this without a clear use case. Do we really need this when we have the hierarchical logger? Can just look at the mean stats and ignore the raw ones (we coudl even disable the raw ones being logged to stdout).
  • Added a target standard deviation for reward normalisation, and changed policy optimisation to ensure that the terminal observation is normalised too. The former change was helpful for us when tuning GAIL, but I don't think the latter change (to observation normalisation) was actually needed for our project, so we could remove it if desired.

Target SD for reward normalization: so like VecNormalize on reward then scaling it up/down? Does seem useful, not sure if it's best as a change to VecNormalize or as an additional wrapper.

Terminal observation not being normalized sounds like a bug so should probably upstream.

  • Propagating unknown kwargs from PPO constructor to its parent (I'm not sure why it wasn't doing this already). Seems like an unobjectionable fix.

AdamGleave avatar Oct 19 '21 02:10 AdamGleave

Closing in favor of #519

AdamGleave avatar Aug 23 '22 20:08 AdamGleave