ValueError when validating a gym environment with MultiDiscrete action space
Greetings!
Setup
I have a gym class (let's call it gym_env) with a gym.spaces.MultiDiscrete action space with a dimension of (4, 3, 2). So when I use suite_gym.wrap_env(gym_env) then the new_env.action_spec() seems like this:
BoundedArraySpec(shape=(3,), dtype=dtype('int32'), name='action', minimum=0, maximum=[3 2 1])
The Problem
When I validate this new wrapped environment using utils.validate_py_environment(), I get the following error:
[...]/tf_agents/specs/array_spec.py, in sample_bounded_spec
ValueError: cannot reshape array of size 1 into shape (3,)
My Findings
My first guess: spec.minimum is an int (0) and the maximum is a list ([3, 2, 1]). So there must be some problem with that.
When I looked into the sample_bounded_spec function of array_spec.py (this line), I found what seems to be the source:
np.reshape(
np.array([
rng.randint(low, high, size=1, dtype=spec.dtype)
for low, high in zip(low.flatten(), high.flatten())
]), spec.shape)
Here low.flatten() returns 0 and high.flatten() returns [3, 2, 1]. So when zip(low.flatten(), high.flatten()) is called, the iteration does not create 3 pairs of random ints with min & max values of ([0, 3], [0, 2], [0, 1]), but rather only one pair with min, max of [0, 3]. Hence yielding the error that this single random int cannot be reshaped to an array of (3,).
I'm not sure if this is a bug. Please let me know if I'm not understanding something correctly!
Code to Reproduce Error
import gym
from tf_agents.environments import suite_gym, utils
class TestEnv(gym.Env):
def __init__(self):
self.action_shape = (4, 3, 2)
self.action_space = gym.spaces.MultiDiscrete((self.action_shape))
# observation space
self.observation_shape = (10, 10, 2)
self.observation_space = gym.spaces.MultiBinary(self.observation_shape)
self.time_step=0
def reset(self):
self.time_step = 0
return self.observation_space.sample()
def step(self, action):
self.time_step += 1
return self.observation_space.sample(), 0, self.time_step>10, None
if __name__ == '__main__':
gym_env = TestEnv()
tf_env = suite_gym.wrap_env(gym_env)
utils.validate_py_environment(tf_env)