stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Bug]: Higher memory usage on sequential training runs

Open NickLucche opened this issue 1 year ago • 6 comments

🐛 Bug

Hey, thanks a lot for your work! I am trying to debug an apparent memory leak/higher memory usage when running the training code multiple times, but I can't pinpoint its cause. I've boiled down my problem to the snippet below. Basically when starting sequential training runs I get a higher memory consumption than a single one, when I would expect all resources to be released after PPO object is collected. I believe the only real difference in this example is the obs and action space, which mimics my use case.

Single run memory usage model.learn(total_timesteps=500_000) image

Multi run memory usage model.learn(total_timesteps=25_000) N times. Crashes early due to OOM. image

To Reproduce

import gymnasium as gym
from gymnasium.wrappers.time_limit import TimeLimit
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv
from gymnasium.spaces import MultiDiscrete, Box
import numpy as np

OBS_SHAPE = (320, 320, 2)

class DummyEnv(gym.Env):
    def __init__(self, *args, **kwargs):
        super(DummyEnv, self).__init__()
        self.action_space = MultiDiscrete([6 * 8 + 1, 2, 2], dtype=np.uint8)
        self.observation_space = Box(low=0, high=255, shape=OBS_SHAPE, dtype=np.uint8)

    def reset(self, *args, **kwargs):
        return self.observation_space.sample(), {}

    def step(self, action, *args, **kwargs):
        assert self.action_space.contains(action), f"{action} ({type(action)}) invalid"
        state = self.observation_space.sample()
        reward = 0.0
        done = False
        return state, reward, done, False, {}
        

def make_env():
    env = DummyEnv()
    env = TimeLimit(env, 100)
    return Monitor(env)

def train(ts):
    vec_env = SubprocVecEnv([make_env for _ in range(12)])
    model = PPO("CnnPolicy", vec_env, verbose=1)
    model.learn(total_timesteps=ts)
    model.get_env().close()

if __name__ == "__main__":    
    for i in range(20):
        print("Starting", i)
        train(25_000)
        print(i, "finished")
    # train(500_000)

Relevant log output / Error message

No response

System Info

- OS: Linux-6.8.7-arch1-1-x86_64-with-glibc2.39 # 1 SMP PREEMPT_DYNAMIC Wed, 17 Apr 2024 15:20:28 +0000
- Python: 3.11.8
- Stable-Baselines3: 2.3.0
- PyTorch: 2.3.0+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1

Checklist

  • [X] My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
  • [X] I have checked that there is no similar issue in the repo
  • [X] I have read the documentation
  • [X] I have provided a minimal and working example to reproduce the bug
  • [X] I've used the markdown code blocks for both code and stack traces.

NickLucche avatar Jul 09 '24 07:07 NickLucche

Hello, why would you do instead of using a callback for instance? I'm also wondering why you would recreate the environment every time instead of just calling learn(..., reset_num_timesteps=False) (see our doc)?

Does the higher memory usage happened also when using a DummyVecEnv ?

araffin avatar Jul 10 '24 07:07 araffin

Code is structured that way because my (actual) environment depends on some initial seed/state, which I can use to simulate ~"unseen" data and test generalization. It's then very straightforward to re-use regular training script to train on different ""splits""/conditions i.e just calling train in a loop like that.

But I believe use-case is of secondary importance here if there's some actual unreleased resource we could address (assuming there's no blunt mistake on my side).

Does the higher memory usage happened also when using a DummyVecEnv ?

Yep, still happens even when

vec_env = make_vec_env(make_env, n_envs=12)

Decreased obs space to 256x256x2 to better highlight ramp-up before my system's OOM: image

NickLucche avatar Jul 10 '24 08:07 NickLucche

depends on some initial seed/state, which I can use to simulate ~"unseen" data and test generalization

.reset(seed=...) is made for that normally (.seed() for VecEnv and then do a reset)

araffin avatar Jul 10 '24 09:07 araffin

