stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

VecHerReplayBuffer

Open araffin opened this issue 3 years ago • 8 comments

Description

Motivation and Context

  • [x] I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to change)
  • [ ] Documentation (update in the documentation)

Checklist:

  • [x] I've read the CONTRIBUTION guide (required)
  • [ ] I have updated the changelog accordingly (required).
  • [ ] My change requires a change to the documentation.
  • [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
  • [ ] I have updated the documentation accordingly.
  • [ ] I have reformatted the code using make format (required)
  • [ ] I have checked the codestyle using make check-codestyle and make lint (required)
  • [ ] I have ensured make pytest and make type both pass. (required)
  • [ ] I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

araffin avatar Nov 06 '21 21:11 araffin

Hi, any updates regarding this PR? I am working on a project that will like to use VecHerReplayBuffer.

thiakx avatar Dec 14 '21 16:12 thiakx

Hi, any updates regarding this PR? I am working on a project that will like to use VecHerReplayBuffer.

I would welcome help on that PR. You can try it by installing SB3 from source (cf. doc) but there is no warranty regarding performance. The current PR runs and seems to work but it is not polished enough to be integrated in the master branch.

araffin avatar Dec 14 '21 16:12 araffin

Hi, I would be interested too! What kind of polishing would the PR need?

ischubert avatar Dec 14 '21 22:12 ischubert

Hi, I would be interested too! What kind of polishing would the PR need?

First of all, more tests, especially for the saving/loading part of the buffer. Then print nice errors for some edge cases (for instance when the buffer size is not big enough to use all buffers).

And finally more performance tests (can be done with the RL Zoo) to check that nothing is broken by this implementation.

If you are working on that, please create PR that build on that one ;)

araffin avatar Dec 15 '21 12:12 araffin

I have ran the master from this branch, and it looks like parallel environments are not working well for FetchPush and FetchPickAndPlace. Happy to help contribute to this PR.

To reproduce:

  1. If you run the sb-zoon via

    train.py --algo sac --env FetchPush-v1
    

    This one launches with just 1 environment. I have confirmed that this reproduces the original learning curve.

  2. I have added our launch and logging infrastructure and have reproduced the identical learning curve with just 1 parallel environment using the VecHerReplayBuffer (see below.)

    Screen Shot 2021-12-18 at 12 15 29 PM
  3. When you use multiple environments, however, Reach learns, but FetchPush and PickAndPlace do not learn.

    Screen Shot 2021-12-18 at 12 14 27 PM

I am running the following

from params_proto.neo_proto import ParamsProto

# strict is needed to overload for various algos.
class Args(ParamsProto, strict=False):
    env_name = "FetchReach-v1"
    algo = "sac"
    n_envs = 1
    n_timesteps = 10_000
    lr = 1e-3
    batch_size = 2048

    gamma = 0.95
    tau = 0.05

    # SAC parameters
    ent_coef = "auto"

    seed = 100


def train(**kwargs):
    from ml_logger import logger

    Args._update(kwargs)
    logger.log_params(Args=vars(Args))
    logger.log_text("""
        keys:
        - job.status
        - Args.env_name
        - Args.n_envs
        - Args.batch_size
        - Args.seed
        charts:
        - yKey: rollout/success_rate
          xKey: step
        """, filename=".charts.yml", overwrite=True)

    logger.job_started()

    env = make_vec_env(env_id=Args.env_name, n_envs=Args.n_envs, seed=Args.seed)

    # todo: Need to add EvalCallback to collect test-time data.

    # Initialize the model
    model = globals()[Args.algo.upper()](
        policy='MultiInputPolicy',
        policy_kwargs=dict(n_critics=2, net_arch=[512, 512, 512]),
        env=env,
        buffer_size=1000_000,
        ent_coef=Args.ent_coef,
        batch_size=Args.batch_size,
        gamma=Args.gamma,
        tau=Args.tau,
        learning_rate=Args.lr,
        learning_starts=1000,
        replay_buffer_class=VecHerReplayBuffer,
        replay_buffer_kwargs=dict(
            online_sampling=True,
            goal_selection_strategy='future',
            n_sampled_goal=4,
        ),
        verbose=1,
    )

    # Train the model
    model.learn(total_timesteps=Args.n_timesteps, )

geyang avatar Dec 18 '21 17:12 geyang

Hello @geyang, as mentioned in the doc (multi env with off policy algo example), you probably need to update gradient_steps variable to match the number of envs

araffin avatar Dec 19 '21 10:12 araffin

Hi, I'm working on this feature, and it's not clear why there is a class dedicated to the case where n_envs == 1: HerReplayBuffer , and another one for the cases where n_envs > 1: VecHerReplayBuffer. Having a single class that handles both cases would avoid having to make the following distinctions:

  • https://github.com/DLR-RM/stable-baselines3/blob/8ed3562cdb68f69795668026b5d1cee925fa5798/stable_baselines3/common/off_policy_algorithm.py#L204
  • https://github.com/DLR-RM/stable-baselines3/blob/8ed3562cdb68f69795668026b5d1cee925fa5798/stable_baselines3/common/off_policy_algorithm.py#L189
  • https://github.com/DLR-RM/stable-baselines3/blob/8ed3562cdb68f69795668026b5d1cee925fa5798/stable_baselines3/common/off_policy_algorithm.py#L269

Also, I find that it causes misleading naming, since having HerReplayBuffer and VecHerReplayBuffer suggests HerReplayBuffer is not vectorized, when it is, it just doesn't work when there is more than one environment.

What is your opinion about this?

qgallouedec avatar Dec 22 '21 10:12 qgallouedec

Having a single class that handles both cases would avoid having to make the following distinctions:

I agree, that would be awesome if we had one class that handle both, but I'm afraid of the complexity of it, so I went for the easy solution for now (to have a minimal proof of concept). Anyway, feel free to submit a draft PR, as I won't have time to work on that one any time soon.

Also, I find that it causes misleading naming, since having HerReplayBuffer and VecHerReplayBuffer suggests HerReplayBuffer is not vectorized, when it is, it just doesn't work when there is more than one environment.

Well, every env in SB3 is a VecEnv (even when num_envs=1).

araffin avatar Dec 22 '21 10:12 araffin

Closing in favor of https://github.com/DLR-RM/stable-baselines3/pull/704

araffin avatar Mar 13 '23 17:03 araffin