stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Bug] Too general a return type for load and learn in BaseAlgorithm

Open Rocamonde opened this issue 3 years ago • 4 comments

🐛 Bug

The return type of methods .load() and .learn() in BaseAlgorithm is annotated as "BaseAlgorithm", which means that for any subclass that does not override the methods with suitable annotations, the type won't check if the user is expecting an instance of the subclass. Example:

def load_ppo(cache_path, venv) -> PPO:
    return PPO.load(cache_path, venv)

The code above does not type check ("expected PPO, got BaseAlgorithm").

To Reproduce

Run type checker on code above.

Expected behavior

Code should type check.

Solution

The solution to this issue (happy to submit a PR for it) is to replace the return type with the Self type variable introduced in PEP0673. It is built into python 3.11 but it is also available through the PyPI package typing_extensions.

Instead of: https://github.com/DLR-RM/stable-baselines3/blob/304c17dc78dbaeba77c709ec03c1b3847991018d/stable_baselines3/common/base_class.py#L533-L544

we use

if sys.version_info >= (3, 11):
    from typing import Self
else:
    from typing_extensions import Self

class BaseAlgorithm(ABC):
    ...

    @abstractmethod
    def learn(self, ...) -> Self: 
        ...

    @classmethod
    def load(cls, ...) -> Self:
        ...

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

Rocamonde avatar Aug 31 '22 15:08 Rocamonde

This seems to me to be a good idea. Do you know if this syntax is compatible with pytype for every python version? Have you run the pytype check?

qgallouedec avatar Aug 31 '22 15:08 qgallouedec

@qgallouedec I have not tried actually starting a PR yet (have never contributed to sb3), but the typing_extensions package works with Python 3.7+, and is supported by pytype. Is this in line with what SB3 supports? If so, happy to create a PR and run all the relevant CI checks.

Rocamonde avatar Aug 31 '22 15:08 Rocamonde

Is this in line with what SB3 supports?

It is.

If so, happy to create a PR and run all the relevant CI checks.

Please do. :)

qgallouedec avatar Aug 31 '22 15:08 qgallouedec

@qgallouedec see https://github.com/DLR-RM/stable-baselines3/pull/1043

Rocamonde avatar Sep 01 '22 11:09 Rocamonde