stable-baselines3
stable-baselines3 copied to clipboard
NoneType crash with graph observation space even with FeaturesExtractor
🐛 Bug
I have a problem that involves the use of graphs in the observation space and I'd like to use pytorch geometric to train a GNN for my learning agent
My first idea was to use gymnasium's observation transform to flatten the input in a way that's compatible with Stable Baselines, but I didn't find a way to train the GNN inside the function that processes the observation
So I wanted to use a custom feature extractor to implement the GNN using sb3 to train it, but the program is crashing on the wrapping the environment at the initialization of PPO, not even giving me the chance to define a neural network to transform the input
Right now this approach gives a non-descriptive error message of None type not being iterable. Some investigation shows this happens because of the shape field when reading the graph, as shown in the attached code and error messages
Code example
import numpy as np
import gymnasium as gym
import gymnasium.spaces as sp
from gymnasium.utils.env_checker import check_env
import stable_baselines3 as sb3
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class GraphEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = gym.spaces.Box(0, 1)
node_shape = (2,)
edge_shape = (3,)
self.observation_space = sp.Dict({
"my_graph": sp.Graph(
node_space=sp.Box(low=0, high=1, shape=node_shape),
edge_space=sp.Box(low=0, high=1, shape=edge_shape)
),
"my_box": sp.Box(0, 1),
})
self.const_obs = {
"my_graph": gym.spaces.graph.GraphInstance(
nodes=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32),
edges=np.array([[0.5, 0.6, 0.7]], dtype=np.float32),
edge_links=np.array([[0, 1]])
),
"my_box": np.array([0.2], dtype=np.float32),
}
def reset(self, seed=None, options=None):
super().reset(seed=seed)
observation = self.const_obs
info = {}
return observation, info
def step(self, action):
observation = self.const_obs
reward = 1
terminated = truncated = False
info = {}
return observation, reward, terminated, truncated, info
class MyExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Dict, act_addr):
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.__init__ before adding modules
super().__init__(observation_space, features_dim=1)
extractors = {}
total_concat_size = 10
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations):
return np.zeros((10,), dtype=np.float32)
policy_kwargs = dict(
features_extractor_class=MyExtractor,
share_features_extractor=False,
)
env = GraphEnv()
check_env(env) # passes with irrelevant warning about render modes
sb3.PPO("MultiInputPolicy", env, policy_kwargs=policy_kwargs)
Relevant log output / Error message
File "/home/alex/code/deeprop/venv/lib/python3.13/site-packages/stable_baselines3/common/base_class.py", line 170, in __init__
env = self._wrap_env(env, self.verbose, monitor_wrapper)
File "/home/alex/code/deeprop/venv/lib/python3.13/site-packages/stable_baselines3/common/base_class.py", line 224, in _wrap_env
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
File "/home/alex/code/deeprop/venv/lib/python3.13/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 47, in __init__
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
~~~~~^^^^^^^^^^^
TypeError: 'NoneType' object is not iterable
From a postmortem analysis with pdb:
{'my_box': (1,), 'my_graph': None}
System Info
- OS: Linux-6.12.21-4-MANJARO-x86_64-with-glibc2.41 # 1 SMP PREEMPT_DYNAMIC Tue, 01 Apr 2025 18:45:30 +0000
- Python: 3.13.2
- Stable-Baselines3: 2.6.0
- PyTorch: 2.6.0+cu126
- GPU Enabled: True
- Numpy: 2.2.3
- Cloudpickle: 3.1.1
- Gymnasium: 1.0.0
- OpenAI Gym: 0.26.2
Checklist
- [x] I have checked that there is no similar issue in the repo
- [x] I have read the documentation
- [x] I have provided a minimal and working example to reproduce the bug
- [x] I have checked my env using the env checker
- [x] I've used the markdown code blocks for both code and stack traces.
Hello,
the env checker should be updated, SB3 doesn't support spaces.Graph.
Duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1723 and others (like https://github.com/DLR-RM/stable-baselines3/issues/1280)