stable-baselines icon indicating copy to clipboard operation
stable-baselines copied to clipboard

When using HER + SAC, every call to learn massively decreases performance

Open siferati opened this issue 4 years ago • 6 comments

I didn't test other algorithms, so I'm not sure if this is a problem with HER, with SAC, with the combination of both or if it's a problem with all available algorithms.

I first noticed this on my custom environment, so in order to make sure it wasn't a problem on my end, I also tested it using BitFlippingEnv.

Take the example below:

from stable_baselines import HER, SAC
from stable_baselines.common.bit_flipping_env import BitFlippingEnv

env = BitFlippingEnv(continuous = True)

model = HER(
    policy = 'MlpPolicy',
    env = env,
    model_class = SAC,
    n_sampled_goal = 4,
    goal_selection_strategy = 'future',
    verbose = 1

while True:
    model.learn(2000, log_interval = 1)

Every subsequent call to learn will massively impact the algorithm performance. On my computer, the 1st call runs at approximately 315 fps, the 2nd at 275, the 3rd at 150 and the 4th at 50.

Any way to fix this?

siferati avatar Mar 09 '20 13:03 siferati


Thanks for reporting the issue.

I tried the following code (note the learning_starts=0 to avoid wrong estimation of the FPS)

import time

from stable_baselines import HER, SAC
from stable_baselines.common.bit_flipping_env import BitFlippingEnv

env = BitFlippingEnv(continuous = True)

model = HER(
    policy = 'MlpPolicy',
    env = env,
    model_class = SAC,
    n_sampled_goal = 4,
    goal_selection_strategy = 'future',
    verbose = 0,
    learning_starts = 0

i = 0
while True:
    i += 1
    start_time = time.time()
    model.learn(2000, log_interval = 1)
    print(f"Iteration {i} Took {time.time() - start_time:.2f}s")


Iteration 1 Took 19.42s
Iteration 2 Took 20.85s
Iteration 3 Took 25.50s
Iteration 4 Took 48.50s

so a total of ~115s

In comparison, time taken by 4 * 2000 = 8000 steps:


araffin avatar Mar 09 '20 15:03 araffin

I don't have the time to deal with this issue now, but you could use line profiler to check what is taking so much time.

araffin avatar Mar 09 '20 15:03 araffin

I profiled the call to SAC's learn method using the lib you linked. The first experiment consists of training for 2000 timesteps 3 times (using the code you posted). The second experiment consists of training for 6000 timesteps 1 time.

By comparing the two, the largest difference seems to be with the call self.replay_buffer.add(obs, action, reward, new_obs, float(done)), where the 1st experiment takes 10 times longer than the 2nd experiment.

x3 2000ts:

Line #      Hits         Time  Per Hit   % Time  
   419      6000   19194204.0   3199.0     39.6

x1 6000ts:

Line #      Hits         Time  Per Hit   % Time  
   419      6000    1823569.0    303.9      5.8

x3 2000ts logs

Iteration 1 Took 10.88s
Iteration 2 Took 13.01s
Iteration 3 Took 25.11s
Wrote profile results to
Timer unit: 1e-06 s

Total time: 48.4257 s
File: /home/tirafesi/.local/lib/python3.6/site-packages/stable_baselines/sac/
Function: learn at line 356

Line #      Hits         Time  Per Hit   % Time  Line Contents
   356                                               @profile
   357                                               def learn(self, total_timesteps, callback=None,
   358                                                         log_interval=4, tb_log_name="SAC", reset_num_timesteps=True, replay_wrapper=None):
   360         3         15.0      5.0      0.0          new_tb_log = self._init_num_timesteps(reset_num_timesteps)
   361         3         87.0     29.0      0.0          callback = self._init_callback(callback)
   363         3          5.0      1.7      0.0          if replay_wrapper is not None:
   364         3         25.0      8.3      0.0              self.replay_buffer = replay_wrapper(self.replay_buffer)
   366         3        102.0     34.0      0.0          with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
   367         3          6.0      2.0      0.0                  as writer:
   369         3         20.0      6.7      0.0              self._setup_learn()
   371                                                       # Transform to callable if needed
   372         3         18.0      6.0      0.0              self.learning_rate = get_schedule_fn(self.learning_rate)
   373                                                       # Initial learning rate
   374         3          8.0      2.7      0.0              current_lr = self.learning_rate(1)
   376         3          8.0      2.7      0.0              start_time = time.time()
   377         3          5.0      1.7      0.0              episode_rewards = [0.0]
   378         3          5.0      1.7      0.0              episode_successes = []
   379         3          6.0      2.0      0.0              if self.action_noise is not None:
   380                                                           self.action_noise.reset()
   381         3        137.0     45.7      0.0              obs = self.env.reset()
   382         3          5.0      1.7      0.0              n_updates = 0
   383         3          6.0      2.0      0.0              infos_values = []
   385         3         29.0      9.7      0.0              callback.on_training_start(locals(), globals())
   386         3         10.0      3.3      0.0              callback.on_rollout_start()
   388      6003      11853.0      2.0      0.0              for step in range(total_timesteps):
   389                                                           # Before training starts, randomly sample actions
   390                                                           # from a uniform distribution for better exploration.
   391                                                           # Afterwards, use the learned policy
   392                                                           # if random_exploration is set to 0 (normal setting)
   393      6000      45542.0      7.6      0.1                  if self.num_timesteps < self.learning_starts or np.random.rand() < self.random_exploration:
   394                                                               # actions sampled from action space are from range specific to the environment
   395                                                               # but algorithm operates on tanh-squashed actions therefore simple scaling is used
   396                                                               unscaled_action = self.env.action_space.sample()
   397                                                               action = scale_action(self.action_space, unscaled_action)
   398                                                           else:
   399      6000    2997316.0    499.6      6.2                      action = self.policy_tf.step(obs[None], deterministic=False).flatten()
   400                                                               # Add noise to the action (improve exploration,
   401                                                               # not needed in general)
   402      6000      18399.0      3.1      0.0                      if self.action_noise is not None:
   403                                                                   action = np.clip(action + self.action_noise(), -1, 1)
   404                                                               # inferred actions need to be transformed to environment action_space before stepping
   405      6000     127251.0     21.2      0.3                      unscaled_action = unscale_action(self.action_space, action)
   407      6000      22902.0      3.8      0.0                  assert action.shape == self.env.action_space.shape
   409      6000     363003.0     60.5      0.7                  new_obs, reward, done, info = self.env.step(unscaled_action)
   411      6000      15849.0      2.6      0.0                  self.num_timesteps += 1
   413                                                           # Only stop training if return value is False, not when it is None. This is for backwards
   414                                                           # compatibility with callbacks that have no return statement.
   415      6000      35909.0      6.0      0.1                  if callback.on_step() is False:
   416                                                               break
   418                                                           # Store transition in the replay buffer.
   419      6000   19194204.0   3199.0     39.6                  self.replay_buffer.add(obs, action, reward, new_obs, float(done))
   420      6000      11240.0      1.9      0.0                  obs = new_obs
   422                                                           # Retrieve reward and episode length if using Monitor wrapper
   423      6000      12410.0      2.1      0.0                  maybe_ep_info = info.get('episode')
   424      6000      10589.0      1.8      0.0                  if maybe_ep_info is not None:
   425                                                               self.ep_info_buf.extend([maybe_ep_info])
   427      6000      10360.0      1.7      0.0                  if writer is not None:
   428                                                               # Write reward per episode to tensorboard
   429                                                               ep_reward = np.array([reward]).reshape((1, -1))
   430                                                               ep_done = np.array([done]).reshape((1, -1))
   431                                                               tf_util.total_episode_reward_logger(self.episode_reward, ep_reward,
   432                                                                                                   ep_done, writer, self.num_timesteps)
   434      6000      12799.0      2.1      0.0                  if step % self.train_freq == 0:
   435      6000      20897.0      3.5      0.0                      callback.on_rollout_end()
   437      6000      15366.0      2.6      0.0                      mb_infos_vals = []
   438                                                               # Update policy, critics and target networks
   439     11981      36954.0      3.1      0.1                      for grad_step in range(self.gradient_steps):
   440                                                                   # Break if the warmup phase is not over
   441                                                                   # or if there are not enough samples in the replay buffer
   442      6000      43728.0      7.3      0.1                          if not self.replay_buffer.can_sample(self.batch_size) \
   443      5981      11525.0      1.9      0.0                             or self.num_timesteps < self.learning_starts:
   444        19         43.0      2.3      0.0                              break
   445      5981      11551.0      1.9      0.0                          n_updates += 1
   446                                                                   # Compute current learning_rate
   447      5981      14719.0      2.5      0.0                          frac = 1.0 - step / total_timesteps
   448      5981      18428.0      3.1      0.0                          current_lr = self.learning_rate(frac)
   449                                                                   # Update policy and critics (q functions)
   450      5981   20661504.0   3454.5     42.7                          mb_infos_vals.append(self._train_step(step, writer, current_lr))
   451                                                                   # Update target network
   452      5981      21779.0      3.6      0.0                          if (step + grad_step) % self.target_update_interval == 0:
   453                                                                       # Update target network
   454      5981    3861313.0    645.6      8.0                    
   455                                                               # Log losses and entropy, useful for monitor training
   456      6000      13265.0      2.2      0.0                      if len(mb_infos_vals) > 0:
   457      5981     372618.0     62.3      0.8                          infos_values = np.mean(mb_infos_vals, axis=0)
   459      6000      29575.0      4.9      0.1                      callback.on_rollout_start()
   461      6000      16097.0      2.7      0.0                  episode_rewards[-1] += reward
   462      6000      10733.0      1.8      0.0                  if done:
   463       614       1200.0      2.0      0.0                      if self.action_noise is not None:
   464                                                                   self.action_noise.reset()
   465       614       6498.0     10.6      0.0                      if not isinstance(self.env, VecEnv):
   466       614      25621.0     41.7      0.1                          obs = self.env.reset()
   467       614       1472.0      2.4      0.0                      episode_rewards.append(0.0)
   469       614       1262.0      2.1      0.0                      maybe_is_success = info.get('is_success')
   470       614       1070.0      1.7      0.0                      if maybe_is_success is not None:
   471       614       1623.0      2.6      0.0                          episode_successes.append(float(maybe_is_success))
   473      6000      26085.0      4.3      0.1                  if len(episode_rewards[-101:-1]) == 0:
   474        27         64.0      2.4      0.0                      mean_reward = -np.inf
   475                                                           else:
   476      5973     282793.0     47.3      0.6                      mean_reward = round(float(np.mean(episode_rewards[-101:-1])), 1)
   478      6000      14647.0      2.4      0.0                  num_episodes = len(episode_rewards)
   479                                                           # Display training infos
   480      6000      13067.0      2.2      0.0                  if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0:
   481                                                               fps = int(step / (time.time() - start_time))
   482                                                               logger.logkv("episodes", num_episodes)
   483                                                               logger.logkv("mean 100 episode reward", mean_reward)
   484                                                               if len(self.ep_info_buf) > 0 and len(self.ep_info_buf[0]) > 0:
   485                                                                   logger.logkv('ep_rewmean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buf]))
   486                                                                   logger.logkv('eplenmean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buf]))
   487                                                               logger.logkv("n_updates", n_updates)
   488                                                               logger.logkv("current_lr", current_lr)
   489                                                               logger.logkv("fps", fps)
   490                                                               logger.logkv('time_elapsed', int(time.time() - start_time))
   491                                                               if len(episode_successes) > 0:
   492                                                                   logger.logkv("success rate", np.mean(episode_successes[-100:]))
   493                                                               if len(infos_values) > 0:
   494                                                                   for (name, val) in zip(self.infos_names, infos_values):
   495                                                                       logger.logkv(name, val)
   496                                                               logger.logkv("total timesteps", self.num_timesteps)
   497                                                               logger.dumpkvs()
   498                                                               # Reset infos:
   499                                                               infos_values = []
   500         3         12.0      4.0      0.0              callback.on_training_end()
   501         3         90.0     30.0      0.0              return self

x1 6000ts logs

Iteration 1 Took 31.81s
Wrote profile results to
Timer unit: 1e-06 s

Total time: 31.2378 s
File: /home/tirafesi/.local/lib/python3.6/site-packages/stable_baselines/sac/
Function: learn at line 356

Line #      Hits         Time  Per Hit   % Time  Line Contents
   356                                               @profile
   357                                               def learn(self, total_timesteps, callback=None,
   358                                                         log_interval=4, tb_log_name="SAC", reset_num_timesteps=True, replay_wrapper=None):
   360         1          6.0      6.0      0.0          new_tb_log = self._init_num_timesteps(reset_num_timesteps)
   361         1         30.0     30.0      0.0          callback = self._init_callback(callback)
   363         1          2.0      2.0      0.0          if replay_wrapper is not None:
   364         1          8.0      8.0      0.0              self.replay_buffer = replay_wrapper(self.replay_buffer)
   366         1         35.0     35.0      0.0          with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
   367         1          1.0      1.0      0.0                  as writer:
   369         1         10.0     10.0      0.0              self._setup_learn()
   371                                                       # Transform to callable if needed
   372         1          8.0      8.0      0.0              self.learning_rate = get_schedule_fn(self.learning_rate)
   373                                                       # Initial learning rate
   374         1          2.0      2.0      0.0              current_lr = self.learning_rate(1)
   376         1          2.0      2.0      0.0              start_time = time.time()
   377         1          2.0      2.0      0.0              episode_rewards = [0.0]
   378         1          1.0      1.0      0.0              episode_successes = []
   379         1          2.0      2.0      0.0              if self.action_noise is not None:
   380                                                           self.action_noise.reset()
   381         1         49.0     49.0      0.0              obs = self.env.reset()
   382         1          2.0      2.0      0.0              n_updates = 0
   383         1          1.0      1.0      0.0              infos_values = []
   385         1         10.0     10.0      0.0              callback.on_training_start(locals(), globals())
   386         1          2.0      2.0      0.0              callback.on_rollout_start()
   388      6001      12419.0      2.1      0.0              for step in range(total_timesteps):
   389                                                           # Before training starts, randomly sample actions
   390                                                           # from a uniform distribution for better exploration.
   391                                                           # Afterwards, use the learned policy
   392                                                           # if random_exploration is set to 0 (normal setting)
   393      6000      43698.0      7.3      0.1                  if self.num_timesteps < self.learning_starts or np.random.rand() < self.random_exploration:
   394                                                               # actions sampled from action space are from range specific to the environment
   395                                                               # but algorithm operates on tanh-squashed actions therefore simple scaling is used
   396                                                               unscaled_action = self.env.action_space.sample()
   397                                                               action = scale_action(self.action_space, unscaled_action)
   398                                                           else:
   399      6000    2961251.0    493.5      9.5                      action = self.policy_tf.step(obs[None], deterministic=False).flatten()
   400                                                               # Add noise to the action (improve exploration,
   401                                                               # not needed in general)
   402      6000      23325.0      3.9      0.1                      if self.action_noise is not None:
   403                                                                   action = np.clip(action + self.action_noise(), -1, 1)
   404                                                               # inferred actions need to be transformed to environment action_space before stepping
   405      6000     124964.0     20.8      0.4                      unscaled_action = unscale_action(self.action_space, action)
   407      6000      22296.0      3.7      0.1                  assert action.shape == self.env.action_space.shape
   409      6000     359055.0     59.8      1.1                  new_obs, reward, done, info = self.env.step(unscaled_action)
   411      6000      16529.0      2.8      0.1                  self.num_timesteps += 1
   413                                                           # Only stop training if return value is False, not when it is None. This is for backwards
   414                                                           # compatibility with callbacks that have no return statement.
   415      6000      31799.0      5.3      0.1                  if callback.on_step() is False:
   416                                                               break
   418                                                           # Store transition in the replay buffer.
   419      6000    1823569.0    303.9      5.8                  self.replay_buffer.add(obs, action, reward, new_obs, float(done))
   420      6000      11180.0      1.9      0.0                  obs = new_obs
   422                                                           # Retrieve reward and episode length if using Monitor wrapper
   423      6000      13268.0      2.2      0.0                  maybe_ep_info = info.get('episode')
   424      6000      11084.0      1.8      0.0                  if maybe_ep_info is not None:
   425                                                               self.ep_info_buf.extend([maybe_ep_info])
   427      6000      10821.0      1.8      0.0                  if writer is not None:
   428                                                               # Write reward per episode to tensorboard
   429                                                               ep_reward = np.array([reward]).reshape((1, -1))
   430                                                               ep_done = np.array([done]).reshape((1, -1))
   431                                                               tf_util.total_episode_reward_logger(self.episode_reward, ep_reward,
   432                                                                                                   ep_done, writer, self.num_timesteps)
   434      6000      12777.0      2.1      0.0                  if step % self.train_freq == 0:
   435      6000      19612.0      3.3      0.1                      callback.on_rollout_end()
   437      6000      15610.0      2.6      0.0                      mb_infos_vals = []
   438                                                               # Update policy, critics and target networks
   439     11981      36474.0      3.0      0.1                      for grad_step in range(self.gradient_steps):
   440                                                                   # Break if the warmup phase is not over
   441                                                                   # or if there are not enough samples in the replay buffer
   442      6000      36980.0      6.2      0.1                          if not self.replay_buffer.can_sample(self.batch_size) \
   443      5981      12206.0      2.0      0.0                             or self.num_timesteps < self.learning_starts:
   444        19         31.0      1.6      0.0                              break
   445      5981      12012.0      2.0      0.0                          n_updates += 1
   446                                                                   # Compute current learning_rate
   447      5981      15936.0      2.7      0.1                          frac = 1.0 - step / total_timesteps
   448      5981      20271.0      3.4      0.1                          current_lr = self.learning_rate(frac)
   449                                                                   # Update policy and critics (q functions)
   450      5981   20831587.0   3483.0     66.7                          mb_infos_vals.append(self._train_step(step, writer, current_lr))
   451                                                                   # Update target network
   452      5981      20737.0      3.5      0.1                          if (step + grad_step) % self.target_update_interval == 0:
   453                                                                       # Update target network
   454      5981    3906941.0    653.2     12.5                    
   455                                                               # Log losses and entropy, useful for monitor training
   456      6000      13947.0      2.3      0.0                      if len(mb_infos_vals) > 0:
   457      5981     379424.0     63.4      1.2                          infos_values = np.mean(mb_infos_vals, axis=0)
   459      6000      27764.0      4.6      0.1                      callback.on_rollout_start()
   461      6000      16663.0      2.8      0.1                  episode_rewards[-1] += reward
   462      6000      11280.0      1.9      0.0                  if done:
   463       613       1692.0      2.8      0.0                      if self.action_noise is not None:
   464                                                                   self.action_noise.reset()
   465       613       6921.0     11.3      0.0                      if not isinstance(self.env, VecEnv):
   466       613      26734.0     43.6      0.1                          obs = self.env.reset()
   467       613       1518.0      2.5      0.0                      episode_rewards.append(0.0)
   469       613       1356.0      2.2      0.0                      maybe_is_success = info.get('is_success')
   470       613       1204.0      2.0      0.0                      if maybe_is_success is not None:
   471       613       1772.0      2.9      0.0                          episode_successes.append(float(maybe_is_success))
   473      6000      28444.0      4.7      0.1                  if len(episode_rewards[-101:-1]) == 0:
   474         9         19.0      2.1      0.0                      mean_reward = -np.inf
   475                                                           else:
   476      5991     283515.0     47.3      0.9                      mean_reward = round(float(np.mean(episode_rewards[-101:-1])), 1)
   478      6000      15860.0      2.6      0.1                  num_episodes = len(episode_rewards)
   479                                                           # Display training infos
   480      6000      13083.0      2.2      0.0                  if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0:
   481                                                               fps = int(step / (time.time() - start_time))
   482                                                               logger.logkv("episodes", num_episodes)
   483                                                               logger.logkv("mean 100 episode reward", mean_reward)
   484                                                               if len(self.ep_info_buf) > 0 and len(self.ep_info_buf[0]) > 0:
   485                                                                   logger.logkv('ep_rewmean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buf]))
   486                                                                   logger.logkv('eplenmean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buf]))
   487                                                               logger.logkv("n_updates", n_updates)
   488                                                               logger.logkv("current_lr", current_lr)
   489                                                               logger.logkv("fps", fps)
   490                                                               logger.logkv('time_elapsed', int(time.time() - start_time))
   491                                                               if len(episode_successes) > 0:
   492                                                                   logger.logkv("success rate", np.mean(episode_successes[-100:]))
   493                                                               if len(infos_values) > 0:
   494                                                                   for (name, val) in zip(self.infos_names, infos_values):
   495                                                                       logger.logkv(name, val)
   496                                                               logger.logkv("total timesteps", self.num_timesteps)
   497                                                               logger.dumpkvs()
   498                                                               # Reset infos:
   499                                                               infos_values = []
   500         1          5.0      5.0      0.0              callback.on_training_end()
   501         1         33.0     33.0      0.0              return self

siferati avatar Mar 09 '20 16:03 siferati

Thanks @tirafesi , I assume that replacing the list-based replay buffer by numpy-based replay buffer would solve the issue... You have an example of it in the tf2 draft:

The v3 will rely on that type of buffer.

araffin avatar Mar 09 '20 16:03 araffin

@toksis I will delete the comments as it is not related to stable-baselines nor this issue but to the line-profiler

araffin avatar Mar 10 '20 20:03 araffin

Apparently, the problem is solved in v3: because of the new replay buffer implementation.

araffin avatar Jun 05 '20 19:06 araffin