stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

RecurrentPPO: 9x speedup - whole sequence batching

Open b-vm opened this issue 3 years ago • 9 comments

Description

Moving from 2d batches to 3d batches of whole sequences leads to a 5-9 times speedup in terms of fps while keeping results similar. Proof.

Context

  • [x] I have raised an issue to propose this change (required)

Types of changes

Its currently implemented as an additional feature but would probably be more optimal to replace the original.

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [x] Breaking change (fix or feature that would cause existing functionality to change)
  • [ ] Documentation (update in the documentation)

Checklist:

  • [x] I've read the CONTRIBUTION guide (required)
  • [x] The functionality/performance matches that of the source (required for new training algorithms or training-related features).
  • [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
  • [ ] I have included an example of using the feature (required for new features).
  • [ ] I have included baseline results (required for new training algorithms or training-related features).
  • [ ] I have updated the documentation accordingly.
  • [ ] I have updated the changelog accordingly (required).
  • [ ] I have reformatted the code using make format (required)
  • [ ] I have checked the codestyle using make check-codestyle and make lint (required)
  • [ ] I have ensured make pytest and make type both pass. (required)

Note: we are using a maximum length of 127 characters per line

b-vm avatar Nov 28 '22 17:11 b-vm

@araffin have you been able to take a look at this yet? I am very curious what you think about it.

b-vm avatar Dec 16 '22 20:12 b-vm

have you been able to take a look at this yet? I am very curious what you think about it.

no, not yet, still on my stack... and going on holidays soon, so, I'll probably take a look next week or in january.

araffin avatar Dec 16 '22 22:12 araffin

Cool. Let me know if you need any help running experiments/coding

b-vm avatar Dec 29 '22 09:12 b-vm

Hello, I tried but couldn't test the PR, I got an error (before my changes) both with Pendulum and BipedalWalker:

Traceback (most recent call last):
  File "sb3_contrib/whole_sequence_speed_test.py", line 167, in <module>
    model.learn(2e5, tb_log_name=f"PendulumNoVel-v1_whole_sequences_batch_size{batch_size}")
  File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 505, in learn
    self.train()
  File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 361, in train
    values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
  File "sb3_contrib/sb3_contrib/common/recurrent/policies.py", line 372, in evaluate_actions_whole_sequence
    latent_pi, _ = self.lstm_actor(features)
  File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 810, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 730, in check_forward_args
    self.check_input(input, batch_sizes)
  File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 218, in check_input
    raise RuntimeError(
RuntimeError: input.size(-1) must be equal to input_size. Expected 3, got 6

araffin avatar Apr 03 '23 13:04 araffin

My bad. Bug is fixed now!

b-vm avatar Apr 12 '23 09:04 b-vm

I had to set drop_last=False sometimes, otherwise I was getting error due to the fact nothing was sampled: UnboundLocalError: local variable 'loss' referenced before assignment

To reproduce:

python -m rl_zoo3.train --algo ppo_lstm --env PendulumNoVel-v1 -params whole_sequences:True use_sde:False
python -m rl_zoo3.train --algo ppo_lstm --env CartPoleNoVel-v1 -params whole_sequences:True

On CartPole, I have another error:

Traceback (most recent call last):
  File "torchy-zoo/train.py", line 4, in <module>
    train()
  File "torchy-zoo/rl_zoo3/train.py", line 267, in train
    exp_manager.learn(model)
  File "torchy-zoo/rl_zoo3/exp_manager.py", line 236, in learn
    model.learn(self.n_timesteps, **kwargs)
  File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 521, in learn
    self.train()
  File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 377, in train
    values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
  File "sb3_contrib/sb3_contrib/common/recurrent/policies.py", line 387, in evaluate_actions_whole_sequence
    log_prob = distribution.distribution.log_prob(actions).sum(dim=-1)
  File "mambaforge/lib/python3.10/site-packages/torch/distributions/categorical.py", line 123, in log_prob
    self._validate_sample(value)
  File "mambaforge/lib/python3.10/site-packages/torch/distributions/distribution.py", line 288, in _validate_sample
    raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([32, 15, 1]) vs torch.Size([32, 15]).

Also, SDE seems not supported (that's ok, but need to be checked at runtime).

Finally, I experienced some NaN issue from time to time when drop_last=False (I fixed that by deactivating advantage normalization) :

ValueError: Expected parameter loc (Tensor of shape (4, 1)) of distribution Normal(loc: torch.Size([4, 1]), scale: torch.Size([4, 1])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan],
        [nan],
        [nan],
        [nan]])

araffin avatar Apr 27 '23 09:04 araffin

Also an error when using CNN:

python train.py --algo ppo_lstm --env CarRacing-v2 -P --n-eval-envs 5 --eval-episodes 20 -params batch_size:8 whole_sequences:True
    self.train()
  File "/home/antonin/Documents/rl/sb3-contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 377, in train
    values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/Documents/rl/sb3-contrib/sb3_contrib/common/recurrent/policies.py", line 371, in evaluate_actions_whole_sequence
    features = self.extract_features(obs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/policies.py", line 640, in extract_features
    return super().extract_features(obs, self.features_extractor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/policies.py", line 131, in extract_features
    return features_extractor(preprocessed_obs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/torch_layers.py", line 106, in forward
    return self.linear(self.cnn(observations))
                       ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [281, 8, 2, 64, 64]

araffin avatar Apr 27 '23 10:04 araffin

On CartPole, I have another error:

The error for CartPole seems to be still there...

araffin avatar Oct 06 '23 09:10 araffin

Yes, it has only been implemented for Box action spaces so that might be it.

I have not much time to work on this anymore. So feel free to do it.

b-vm avatar Oct 15 '23 22:10 b-vm