stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Feature Request] Add overloads to return type of `preprocess_obs`

Open Rocamonde opened this issue 3 years ago • 0 comments

🚀 Feature

preprocess_obs returns a tensor or a Dict[str, th.Tensor] depending on whether the observation space is a dict space or not. The return type hint does not indicate this dependence on the argument, forcing users to cast the returned object for type safety.

Unfortunately, there is currently not a simple way to implement this AFAIK, since Python does not allow excluding types (so tensor is returned for any space except Dict), and the version of gym currently used in SB3 does not have support for generic spaces, so we cannot specify the types of elements in the observation space (i.e. what is returned by sample()).

However, I thought I leave the issue here anyways so that we can get back to it later once we upgrade to a newer version of gym, or if someone thinks of a way to implement this that works. ### Checklist

  • [x] I have checked that there is no similar issue in the repo (required)

Rocamonde avatar Sep 13 '22 15:09 Rocamonde