brax
brax copied to clipboard
Issue with rendering for VectorGymWrapper
It seems that the code for the render function in brax.envs.wrappers.gym.VectorGymWrapper is wrong. The code is intended only to render the first environment and not all environments. In case of gym/gymansium, vector envs render all environments at the same time. The error comes on this line
The issue is that pipeline_state does not implement take.
The code can be fixed by replacing the line with the following:
def render(self, mode='human'):
if mode == 'rgb_array':
sys, state = self._env.sys, self._state
if state is None:
raise RuntimeError('must call reset or step before rendering')
# Change this line to return (env_num, height, width, 3)
return np.stack([image.render_array(sys, state.take(i).pipeline_state, 256, 256) for i in range(self.num_envs)])
else:
return super().render(mode=mode) # just raise an exception
This will return an ndarray with images stacked for each environment.