stable-baselines3
stable-baselines3 copied to clipboard
[Feature Request] Customize the features_extractor from a static or pre-trained model
🚀 Feature
While instantiating an RL algorithm, e.g., PPO, a static or pre-trained model can be passed through as a features_extractor.
Motivation
Currently, one can customize the feature extractor by specifying certain
policy_kwargs = dict(
features_extractor_class=...,
featuress_extractor_kwargs=...)
However, there is no clear way to input a pre-trained model as a given feature extractor or just a static one.
I noticed that in https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/policies.py#L63 there was a parameter called features_extractor apart from features_extractor_class. Nevertheless, this parameter became a default one later in BasePolicy and unchangeable in ActorCriticPolicy:
features_extractor: BaseFeaturesExtractor
In my case, I need to train two RL models where the second model shares the same representation module with the first one but they do not have to be trained and evaluated at the same time. So it would be beneficial for me if I can train a whole model for the first model and then just load the feature extractor module trained in the first model to the second model and freeze it.
Pitch
The desired usage would be
pre_trained_features_extractor = torch.load(...)
policy_kwargs = dict(features_extractor=pre_trained_features_extractor)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
where the parameters of feature_extractor will no longer be updated.
Alternatives
Currently, I am getting a detour by specifying another mode parameter in the customized feature extractor.
class CustomFeaturesExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space, features_dim, ...):
super().__init__(observation_space, features_dim)
...
def set_mode(self, mode):
if mode == 'eval':
for param in self.parameters():
param.requires_grad = False
Additional context
No response
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] If I'm requesting a new feature, I have proposed alternatives
Hello, a VecEnv/gym wrapper would do the job in your case, no?
Hello, a VecEnv/gym wrapper would do the job in your case, no?
Appreciate your prompt reply! In practice, I guess your way could work, but in principle, the representation ability is supposed to be part of the agent's functionality instead of properties of the (wrapped) environment. In fact, I think there are a few ways to get around including the one you suggested and what I wrote in the Alternatives. But to me, it still seems that there is an interface in the base class but we haven't make use of it.
Anyways, I guess this is not an urgent feature request :-), and can be put aside for now if you guys are too busy, or I can even start a PR when I have some time 😎
In practice, I guess your way could work, but in principle, the representation ability is supposed to be part of the agent's functionality instead of properties of the (wrapped) environment.
if you don't do it that way, you will waste computation and memory: the observation will be stored as an image instead of a feature vector, and the feature vector will be recomputed several times.
In fact, I think there are a few ways to get around including the one you suggested and what I wrote in the Alternatives.
yes, those are also valid (although not as efficient as a wrapper), so I also think there is currently no real need for a specific feature in SB3.