stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Bug]: Video upload to wandb broken since 2.4.0

Open OliverUrbann opened this issue 11 months ago • 9 comments
trafficstars

🐛 Bug

Using stable_baselines3 2.3.2 in Python 3.11 the provided unit test can upload videos to WANDB successfully. However, using 2.4 it fails.

To Reproduce

import unittest
import time
import os
import gymnasium as gym
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
import wandb
from wandb import Api
from wandb.integration.sb3 import WandbCallback
from stable_baselines3 import PPO

class TestWandbVideoUpload(unittest.TestCase):
    def test_video_upload(self):
        env_id = "CartPole-v1"
        video_folder = "videos"
        video_length = 100

        vec_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")])

        obs = vec_env.reset()

        run = wandb.init(
            project="test",
            sync_tensorboard=True,  # Automatically upload SB3's TensorBoard metrics
            monitor_gym=True,       # Automatically upload agent playing videos
            # save_code=True,       # Optional
        )

        # Record the video starting at the first step
        vec_env = VecVideoRecorder(
            vec_env,
            video_folder,
            record_video_trigger=lambda x: x == 0,
            video_length=video_length,
            name_prefix=f"agent-{env_id}"
        )

        vec_env.reset()

        model = PPO("MlpPolicy", vec_env, verbose=1, tensorboard_log=f"runs/{run.id}")
        model.learn(
            total_timesteps=5000,
            callback=WandbCallback(
                model_save_path=f"tmp/models/{run.id}",
                verbose=2,
            ),
        )
        run.finish()

        # Give some time for the upload (adjust depending on connection speed)
        time.sleep(30)

        # Use the wandb API to check the run
        api = Api()
        # If you're logged into a different W&B account or using an organization, adjust 'entity' accordingly
        run_path = f"{run.entity}/{run.project}/{run.id}"
        run_api = api.run(run_path)

        # Retrieve a list of all files in the run
        files = run_api.files()
        file_names = [f.name for f in files]

        # Check if a video file is present
        video_files = [name for name in file_names if name.endswith('.mp4')]

        self.assertTrue(len(video_files) > 0, "The video was not uploaded to wandb.")

        # Optional: Print the uploaded video files
        print("Uploaded video files:", video_files)

        # Clean up
        vec_env.close()
        wandb.finish()

if __name__ == '__main__':
    unittest.main()

Relevant log output / Error message

No response

System Info

  • OS: Linux-5.15.0-124-generic-x86_64-with-glibc2.35 # 134-Ubuntu SMP Fri Sep 27 20:20:17 UTC 2024
  • Python: 3.11.0rc1
  • Stable-Baselines3: 2.4.0
  • PyTorch: 2.5.1+cu124
  • GPU Enabled: False
  • Numpy: 1.26.4
  • Cloudpickle: 3.1.0
  • Gymnasium: 0.29.1

Checklist

  • [X] My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
  • [X] I have checked that there is no similar issue in the repo
  • [X] I have read the documentation
  • [X] I have provided a minimal and working example to reproduce the bug
  • [X] I've used the markdown code blocks for both code and stack traces.

OliverUrbann avatar Dec 13 '24 10:12 OliverUrbann