stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

NoneType crash with graph observation space even with FeaturesExtractor

Open alexsantee opened this issue 7 months ago • 1 comments
trafficstars

🐛 Bug

I have a problem that involves the use of graphs in the observation space and I'd like to use pytorch geometric to train a GNN for my learning agent

My first idea was to use gymnasium's observation transform to flatten the input in a way that's compatible with Stable Baselines, but I didn't find a way to train the GNN inside the function that processes the observation

So I wanted to use a custom feature extractor to implement the GNN using sb3 to train it, but the program is crashing on the wrapping the environment at the initialization of PPO, not even giving me the chance to define a neural network to transform the input

Right now this approach gives a non-descriptive error message of None type not being iterable. Some investigation shows this happens because of the shape field when reading the graph, as shown in the attached code and error messages

Code example

import numpy as np
import gymnasium as gym
import gymnasium.spaces as sp
from gymnasium.utils.env_checker import check_env
import stable_baselines3 as sb3
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class GraphEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.action_space = gym.spaces.Box(0, 1)
        node_shape = (2,)
        edge_shape = (3,)
        self.observation_space = sp.Dict({
            "my_graph": sp.Graph(
                node_space=sp.Box(low=0, high=1, shape=node_shape),
                edge_space=sp.Box(low=0, high=1, shape=edge_shape)
                ),
            "my_box": sp.Box(0, 1),
            })

        self.const_obs = {
            "my_graph": gym.spaces.graph.GraphInstance(
                nodes=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32),
                edges=np.array([[0.5, 0.6, 0.7]], dtype=np.float32),
                edge_links=np.array([[0, 1]])
                ),
            "my_box": np.array([0.2], dtype=np.float32),
            }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        observation = self.const_obs
        info = {}
        return observation, info

    def step(self, action):
        observation = self.const_obs
        reward = 1
        terminated = truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

class MyExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, act_addr):
        # We do not know features-dim here before going over all the items,
        # so put something dummy for now. PyTorch requires calling
        # nn.Module.__init__ before adding modules
        super().__init__(observation_space, features_dim=1)
        extractors = {}
        total_concat_size = 10
        # Update the features dim manually
        self._features_dim = total_concat_size

    def forward(self, observations):
        return np.zeros((10,), dtype=np.float32)


policy_kwargs = dict(
    features_extractor_class=MyExtractor,
    share_features_extractor=False,
)
env = GraphEnv()
check_env(env)  # passes with irrelevant warning about render modes
sb3.PPO("MultiInputPolicy", env, policy_kwargs=policy_kwargs)

Relevant log output / Error message

File "/home/alex/code/deeprop/venv/lib/python3.13/site-packages/stable_baselines3/common/base_class.py", line 170, in __init__
    env = self._wrap_env(env, self.verbose, monitor_wrapper)
  File "/home/alex/code/deeprop/venv/lib/python3.13/site-packages/stable_baselines3/common/base_class.py", line 224, in _wrap_env
    env = DummyVecEnv([lambda: env])  # type: ignore[list-item, return-value]
  File "/home/alex/code/deeprop/venv/lib/python3.13/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 47, in __init__
    self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
                                                              ~~~~~^^^^^^^^^^^
TypeError: 'NoneType' object is not iterable

From a postmortem analysis with pdb:

{'my_box': (1,), 'my_graph': None}

System Info

  • OS: Linux-6.12.21-4-MANJARO-x86_64-with-glibc2.41 # 1 SMP PREEMPT_DYNAMIC Tue, 01 Apr 2025 18:45:30 +0000
  • Python: 3.13.2
  • Stable-Baselines3: 2.6.0
  • PyTorch: 2.6.0+cu126
  • GPU Enabled: True
  • Numpy: 2.2.3
  • Cloudpickle: 3.1.1
  • Gymnasium: 1.0.0
  • OpenAI Gym: 0.26.2

Checklist

  • [x] I have checked that there is no similar issue in the repo
  • [x] I have read the documentation
  • [x] I have provided a minimal and working example to reproduce the bug
  • [x] I have checked my env using the env checker
  • [x] I've used the markdown code blocks for both code and stack traces.

alexsantee avatar Apr 23 '25 12:04 alexsantee

Hello, the env checker should be updated, SB3 doesn't support spaces.Graph.

Duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1723 and others (like https://github.com/DLR-RM/stable-baselines3/issues/1280)

araffin avatar Apr 23 '25 12:04 araffin