I have the same issue.

  • OS: Ubuntu 22.04.5 LTS
  • Python: 3.10.12
  • Stable-Baselines3: 2.3.0
  • PyTorch: 2.4.1+cu121
  • GPU Enabled: True
  • Numpy: 2.1.1
  • Cloudpickle: 3.0.0
  • Gymnasium: 0.29.1
Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   139    683.0 MiB    683.0 MiB           1       @profile
   140                                             def collect_rollouts(
   141                                                 self,
   142                                                 env: VecEnv,
   143                                                 callback: BaseCallback,
   144                                                 rollout_buffer: RolloutBuffer,
   145                                                 n_rollout_steps: int,
   146                                             ) -> bool:
   147                                                 """
   148                                                 Collect experiences using the current policy and fill a ``RolloutBuffer``.
   149                                                 The term rollout here refers to the model-free notion and should not
   150                                                 be used with the concept of rollout used in model-based RL or planning.
   151                                         
   152                                                 :param env: The training environment
   153                                                 :param callback: Callback that will be called at each step
   154                                                     (and at the beginning and end of the rollout)
   155                                                 :param rollout_buffer: Buffer to fill with rollouts
   156                                                 :param n_rollout_steps: Number of experiences to collect per environment
   157                                                 :return: True if function returned with at least `n_rollout_steps`
   158                                                     collected, False if callback terminated rollout prematurely.
   159                                                 """
   160    683.0 MiB      0.0 MiB           1           assert self._last_obs is not None, "No previous observation was provided"
   161                                                 # Switch to eval mode (this affects batch norm / dropout)
   162    683.0 MiB      0.0 MiB           1           self.policy.set_training_mode(False)
   163                                         
   164    683.0 MiB      0.0 MiB           1           n_steps = 0
   165    683.2 MiB      0.2 MiB           1           rollout_buffer.reset()
   166                                                 # Sample new weights for the state dependent exploration
   167    683.2 MiB      0.0 MiB           1           if self.use_sde:
   168                                                     self.policy.reset_noise(env.num_envs)
   169                                         
   170    683.2 MiB      0.0 MiB           1           callback.on_rollout_start()
   171                                         
   172  13234.2 MiB      0.0 MiB        1025           while n_steps < n_rollout_steps:
   173  13222.5 MiB      0.0 MiB        1024               if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
   174                                                         # Sample a new noise matrix
   175                                                         self.policy.reset_noise(env.num_envs)
   176                                         
   177  13222.5 MiB      0.0 MiB        2048               with th.no_grad():
   178                                                         # Convert to pytorch tensor or to TensorDict
   179  13222.5 MiB      0.0 MiB        1024                   obs_tensor = obs_as_tensor(self._last_obs, self.device)
   180  13222.5 MiB    432.8 MiB        1024                   actions, values, log_probs = self.policy(obs_tensor)
   181  13222.5 MiB      0.0 MiB        1024               actions = actions.cpu().numpy()
   182                                         
   183                                                     # Rescale and perform action
   184  13222.5 MiB      0.0 MiB        1024               clipped_actions = actions
   185                                         
   186  13222.5 MiB      0.0 MiB        1024               if isinstance(self.action_space, spaces.Box):
   187                                                         if self.policy.squash_output:
   188                                                             # Unscale the actions to match env bounds
   189                                                             # if they were previously squashed (scaled in [-1, 1])
   190                                                             clipped_actions = self.policy.unscale_action(clipped_actions)
   191                                                         else:
   192                                                             # Otherwise, clip the actions to avoid out of bound error
   193                                                             # as we are sampling from an unbounded Gaussian distribution
   194                                                             clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
   195                                         
   196  13222.5 MiB     25.3 MiB        1024               new_obs, rewards, dones, infos = env.step(clipped_actions)
   197                                         
   198  13222.5 MiB      0.0 MiB        1024               self.num_timesteps += env.num_envs
   199                                         
   200                                                     # Give access to local variables
   201  13222.5 MiB      0.0 MiB        1024               callback.update_locals(locals())
   202  13222.5 MiB      0.0 MiB        1024               if not callback.on_step():
   203                                                         return False
   204                                         
   205  13222.5 MiB      0.0 MiB        1024               self._update_info_buffer(infos, dones)
   206  13222.5 MiB      0.0 MiB        1024               n_steps += 1
   207                                         
   208  13222.5 MiB      0.0 MiB        1024               if isinstance(self.action_space, spaces.Discrete):
   209                                                         # Reshape in case of discrete action
   210  13222.5 MiB      0.0 MiB        1024                   actions = actions.reshape(-1, 1)
   211                                         
   212                                                     # Handle timeout by bootstraping with value function
   213                                                     # see GitHub issue #633
   214  13222.5 MiB      0.0 MiB       50176               for idx, done in enumerate(dones):
   215  13222.5 MiB      0.0 MiB       49152                   if (
   216  13222.5 MiB      0.0 MiB       49152                       done
   217  13222.5 MiB      0.0 MiB        4560                       and infos[idx].get("terminal_observation") is not None
   218  13222.5 MiB      0.0 MiB        2280                       and infos[idx].get("TimeLimit.truncated", False)
   219                                                         ):
   220                                                             terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
   221                                                             with th.no_grad():
   222                                                                 terminal_value = self.policy.predict_values(terminal_obs)[0]  # type: ignore[arg-type]
   223                                                             rewards[idx] += self.gamma * terminal_value
   224                                         
   225  13234.2 MiB  12095.8 MiB        2048               rollout_buffer.add(
   226  13222.5 MiB      0.0 MiB        1024                   self._last_obs,  # type: ignore[arg-type]
   227  13222.5 MiB      0.0 MiB        1024                   actions,
   228  13222.5 MiB      0.0 MiB        1024                   rewards,
   229  13222.5 MiB      0.0 MiB        1024                   self._last_episode_starts,  # type: ignore[arg-type]
   230  13222.5 MiB      0.0 MiB        1024                   values,
   231  13222.5 MiB      0.0 MiB        1024                   log_probs,
   232                                                     )
   233  13234.2 MiB     -2.8 MiB        1024               self._last_obs = new_obs  # type: ignore[assignment]
   234  13234.2 MiB      0.0 MiB        1024               self._last_episode_starts = dones
   235                                         
   236  13234.2 MiB      0.0 MiB           2           with th.no_grad():
   237                                                     # Compute value for the last timestep
   238  13234.2 MiB      0.0 MiB           1               values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))  # type: ignore[arg-type]
   239                                         
   240  13234.2 MiB      0.0 MiB           1           rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
   241                                         
   242  13234.2 MiB      0.0 MiB           1           callback.update_locals(locals())
   243                                         
   244  13234.2 MiB      0.0 MiB           1           callback.on_rollout_end()
   245                                         
   246  13234.2 MiB      0.0 MiB           1           return True

