stable-baselines3
stable-baselines3 copied to clipboard
[Proposal, Enhancement] Improve Registry through class initialisation arguments
Proposal
Using class initialization arguments allow for very easy class registration, this feature makes registering, using, and retrieving policies through strings trivial. As is, every algorithm provides a BasePolicy argument. We can use the base policy argument to retrieve the correct policy from a registry specific to the particular policy.
Motivation
This change removes the need for explicit registration and makes registering trivial. It also enables us to use a derived class as a base class and register a policy as a under for multiple other parent policies.
Specifics
Introduce the following function to BasePolicy.
The function is automatically called by every subclass, BasePolicy doesn't call it. Each derivative policy provides a policy name, e.g. BaseAlgorithmPolicy provides policy_name="MlpPolicy". Stuff like CnnPolicy provide policy_name="CnnPolicy" and parent_class=BaseAlgorithmPolicy.
Because stuff like Qnetworks subclass BasePolicy, we should avoid registering them. This is done by checking the policy_name.
class BasePolicy:
def __init_subclass__(cls, policy_name=None, parent_class=None, is_root_class=False, root_name=""):
# For stuff like QNetworks that subclass BasePolicy
if policy_name is None:
return
if is_root_class:
cls._registry = {(root_name or policy_name): cls}
if parent_class is None:
# The AlgorithmBasePolicy is parent of itself
parent_class = cls
if not isinstance(parent_class, tuple):
parent_class = [parent_class]
for parent in parent_class:
parent._registry = getattr(parent, "_registry", {})
parent._registry[policy_name] = cls
@classmethod
def policy_from_name(cls, name):
# called from the agent using self.base_policy.policy_from_name(policy)
return cls._registry[name]
class DQNPolicy(BasePolicy, policy_name="MlpPolicy"):
...
class CNNDQNPolicy(DQNPolicy, policy_name="CnnPolicy", parent_class=DQNPolicy):
...
class RNNCNN(BasePolicy, policy_name="RnnCnnPolicy", parent_class=DQNPolicy):
....
# Algorithms can use EnsemblePolicy as base policy
class EnsemblePolicy(DQNPolicy, policy_name="EnsemblePolicy", parent_class=DQNPolicy, is_root_class=True, root_name="MlpPolicy"):
....
class UberEnsemblePolicy(EnsemblePolicy, policy_name="UberPolicy", parent_class=EnsemblePolicy):
....
Sounds reasonable, but I think we still should discuss bit more about the registry stuff (it was not designed to be used outside the policies that come as-is). I see the appeal (easy to register policies in other files), ~~but on the other hand they limit the initialization parameters to the default ones unless you start overriding classes and such.~~ Edit: Took me too long to realize policy_kwargs exists.
@araffin comments? Should this also be more on the v1.1 side once main focuses are done?
but on the other hand they limit the initialization parameters to the default ones unless you start overriding classes and such.
What do you mean?
Oh sorry, my bad. I completely forgot policy_kwargs exists. Never mind that part of my comment ^^'
Even so, __init_subclass__ is a hook that is automatically called, so from the perspective of the current code, the only change is to pass the policy name and the parent class in the header and that's it.
About __init_subclass__:
https://www.python.org/dev/peps/pep-0487/#:~:text=An%20__init_subclass__%20hook,defined%20in%20the%20class%2C%20and
Ah alright, hmm... It sounds nifty, but the same time it is starting to make things more complicated with these advanced and not-so-much-used features. One of the features of stable-baselines is, or at least should be, more readable code, and I think even the current registry system is approaching the level of complexity that is not healthy in the longer run.
In any case, I'll let arrafin comment on this more, being his idea :)
Indeed, I agree with you on the complexity. This began as an attempt to reduce it but I got carried away ^^'.
I also agree on v.1.1 +.
should be closed by #842