lerobot icon indicating copy to clipboard operation
lerobot copied to clipboard

feat: Implement Proximal Policy Optimization (PPO) policy

Open bsprenger opened this issue 10 months ago • 0 comments

What this does

This pull request introduces an implementation of Proximal Policy Optimization (PPO) (🤖 Policy) into LeRobot. Based on the original paper Proximal Policy Optimization Algorithms.

💡Note: For initial testing and debugging, I used the simple Pendulum-v1 environment. See the How to Checkout & Try section below for details on the custom wrapper that integrates this environment into the LeRobot training pipeline.

https://github.com/user-attachments/assets/9551838b-361a-4a30-8ab9-b90e2a02d68d

Key Changes:

  • Added the policy implementation (in modeling_ppo.py)
  • Added the configuration dataclass (in configuration_ppo.py)
  • Minor changes to the training script to accommodate the add'l tensors that are needed during training (e.g. log probs, advantages, returns)
  • Minor refactoring to allow for re-use of the MLP network from VQ-BeT.

Why PPO?

PPO is one of the most popular RL algorithms in robotics due to its stability and sample efficiency. It has proven its worth in many real-world applications:

Integrating PPO into LeRobot not only provides a proven benchmark against which newer policies can be compared, but also makes it easier to port pre-trained models from popular RL libraries like Stable Baselines3 which already has many models available on the Hub! 🚀✨

Additional Notes ✨

  • PPO is traditionally used in an online setting. Since the train.py script currently requires a dataset, I pushed a simple dataset to the Hub to avoid major refactoring. Setting offline.steps=0 ensures that the training behaves as online-only.
    • There is some literature on offline PPO, which could be explored in the future
  • The current implementation only supports state-based observations (no vision pipeline) to simplify testing and validation. A future commit could introduce a vision backbone for more general use
  • To ensure everything works correctly, the policy was validated on the simplest possible environment (Pendulum-v1 from gym/classic-control) using a custom wrapper that integrates into the LeRobot configuration and training API.

How it was tested

I trained the PPO policy to convergence on the Pendulum-v1 environment as a sanity check. The tests confirmed that the policy integrates seamlessly with the LeRobot training pipeline.

How to checkout & try? (for the reviewer)

To simplify testing and debugging, I created a custom wrapper for the Pendulum-v1 environment that adapts its API to match LeRobot’s requirements. This lets you use the standard training pipeline without extra modifications.

Since PPO is generally used online, loading a dataset isn’t ideal. However, to avoid major changes to train.py, I pushed a small dataset to the Hub and set offline.steps=0 to effectively perform online-only training.

Steps to Test:

To test with the same Pendulum-v1 environment:

  1. Create a Python file (e.g., wrapped_pendulum_train.py) with the following code:
import sys
import types
from dataclasses import dataclass, field

import gymnasium as gym
from gymnasium.envs.registration import register
from gymnasium.spaces import Dict

from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.utils.utils import init_logging
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.scripts.train import train

# Create simple wrapper around Pendulum-v1 environment
# This allows us to use this really simple env with the LeRobot pipeline
class PendulumDictWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Dict({"agent_pos": self.env.observation_space})

    def observation(self, obs):
        return {"agent_pos": obs}

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        info["is_success"] = False  # Always False for now
        return self.observation(obs), reward, terminated, truncated, info

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        return self.observation(obs), info


def make_pendulum_dict(**kwargs):
    env = gym.make("Pendulum-v1", **kwargs)
    return PendulumDictWrapper(env)


register(
    id="gym_pendulum_v1/Pendulum-v1",
    entry_point=make_pendulum_dict,
)


@EnvConfig.register_subclass("pendulum_v1")
@dataclass
class PendulumEnv(EnvConfig):
    task: str = "Pendulum-v1"
    fps: int = 30
    episode_length: int = 200
    obs_type: str = "state"
    render_mode: str = "rgb_array"
    features: dict[str, PolicyFeature] = field(
        default_factory=lambda: {
            "action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
            "state": PolicyFeature(type=FeatureType.STATE, shape=(3,)),
        }
    )
    features_map: dict[str, str] = field(
        default_factory=lambda: {
            "action": ACTION,
            "state": OBS_ROBOT,
        }
    )

    @property
    def gym_kwargs(self) -> dict:
        return {
            "render_mode": self.render_mode,
            "max_episode_steps": self.episode_length,
        }


# Create dummy module to avoid import errors from `make_env`
if "gym_pendulum_v1" not in sys.modules:
    sys.modules["gym_pendulum_v1"] = types.ModuleType("gym_pendulum_v1")


if __name__ == "__main__":
    init_logging()
    train()

  1. Execute the newly created script with the following command-line arguments:
python path/to/wrapped_pendulum_train.py \
--policy.type=ppo \
--dataset.repo_id=bensprenger/lerobot_ppo_pendulum_v1 \
--env.type=pendulum_v1 \
--env.task=Pendulum-v1 \
--job_name=ppo_pendulum_v1 \
--log_freq=2500 \
--batch_size=64 \
--offline.steps=0 \
--online.steps=100000 \
--online.sampling_ratio=1.0 \
--online.steps_between_rollouts=370 \
--online.env_seed=1 \
--online.buffer_capacity=2412 \
--online.rollout_batch_size=12 

@Cadene @aliberts Let me know what you think!

bsprenger avatar Feb 09 '25 21:02 bsprenger