Fix return type for load, learn in BaseAlgorithm
Description
Fixes the return type of .load() and .learn() methods in BaseAlgorithm so that they now use the Self type PEP 0673 instead of BaseAlgorithm, which breaks type checkers for use with any subclass.
Motivation and Context
Closes #1040.
- [x] I have raised an issue to propose this change (required for new features and bug fixes)
Types of changes
- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Documentation (update in the documentation)
Checklist:
- [x] I've read the CONTRIBUTION guide (required)
- [x] I have updated the changelog accordingly (required).
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
- [ ] I have updated the documentation accordingly.
- [x] I have reformatted the code using
make format(required) - [x] I have checked the codestyle using
make check-codestyleandmake lint(required) - [x] I have ensured
make pytestandmake typeboth pass. (required) - [x] I have checked that the documentation builds using
make doc(required)
Note: You can run most of the checks using make commit-checks.
Note: we are using a maximum length of 127 characters per line
Users may use python >= 3.11. So I suggest that you distinguish between the cases:
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
This is what I was doing initially. But if typing extensions is installed anyways, that would work regardless. Though I agree it’s cleaner.
On Fri, Sep 2 2022 at 15:44, Quentin Gallouédec < @.*** > wrote:
Users may use python >= 3.11. So I suggest that you distinguish between the cases:
if sys. version_info >= ( 3 , 11 ): from typing import Self else : from typing_extensions import Self
— Reply to this email directly, view it on GitHub ( https://github.com/DLR-RM/stable-baselines3/pull/1043#issuecomment-1235524240 ) , or unsubscribe ( https://github.com/notifications/unsubscribe-auth/ABVWH324CEJQDA7Q3FZE2XDV4IAEXANCNFSM6AAAAAAQCH2GCQ ). You are receiving this because you authored the thread. Message ID: <DLR-RM/stable-baselines3/pull/1043/c1235524240 @ github. com>
Apparently pytype does not support typing_extensions.Self yet, lol.
Apparently pytype does not support typing_extensions.Self yet, lol.
It's a shame, I opened a question about it: https://github.com/google/pytype/issues/1283
Given Self is not supported, we should probably go with the TypeVar solution from https://peps.python.org/pep-0673/ ? WDYT, @Rocamonde ?
it seems like it won't be supported for some time. Best to use the TypeVar solution for now, update it later.
We might want to open another issue (or not mark the current one as completed) so that we remember to go back to this in a couple of months when support is added.
@AdamGleave @qgallouedec what do you think of the current status?
Sounds good. Made those changes. Left the small format change in the load_from_zip_file() call as despite it being accidental I think it makes the call more readable now. Let me know your thoughts.
You also should remove the unused import sys
Sorry about that. Just did that now.
Was going to merge the approved PR but don't have write access.
Was going to merge the approved PR but don't have write access.
Just wait for the checks to to completed, I'll merge it soon.
Should we do the same fix here?
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/common/policies.py#L161
Sure, let's do that too.
Could you also correct the following lines? I think it is within the scope of this PR.
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/common/off_policy_algorithm.py#L332
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/common/on_policy_algorithm.py#L238
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/a2c/a2c.py#L196
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/ddpg/ddpg.py#L129
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/dqn/dqn.py#L268
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/ppo/ppo.py#L310
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/sac/sac.py#L302
https://github.com/DLR-RM/stable-baselines3/blob/18b29a68e8d5a11d0e98aeea539c247f0a913019/stable_baselines3/td3/td3.py#L218
Added those changes @qgallouedec .
Thank you for contributing @Rocamonde!
@Rocamonde in case you have time in the coming weeks, could you do a similar PR to our contrib repo? Otherwise, I will open an issue in that repo not to forget ;)
@araffin I suppose you're referring to the return types of specific algorithms? As I guess the base classes are shared with sb3.
What is a good way to search for all occurrences in the codebase?
I suppose you're referring to the return types of specific algorithms?
yes, and I mean this repo: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib It follows the exact same structure as SB3.
What is a good way to search for all occurrences in the codebase?
i think there will be only one occurence in each algorithm folder (for the learn method).