I profiled the memory usage of the code, I guess the buffer needs to be reset somewhere? But its not done?

NIvo172 avatar Sep 19 '24 14:09 NIvo172

@NickLucche I think the problem is not SB3, I think its that Python does not free the memory in the loop.

In your loop, you train 20 different models? In mine I run 20 loops and load the last save.


for train_rounds in range(20):
    if exists(model_savestate):
        model = PPO.load(model_savestate)
    else:
       model = PPO()

   model.learn()
   model.save()

This went OOM. I did the following:


import gc

for train_rounds in range(20):
    if exists(model_savestate):
        model = PPO.load(model_savestate)
    else:
       model = PPO()

   model.learn()
   model.save()
   del model # Dereference old model
   gc.collect() # Force free memory

Without the gc.collect() call the memory is still not freed. Maybe this helps you.

NIvo172 avatar Sep 20 '24 09:09 NIvo172

I have the same issue.

  • OS: Ubuntu 22.04.5 LTS
  • Python: 3.10.12
  • Stable-Baselines3: 2.3.0
  • PyTorch: 2.4.1+cu121
  • GPU Enabled: True
  • Numpy: 2.1.1
  • Cloudpickle: 3.0.0
  • Gymnasium: 0.29.1

Line # Mem usage Increment Occurrences Line Contents

