RecurrentPPO: 9x speedup - whole sequence batching
Description
Moving from 2d batches to 3d batches of whole sequences leads to a 5-9 times speedup in terms of fps while keeping results similar. Proof.
Context
- [x] I have raised an issue to propose this change (required)
Types of changes
Its currently implemented as an additional feature but would probably be more optimal to replace the original.
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Documentation (update in the documentation)
Checklist:
- [x] I've read the CONTRIBUTION guide (required)
- [x] The functionality/performance matches that of the source (required for new training algorithms or training-related features).
- [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
- [ ] I have included an example of using the feature (required for new features).
- [ ] I have included baseline results (required for new training algorithms or training-related features).
- [ ] I have updated the documentation accordingly.
- [ ] I have updated the changelog accordingly (required).
- [ ] I have reformatted the code using
make format(required) - [ ] I have checked the codestyle using
make check-codestyleandmake lint(required) - [ ] I have ensured
make pytestandmake typeboth pass. (required)
Note: we are using a maximum length of 127 characters per line
@araffin have you been able to take a look at this yet? I am very curious what you think about it.
have you been able to take a look at this yet? I am very curious what you think about it.
no, not yet, still on my stack... and going on holidays soon, so, I'll probably take a look next week or in january.
Cool. Let me know if you need any help running experiments/coding
Hello, I tried but couldn't test the PR, I got an error (before my changes) both with Pendulum and BipedalWalker:
Traceback (most recent call last):
File "sb3_contrib/whole_sequence_speed_test.py", line 167, in <module>
model.learn(2e5, tb_log_name=f"PendulumNoVel-v1_whole_sequences_batch_size{batch_size}")
File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 505, in learn
self.train()
File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 361, in train
values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
File "sb3_contrib/sb3_contrib/common/recurrent/policies.py", line 372, in evaluate_actions_whole_sequence
latent_pi, _ = self.lstm_actor(features)
File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 810, in forward
self.check_forward_args(input, hx, batch_sizes)
File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 730, in check_forward_args
self.check_input(input, batch_sizes)
File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 218, in check_input
raise RuntimeError(
RuntimeError: input.size(-1) must be equal to input_size. Expected 3, got 6
My bad. Bug is fixed now!
I had to set drop_last=False sometimes, otherwise I was getting error due to the fact nothing was sampled:
UnboundLocalError: local variable 'loss' referenced before assignment
To reproduce:
python -m rl_zoo3.train --algo ppo_lstm --env PendulumNoVel-v1 -params whole_sequences:True use_sde:False
python -m rl_zoo3.train --algo ppo_lstm --env CartPoleNoVel-v1 -params whole_sequences:True
On CartPole, I have another error:
Traceback (most recent call last):
File "torchy-zoo/train.py", line 4, in <module>
train()
File "torchy-zoo/rl_zoo3/train.py", line 267, in train
exp_manager.learn(model)
File "torchy-zoo/rl_zoo3/exp_manager.py", line 236, in learn
model.learn(self.n_timesteps, **kwargs)
File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 521, in learn
self.train()
File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 377, in train
values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
File "sb3_contrib/sb3_contrib/common/recurrent/policies.py", line 387, in evaluate_actions_whole_sequence
log_prob = distribution.distribution.log_prob(actions).sum(dim=-1)
File "mambaforge/lib/python3.10/site-packages/torch/distributions/categorical.py", line 123, in log_prob
self._validate_sample(value)
File "mambaforge/lib/python3.10/site-packages/torch/distributions/distribution.py", line 288, in _validate_sample
raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([32, 15, 1]) vs torch.Size([32, 15]).
Also, SDE seems not supported (that's ok, but need to be checked at runtime).
Finally, I experienced some NaN issue from time to time when drop_last=False (I fixed that by deactivating advantage normalization) :
ValueError: Expected parameter loc (Tensor of shape (4, 1)) of distribution Normal(loc: torch.Size([4, 1]), scale: torch.Size([4, 1])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan],
[nan],
[nan],
[nan]])
Also an error when using CNN:
python train.py --algo ppo_lstm --env CarRacing-v2 -P --n-eval-envs 5 --eval-episodes 20 -params batch_size:8 whole_sequences:True
self.train()
File "/home/antonin/Documents/rl/sb3-contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 377, in train
values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/Documents/rl/sb3-contrib/sb3_contrib/common/recurrent/policies.py", line 371, in evaluate_actions_whole_sequence
features = self.extract_features(obs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/policies.py", line 640, in extract_features
return super().extract_features(obs, self.features_extractor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/policies.py", line 131, in extract_features
return features_extractor(preprocessed_obs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/torch_layers.py", line 106, in forward
return self.linear(self.cnn(observations))
^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
^^^^^^^^^^^^^
File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [281, 8, 2, 64, 64]
On CartPole, I have another error:
The error for CartPole seems to be still there...
Yes, it has only been implemented for Box action spaces so that might be it.
I have not much time to work on this anymore. So feel free to do it.