[feature request] LstmPolicy does not support using net_arch with feature_extraction="cnn"
Currently, the code for common.policies.LstmPolicy does not support using the "new" (apparently) net_arch with the setting feature_extraction="cnn". I was wondering if there is a reason for this?
More specifically, the above LstmPolicy class has this snippet on lines 438-442:
if feature_extraction == "cnn":
raise NotImplementedError()
with tf.variable_scope("model", reuse=reuse):
latent = tf.layers.flatten(self.processed_obs)
whereas earlier on in the same class, in the net_arch is None block we have, on lines 416-420:
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
extracted_features = cnn_extractor(self.processed_obs, **kwargs)
else:
extracted_features = tf.layers.flatten(self.processed_obs)
This suggests (naively, at least?) that perhaps we could support feature_extraction="cnn" in this block by changing the first snippet to:
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
latent = cnn_extractor(self.processed_obs, **kwargs)
else:
latent = tf.layers.flatten(self.processed_obs)
Would that change work as expected and not break anything else? If so, I'm happy to submit a PR. If it won't work, I'd love to understand why not :)
Hmm on a quick glimpse you might be right, it might be this trivial especially if things are limited to having CNN only at the very beginning.
Normally we would not take new features for stable-baselines (in favor of supporting stable-baselines3), but if you get things working as trivially as you mentioned, I believe we can have a PR to add this feature (@araffin comments? Seems like few-line change + docs for a seemingly obvious thing). Please post here on this issue with your findings before creating the PR so we can decide if it is simple enough.
@Miffyli @araffin It certainly runs with that change - not sure if that's what you meant by "if you get things working", and whether there's a straightforward way to validate that it's doing the right thing...
Seems like few-line change + docs for a seemingly obvious thing
well, as always, I think we did not investigate that too much as MLP + Framestacking is usually both faster and better (in term of permance) than lstm.
f you get things working", and whether there's a straightforward way to validate that it's doing the right thing...
it means that it works on a non-trivial environment that requires memory (i.e. that is partially observable). One example would be pong without frame-stacking I would say.