139 683.0 MiB 683.0 MiB 1 @profile 140 def collect_rollouts( 141 self, 142 env: VecEnv, 143 callback: BaseCallback, 144 rollout_buffer: RolloutBuffer, 145 n_rollout_steps: int, 146 ) -> bool: 147 """ 148 Collect experiences using the current policy and fill a RolloutBuffer. 149 The term rollout here refers to the model-free notion and should not 150 be used with the concept of rollout used in model-based RL or planning. 151
152 :param env: The training environment 153 :param callback: Callback that will be called at each step 154 (and at the beginning and end of the rollout) 155 :param rollout_buffer: Buffer to fill with rollouts 156 :param n_rollout_steps: Number of experiences to collect per environment 157 :return: True if function returned with at least n_rollout_steps 158 collected, False if callback terminated rollout prematurely. 159 """ 160 683.0 MiB 0.0 MiB 1 assert self._last_obs is not None, "No previous observation was provided" 161 # Switch to eval mode (this affects batch norm / dropout) 162 683.0 MiB 0.0 MiB 1 self.policy.set_training_mode(False) 163
164 683.0 MiB 0.0 MiB 1 n_steps = 0 165 683.2 MiB 0.2 MiB 1 rollout_buffer.reset() 166 # Sample new weights for the state dependent exploration 167 683.2 MiB 0.0 MiB 1 if self.use_sde: 168 self.policy.reset_noise(env.num_envs) 169
170 683.2 MiB 0.0 MiB 1 callback.on_rollout_start() 171
172 13234.2 MiB 0.0 MiB 1025 while n_steps < n_rollout_steps: 173 13222.5 MiB 0.0 MiB 1024 if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: 174 # Sample a new noise matrix 175 self.policy.reset_noise(env.num_envs) 176
177 13222.5 MiB 0.0 MiB 2048 with th.no_grad(): 178 # Convert to pytorch tensor or to TensorDict 179 13222.5 MiB 0.0 MiB 1024 obs_tensor = obs_as_tensor(self._last_obs, self.device) 180 13222.5 MiB 432.8 MiB 1024 actions, values, log_probs = self.policy(obs_tensor) 181 13222.5 MiB 0.0 MiB 1024 actions = actions.cpu().numpy() 182
183 # Rescale and perform action 184 13222.5 MiB 0.0 MiB 1024 clipped_actions = actions 185
186 13222.5 MiB 0.0 MiB 1024 if isinstance(self.action_space, spaces.Box): 187 if self.policy.squash_output: 188 # Unscale the actions to match env bounds 189 # if they were previously squashed (scaled in [-1, 1]) 190 clipped_actions = self.policy.unscale_action(clipped_actions) 191 else: 192 # Otherwise, clip the actions to avoid out of bound error 193 # as we are sampling from an unbounded Gaussian distribution 194 clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) 195
196 13222.5 MiB 25.3 MiB 1024 new_obs, rewards, dones, infos = env.step(clipped_actions) 197
198 13222.5 MiB 0.0 MiB 1024 self.num_timesteps += env.num_envs 199
200 # Give access to local variables 201 13222.5 MiB 0.0 MiB 1024 callback.update_locals(locals()) 202 13222.5 MiB 0.0 MiB 1024 if not callback.on_step(): 203 return False 204
205 13222.5 MiB 0.0 MiB 1024 self._update_info_buffer(infos, dones) 206 13222.5 MiB 0.0 MiB 1024 n_steps += 1 207
208 13222.5 MiB 0.0 MiB 1024 if isinstance(self.action_space, spaces.Discrete): 209 # Reshape in case of discrete action 210 13222.5 MiB 0.0 MiB 1024 actions = actions.reshape(-1, 1) 211
212 # Handle timeout by bootstraping with value function 213 # see GitHub issue #633 214 13222.5 MiB 0.0 MiB 50176 for idx, done in enumerate(dones): 215 13222.5 MiB 0.0 MiB 49152 if ( 216 13222.5 MiB 0.0 MiB 49152 done 217 13222.5 MiB 0.0 MiB 4560 and infos[idx].get("terminal_observation") is not None 218 13222.5 MiB 0.0 MiB 2280 and infos[idx].get("TimeLimit.truncated", False) 219 ): 220 terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] 221 with th.no_grad(): 222 terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] 223 rewards[idx] += self.gamma * terminal_value 224
225 13234.2 MiB 12095.8 MiB 2048 rollout_buffer.add( 226 13222.5 MiB 0.0 MiB 1024 self._last_obs, # type: ignore[arg-type] 227 13222.5 MiB 0.0 MiB 1024 actions, 228 13222.5 MiB 0.0 MiB 1024 rewards, 229 13222.5 MiB 0.0 MiB 1024 self._last_episode_starts, # type: ignore[arg-type] 230 13222.5 MiB 0.0 MiB 1024 values, 231 13222.5 MiB 0.0 MiB 1024 log_probs, 232 ) 233 13234.2 MiB -2.8 MiB 1024 self._last_obs = new_obs # type: ignore[assignment] 234 13234.2 MiB 0.0 MiB 1024 self._last_episode_starts = dones 235
236 13234.2 MiB 0.0 MiB 2 with th.no_grad(): 237 # Compute value for the last timestep 238 13234.2 MiB 0.0 MiB 1 values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type] 239
240 13234.2 MiB 0.0 MiB 1 rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) 241
242 13234.2 MiB 0.0 MiB 1 callback.update_locals(locals()) 243
244 13234.2 MiB 0.0 MiB 1 callback.on_rollout_end() 245
246 13234.2 MiB 0.0 MiB 1 return True I profiled the memory usage of the code, I guess the buffer needs to be reset somewhere? But its not done?

Hi, did you figure out how to solve this problem ? I meet this as well, my conclusion is that callback.on_rollout_start() may have some problem and cause extra memory consumption

jingyang-huang avatar May 11 '25 12:05 jingyang-huang