feat: Implement Proximal Policy Optimization (PPO) policy
What this does
This pull request introduces an implementation of Proximal Policy Optimization (PPO) (🤖 Policy) into LeRobot. Based on the original paper Proximal Policy Optimization Algorithms.
💡Note: For initial testing and debugging, I used the simple Pendulum-v1 environment. See the How to Checkout & Try section below for details on the custom wrapper that integrates this environment into the LeRobot training pipeline.
https://github.com/user-attachments/assets/9551838b-361a-4a30-8ab9-b90e2a02d68d
Key Changes:
- Added the policy implementation (in
modeling_ppo.py) - Added the configuration dataclass (in
configuration_ppo.py) - Minor changes to the training script to accommodate the add'l tensors that are needed during training (e.g. log probs, advantages, returns)
- Minor refactoring to allow for re-use of the MLP network from VQ-BeT.
Why PPO?
PPO is one of the most popular RL algorithms in robotics due to its stability and sample efficiency. It has proven its worth in many real-world applications:
Integrating PPO into LeRobot not only provides a proven benchmark against which newer policies can be compared, but also makes it easier to port pre-trained models from popular RL libraries like Stable Baselines3 which already has many models available on the Hub! 🚀✨
Additional Notes ✨
- PPO is traditionally used in an online setting. Since the
train.pyscript currently requires a dataset, I pushed a simple dataset to the Hub to avoid major refactoring. Setting offline.steps=0 ensures that the training behaves as online-only.- There is some literature on offline PPO, which could be explored in the future
- The current implementation only supports state-based observations (no vision pipeline) to simplify testing and validation. A future commit could introduce a vision backbone for more general use
- To ensure everything works correctly, the policy was validated on the simplest possible environment (Pendulum-v1 from
gym/classic-control) using a custom wrapper that integrates into the LeRobot configuration and training API.
How it was tested
I trained the PPO policy to convergence on the Pendulum-v1 environment as a sanity check. The tests confirmed that the policy integrates seamlessly with the LeRobot training pipeline.
How to checkout & try? (for the reviewer)
To simplify testing and debugging, I created a custom wrapper for the Pendulum-v1 environment that adapts its API to match LeRobot’s requirements. This lets you use the standard training pipeline without extra modifications.
Since PPO is generally used online, loading a dataset isn’t ideal. However, to avoid major changes to train.py, I pushed a small dataset to the Hub and set offline.steps=0 to effectively perform online-only training.
Steps to Test:
To test with the same Pendulum-v1 environment:
- Create a Python file (e.g., wrapped_pendulum_train.py) with the following code:
import sys
import types
from dataclasses import dataclass, field
import gymnasium as gym
from gymnasium.envs.registration import register
from gymnasium.spaces import Dict
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.utils.utils import init_logging
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.scripts.train import train
# Create simple wrapper around Pendulum-v1 environment
# This allows us to use this really simple env with the LeRobot pipeline
class PendulumDictWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = Dict({"agent_pos": self.env.observation_space})
def observation(self, obs):
return {"agent_pos": obs}
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
info["is_success"] = False # Always False for now
return self.observation(obs), reward, terminated, truncated, info
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
return self.observation(obs), info
def make_pendulum_dict(**kwargs):
env = gym.make("Pendulum-v1", **kwargs)
return PendulumDictWrapper(env)
register(
id="gym_pendulum_v1/Pendulum-v1",
entry_point=make_pendulum_dict,
)
@EnvConfig.register_subclass("pendulum_v1")
@dataclass
class PendulumEnv(EnvConfig):
task: str = "Pendulum-v1"
fps: int = 30
episode_length: int = 200
obs_type: str = "state"
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"state": PolicyFeature(type=FeatureType.STATE, shape=(3,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"state": OBS_ROBOT,
}
)
@property
def gym_kwargs(self) -> dict:
return {
"render_mode": self.render_mode,
"max_episode_steps": self.episode_length,
}
# Create dummy module to avoid import errors from `make_env`
if "gym_pendulum_v1" not in sys.modules:
sys.modules["gym_pendulum_v1"] = types.ModuleType("gym_pendulum_v1")
if __name__ == "__main__":
init_logging()
train()
- Execute the newly created script with the following command-line arguments:
python path/to/wrapped_pendulum_train.py \
--policy.type=ppo \
--dataset.repo_id=bensprenger/lerobot_ppo_pendulum_v1 \
--env.type=pendulum_v1 \
--env.task=Pendulum-v1 \
--job_name=ppo_pendulum_v1 \
--log_freq=2500 \
--batch_size=64 \
--offline.steps=0 \
--online.steps=100000 \
--online.sampling_ratio=1.0 \
--online.steps_between_rollouts=370 \
--online.env_seed=1 \
--online.buffer_capacity=2412 \
--online.rollout_batch_size=12
@Cadene @aliberts Let me know what you think!