[Question] Why do the hidden & cell shapes of an LSTM actor change during training with a single environment?
❓ Question
First, thank you for the great work on this package (and happy Friday)!
I have a question about the way sequences are handled in RecurrentRolloutBuffer. The context is that I am trying to extract the hidden and cell states at every time step from an LSTM actor, but during training the shapes of the internal hidden and cell states change, and I don't quite understand why since I'm using a single environment (which seems to be a factor, I've detailed the steps below - 🐻 with me).
Here is my minimal setup:
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from stable_baselines3.common.env_checker import check_env
from sb3_contrib import RecurrentPPO
import gymnasium as gym
import torch
from torch import nn
# Forward hook for the actor LSTM
step = 0
def fw_hook(
module: nn.Module,
input_: torch.Tensor,
output: torch.Tensor,
):
global step
step += 1
hidden = output[1][0]
print(f"[ {step:3d} ] Hidden shape: {hidden.shape} | Training mode: {module.training}")
# Uncomment the following to stop the first time shape[1] is not 1
# if hidden.shape[1] > 1:
# print(f"Anomalous shape encountered")
# for s in range(hidden.shape[1]):
# print(f"Slice {s}:\n{hidden[:,s]}")
# raise SystemExit()
# Generic environment
env = gym.make("Pendulum-v1")
# Agent
agent = RecurrentPPO(
RecurrentActorCriticPolicy,
env,
policy_kwargs={"lstm_hidden_size": 64},
)
# Register the forward hook with the actor
agent.policy.lstm_actor.register_forward_hook(fw_hook)
# Train the agent.
# This should print the shape of the hidden state tensor every time the agent's `forward()` method is called
agent.learn(total_timesteps=250)
This produces the following output:
Expand output
...truncated...
[ 127 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False [ 128 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False [ 129 ] Hidden shape: torch.Size([1, 2, 64]) | Training mode: True # Two sequences... [ 130 ] Hidden shape: torch.Size([1, 2, 64]) | Training mode: True [ 131 ] Hidden shape: torch.Size([1, 2, 64]) | Training mode: True [ 132 ] Hidden shape: torch.Size([1, 2, 64]) | Training mode: True [ 133 ] Hidden shape: torch.Size([1, 2, 64]) | Training mode: True
...truncated...
[ 1286 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False [ 1287 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False [ 1288 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False [ 1289 ] Hidden shape: torch.Size([1, 3, 64]) | Training mode: True # Three sequences! [ 1290 ] Hidden shape: torch.Size([1, 3, 64]) | Training mode: True # See bonus question below. [ 1291 ] Hidden shape: torch.Size([1, 3, 64]) | Training mode: True [ 1292 ] Hidden shape: torch.Size([1, 3, 64]) | Training mode: True [ 1293 ] Hidden shape: torch.Size([1, 3, 64]) | Training mode: True
...truncated...
The second dimension of the shape varies between 1 and 3, which makes it difficult to determine what the actual hidden state is at the current time step. The slices are not identical - breaking the first time shape[1] is not 1 produces the following:
Expand output
...truncated...
[ 126 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False [ 127 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False [ 128 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False [ 129 ] Hidden shape: torch.Size([1, 2, 64]) | Training mode: True Anomalous shape encountered Slice 0: tensor([[-0.0829, -0.1313, -0.0784, -0.0005, 0.1998, -0.1344, 0.0718, 0.0770, 0.0633, -0.1303, -0.1619, -0.0804, 0.0221, 0.0895, 0.1033, 0.1790, -0.0631, -0.0337, 0.0096, 0.0497, 0.1222, -0.0074, 0.0711, -0.0137, -0.0203, -0.0729, 0.1066, -0.1037, 0.2195, -0.1407, 0.0269, -0.0968, -0.1034, 0.1163, -0.0931, 0.0071, 0.0914, -0.0213, 0.1505, 0.0700, 0.1196, 0.0809, -0.0368, -0.0342, 0.0384, -0.1669, 0.0109, 0.1535, -0.0206, -0.1599, -0.0975, -0.0114, -0.0040, 0.0757, 0.0590, -0.0663, 0.0200, -0.0264, -0.2038, -0.2178, -0.0444, -0.1478, 0.0482, -0.0903]], grad_fn=<SelectBackward0>) Slice 1: tensor([[-0.0356, -0.0018, 0.0154, 0.0565, -0.0229, 0.0009, 0.0023, 0.0469, 0.0245, 0.0436, 0.0444, 0.0029, -0.0717, -0.0088, -0.0291, 0.0016, 0.0177, -0.0729, -0.0166, -0.0091, -0.0412, -0.0033, 0.0781, -0.0058, 0.0177, -0.0253, 0.0068, -0.0249, -0.0008, 0.0197, -0.0673, 0.0052, -0.0183, -0.0185, 0.0483, 0.0335, -0.0625, -0.0217, -0.0155, -0.0839, -0.0129, 0.0036, 0.0642, -0.0446, -0.0107, -0.0002, -0.0331, -0.0013, -0.0827, -0.0065, 0.0314, 0.0236, 0.0364, 0.0125, 0.0256, 0.0098, 0.0923, -0.0712, -0.0166, -0.0073, 0.0583, 0.0363, -0.0390, 0.0089]], grad_fn=<SelectBackward0>)
After some digging, it seems that this is caused at least partially by create_sequencers L64 in sb3_contrib/common/recurrent/buffers.py. The flow is as follows:
agent.learn()callsOnPolicyAlgorithm.learn()in SB3, which in turn callsRecurrentPPO.collect_rollouts()(L324).RecurrentPPO.collect_rollouts()callsRecurrentPPO.policy.forward()on L242.- This calls
RecurrentActorCriticPolicy._process_sequence()(L237). - Finally, inside
RecurrentActorCriticPolicy._process_sequence(),nseqis set to whatever the second dimension of the hidden state is (L182).
Now, at the beginning this produces the expected outcome because the condition on L191 is satisfied for the entire duration of the while loop in RecurrentPPO.collect_rollouts() (L233). So for the first 128 steps we are just collecting rollouts with gradients disabled, and the shape is [1,1,64] and training mode is False, as confirmed by the output above. So far, so good.
But then we hit OnPolicyAlgorithm.train() (SB3 L337), resp. L345 in RecurrentPPO, which sends us to RecurrentRolloutBuffer.get() L147.
Here is where my understanding starts to falter. The code from L184 onwards employs a 'shuffling trick' for minibatch sampling. Fast-forward to the yield statement on L196 -> _get_samples() on L199 -> create_sequencers L206, we get to [L82] (I've pasted lines 81-89 here for convenience):
# Create sequence if env changes too
seq_start = np.logical_or(episode_starts, env_change).flatten()
# First index is always the beginning of a sequence
seq_start[0] = True
# Retrieve indices of sequence starts
seq_start_indices = np.where(seq_start == True)[0] # noqa: E712
# End of sequence are just before sequence starts
# Last index is also always end of a sequence
seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])])
Because of seq_start[0] = True (actually, because first element of env_change was already marked as 1.0, this is probably redundant), seq_start now contains two elements that are True (the original split_index and index 0). So this means that seq_start_indices is of length 2, and since n_seq is assigned to the length of seq_start_indices on L354, it ends up evaluating to 2, and this propagates all the way to the second dimension of the LSTM actor (and critic).
So my questions are:
- Why is this 'shuffling trick' employed in
RecurrentRolloutBuffer.get()? It would be good to add some logic to disable it on demand. - Is
create_sequencers()working as intended? - Why do we assume that we are using more than one environment even when this is not the case (I hope that my original comment about the number of environments at the very beginning makes sense now...)? In other words, is the use of the shuffling trick and the
create_sequencers()function in_get_samples()justified when we are using a single environment? - Importantly, how do I get the hidden & cell states during training? Maybe this doesn't even make sense with this setup, but it would be good to understand that. This shuffling trick not only messes with the tensor shapes, but also makes it impossible to know what time steps those apply to.
- Bonus question: how do we end up with three sequences (e.g., line
1289onwards in the output above)? I understand how we end up with two (Also, based on the comment on L180, I think the intention is to have only two), but I'm not sure how this would result in three with a single environment.
Thank you, and apologies about this essay, this question ended up a lot longer than I intended it to be.
Checklist
- [x] I have checked that there is no similar issue in the repo
- [x] I have read the documentation
- [x] If code there is, it is minimal and working
- [x] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
related https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/284#issuecomment-2766526133
Why is this 'shuffling trick' employed in RecurrentRolloutBuffer.get()? It would be good to add some logic to disable it on demand.
As written in https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/284#issuecomment-2766526133 and other related issues, this was an attempt to add a bit of randomization. So that PPO doesn't see the same sequences in the same order all the time. In practice, I didn't have time to check the influence of it, my hypothesis is that it might not be needed (so being able to disable it would be good).
Is create_sequencers() working as intended? now contains two elements that are True
it looks ok so far. If we split the sequence in two, we should have two sequences, no?
Why do we assume that we are using more than one environment
SB3 relies on VecEnv which means most of the code should not act differently when using one or multiple env.
the shuffling trick
See first answer, this should be unrelated to the number of envs.
how do I get the hidden & cell states during training?
I'm not sure what is your end code, but suclassing/forking SB3 might be the best option here.
how this would result in three with a single environment
The number of sequences should be independent to the number of envs, no?
In the sense that we collect n_steps transitions and then we cut whatever sequences we had (one sequence = one episode, normally unrelated to the number of envs).