stable-baselines3
stable-baselines3 copied to clipboard
model.learn() does not properly handle gym.spaces.Discrete spaces where start !=0
🐛 Bug
I have a spaces.Discrete(19, 2) observation as part of my observation. The documentation for the Discrete space lists the space of possible values as {a, a+1, ..., a+n-1}, thus my observation space for this value will be {2, 3, ..., 20}. However, model.learn() does not seem to account for this non-zero start, as it errors during the one hot encoding because torch.nn.functional.one_hot creates a tensor from 0 to observation_space.n: return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float(). So torch creates a tensor of length 19, and thus errors when trying to one-hot encode the value 20.
Am I doing something wrong, is there an automatic way to handle this somewhere that I have missed, or do I need to manually map my Discrete observation to start from 0 before?
Thank you.
Code example
import polars as pl
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import DummyVecEnv
import os
class simple_env(gym.Env):
"""Custom Environment that follows gym interface."""
def __init__(self, df: pl.DataFrame):
super().__init__()
self.df = df
self.cur_row = 0
self.action_space = spaces.Box(low=-1, high=1, shape=(1,))
self.observation_space = spaces.Dict({
"a": spaces.Box(0, 1, (1,), np.float64),
"b": spaces.Discrete(19, 2)
})
self.metadata = {"render_modes": ["human"], "render_fps": 30}
self.render_mode = "human"
def step(self, action):
terminated = False
info = {}
reward = float(action)
if self.cur_row == self.df.shape[0]-1:
self.cur_row = -1
self.cur_row+=1
observation = {key:
self.df[self.cur_row].select(pl.col(f"{key}")).to_numpy().flatten()[0]
if isinstance(value, gym.spaces.Discrete)
else self.df[self.cur_row].select(pl.col(f"{key}")).to_numpy().flatten()
for key, value in self.observation_space.items()}
return observation, reward, terminated, False, info
def reset(self, seed=None):
super().reset(seed=seed)
self.cur_row=0
observation = {key:
self.df[self.cur_row].select(pl.col(f"{key}")).to_numpy().flatten()[0]
if isinstance(value, gym.spaces.Discrete)
else self.df[self.cur_row].select(pl.col(f"{key}")).to_numpy().flatten()
for key, value in self.observation_space.items()}
info = {}
return observation, info
def render(self, mode="human"):
pass
def close(self):
pass
def make_env(env_config, rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_config: (dict) the dictionary of environment parameters
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = simple_env(**env_config)
env.reset(seed=(seed + rank))
return env
set_random_seed(seed)
pl.set_random_seed(rank+seed)
return _init
def make_env(env_config, rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_config: (dict) the dictionary of environment parameters
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = simple_env(**env_config)
env.reset(seed=(seed + rank))
return env
set_random_seed(seed)
pl.set_random_seed(rank+seed)
return _init
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
df = pl.DataFrame({"a": [0.0, 0.12, 0.17, 0.99, 0.86], "b": [7, 8, 5, 15, 20]})
env = simple_env(df)
check_env(env)
env_config = {"df": df}
num_cpu = 1
vec_env = DummyVecEnv([make_env(env_config, i) for i in range(num_cpu)])
model = PPO("MultiInputPolicy", vec_env, device="cpu")
obs = vec_env.reset()
model.learn(10)
Relevant log output / Error message
File "c:test\tests.py", line 231, in <module>
action, _states = model.predict(obs, deterministic=True)
File "C:\test\.venv\lib\site-packages\stable_baselines3\common\base_class.py", line 553, in predict
return self.policy.predict(observation, state, episode_start, deterministic)
File "C:\test\.venv\lib\site-packages\stable_baselines3\common\policies.py", line 366, in predict
actions = self._predict(obs_tensor, deterministic=deterministic)
File "C:\test\.venv\lib\site-packages\stable_baselines3\common\policies.py", line 715, in _predict
return self.get_distribution(observation).get_actions(deterministic=deterministic)
File "C:\test\.venv\lib\site-packages\stable_baselines3\common\policies.py", line 748, in get_distribution
features = super().extract_features(obs, self.pi_features_extractor)
File "C:\test\.venv\lib\site-packages\stable_baselines3\common\policies.py", line 130, in extract_features
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
File "C:\test\.venv\lib\site-packages\stable_baselines3\common\preprocessing.py", line 113, in preprocess_obs
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
File "C:\test\.venv\lib\site-packages\stable_baselines3\common\preprocessing.py", line 125, in preprocess_obs
return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float()
RuntimeError: Class values must be smaller than num_classes.
System Info
- OS: Windows-10-10.0.22621-SP0 10.0.22621
- Python: 3.8.10
- Stable-Baselines3: 2.2.1
- PyTorch: 2.1.2+cu118
- GPU Enabled: True
- Numpy: 1.24.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] I have read the documentation
- [X] I have provided a minimal and working example to reproduce the bug
- [X] I have checked my env using the env checker
- [X] I've used the markdown code blocks for both code and stack traces.
Hi @araffin, if this is a duplicate, can you please provide the link to the original? And what checkboxes need more clarification?
I have checked that there is no similar issue in the repo
Try harder next time =)
Duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1509, #1295 and #913
but it seems we need to update the env checker to warn when users are using dict obs space.
Am I doing something wrong, is there an automatic way to handle this somewhere that I have missed, or do I need to manually map my Discrete observation to start from 0 before?
yes or you can use a wrapper.
but it seems we need to update the env checker to warn when users are using dict obs space.
The env checker was up to date, you were using the Discrete constructor the wrong way (second argument is not start)