stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Bug] get_obs_shape returns wrong shape for MultiBinary spaces

Open hjarraya opened this issue 3 years ago • 10 comments

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

If your issue is related to a custom gym environment, please use the custom gym env template.

🐛 Bug

stable_baselines3.common.preprocessing.get_obs_shape returns the wrong shape when a MultiBinary spaces is multi-dimensions.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

from gym import spaces
from stable_baselines3.common.preprocessing import get_obs_shape


test_multi = spaces.MultiBinary([5, 4, 5])

get_obs_shape(test_multi)

    151     elif isinstance(observation_space, spaces.MultiBinary):
    152         # Number of binary features
--> 153         return (int(observation_space.n),)
    154     elif isinstance(observation_space, spaces.Dict):
    155         return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}

TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'

Expected behavior

get_obs_shape should returnobservation_space.shape not int(observation_space.n)

### System Info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, docker, source, ...): pip
  • GPU models and configuration: none
  • Python version: 3.7.
  • PyTorch version: 1.9.0
  • Gym version" 0.18.3

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

hjarraya avatar Aug 05 '21 14:08 hjarraya

Hey there! Nice catch. The code assumes you can only define MultiBinary with e.g. MultiBinary(5) (five binary variables), so it completely misses the concept of defining multi-dimensional array of binary variables like done here. In that case it should return a total number of binary items, i.e. sum(observation_space.n).

However to support this, same fixes need to be applied elsewhere in the code (e.g. action sampling). A PR to fix this would be very welcome (unless @araffin has other thoughts on this) :)

Miffyli avatar Aug 06 '21 12:08 Miffyli

Happy to make a PR for this, I made the change get_obs_shape and tests ran successfully, will add a test for regression sake. I am new to this codebase, it would be great if you can point me to the action sampling you are referring to.

hjarraya avatar Aug 06 '21 13:08 hjarraya

This is the distribution used for MultiBinary action spaces, and it shouldn't work if you give a list of integers. You can update this test script to also test MultiBinary spaces with multiple dimensions.

Note that you could do the simplest solution of flattening the multi-dimensional, MultiBinary space into one long list. This might be the easiest way to approach this.

Miffyli avatar Aug 06 '21 17:08 Miffyli

Hello,

stable_baselines3.common.preprocessing.get_obs_shape returns the wrong shape when a MultiBinary spaces is multi-dimensions.

quick question: why would do spaces.MultiBinary([5, 4, 5]) instead of spaces.MultiBinary(14)? (as each dimension is independent)

(I know that gym allows you to do the first case, but in my opinion, it does not make much sense, maybe for binary images?)

other than that, if the proposed fix work and keep the old behavior, then please go ahead with the PR ;)

araffin avatar Aug 07 '21 16:08 araffin

I would do spaces.MultiBinary([5, 4, 5]) for example when you have 5 worker queues capacity with 4 types of jobs at 5 different timesteps. Where 1 is busy and 0 it is not.

hjarraya avatar Aug 10 '21 16:08 hjarraya

To add one point, from gym's code for MultiBinary, spaces.MultiBinary([5, 4, 5]) would be 5x4x5, not 5+4+5 binary variables.

In my case the observation is kind of a 10x10 2D chessboard, with 0,1 indicating occupancy. A 100 1D MultiBinary would also work, but then some back and forth array reshaping would be needed

Any news on this issue? Happens to be bitten by the same bug. Wonder if I could help.

ylchan87 avatar Aug 14 '21 12:08 ylchan87

I think the reshaping will be done regardless, so making a quick wrapper (for now, before official support) would be sufficient. Something along lines of (not tested, just a sketch):

class MultiBinaryFlattenWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.original_shape = self.observation_space.n
        assert isinstance(self.original_shape, tuple) and len(self.original_shape) > 1
        self.observation_space = gym.spaces.MultiBinary(n=np.product(self.observation_space.n))

    def step(self, action):
        action = action.reshape(self.original_shape)
        return self.env.step(action) 

Miffyli avatar Aug 14 '21 13:08 Miffyli

Despite the doc, turns out gym itself also dislike N-Dim MultiBinary,

def flatdim(space):
    ....
    elif isinstance(space, MultiBinary):
>        return int(space.n)

E           TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'

https://github.com/openai/gym/blob/4ede9280f9c477f1ca09929d10cdc1e1ba1129f1/gym/spaces/utils.py#L28

So I guess a wrapper would be the solution, not a temporary one

ylchan87 avatar Aug 15 '21 01:08 ylchan87

What exactly is the recommended solution? I've got an environment with a MultiBinary((8, 8)) observation space and can't run PPO or A2C. Has anyone written a wrapper yet? If not I will attempt to write one.

I just hacked this together:

class FlatObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.original_shape = self.observation_space.n
        self.observation_space = gym.spaces.MultiBinary(n=np.product(self.observation_space.n))

    def observation(self, obs):
        observation = obs.reshape(np.product(self.observation_space.n))
        return observation

The problem also occurs with non-flat MultiBinary action spaces. Here's a wrapper for that:

class FlatActionWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.original_shape = self.action_space.n
        self.action_space = gym.spaces.MultiBinary(n=np.product(self.action_space.n))

    def action(self, act):
        action = act.reshape(self.original_shape)
        return action

Honestly, I don't know if these will work in every case.

SkittlePox avatar Jan 01 '22 08:01 SkittlePox

@SkittlePox Doing a wrapper like one you shared would be the way to go for now. Updating SB3 code to support this could potentially be welcomed but it will require careful testing to ensure older behaviour does not change.

Miffyli avatar Jan 01 '22 23:01 Miffyli

Should be fixed in https://github.com/DLR-RM/stable-baselines3/pull/1179

araffin avatar Dec 09 '22 13:12 araffin