stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[feature request] Low-level API

Open mhtb32 opened this issue 5 years ago • 7 comments

Having a lower level API can be helpful for some users to have more control over the training loop. The low-level API is particularly useful in Hierarichal RL or Multi-Agent RL problems. One idea is to provide an API like this repo from @eleurent:


for episode in episodes:
    while not done:
        action = agent.act(state)
        next_state, reward, done, info = env.step(action)
        agent.record(state, action, next_state, reward, done, info)

mhtb32 avatar Jun 10 '20 17:06 mhtb32

Hello,

Why not, but I would say you already have one:

action = agent.act(state)

It is called predict(): action, _ = agent.predict(state)

agent.record(state, action, next_state, reward, done, info)

As mentioned in the developer guide (cf doc), if you are rewriting the training loop (take a look at collect_rollout(), you have access to agent.rollout_buffer or agent.replay_buffer which does what you want. Finally, the gradient update is located in agent.train().

araffin avatar Jun 10 '20 18:06 araffin

Are these available in SB2 as well?

mhtb32 avatar Jun 10 '20 18:06 mhtb32

Are these available in SB2 as well?

Yes for the predict (cf doc). For the rest, more or less, it is a bit more messy. The train() corresponds to the _train_step() usually and the collect_rollout() is done inside the learn().

araffin avatar Jun 10 '20 18:06 araffin

Thanks. I think it is worth adding a tutorial about using low-level API to docs.

mhtb32 avatar Jun 11 '20 07:06 mhtb32

Agreed, this question came up in sb repository quite often. Another related thing we could do is to make getting action probabilities/values bit easier / clarified.

Miffyli avatar Jun 11 '20 08:06 Miffyli

Hello,

Why not, but I would say you already have one:

action = agent.act(state)

It is called predict(): action, _ = agent.predict(state)

agent.record(state, action, next_state, reward, done, info)

As mentioned in the developer guide (cf doc), if you are rewriting the training loop (take a look at collect_rollout(), you have access to agent.rollout_buffer or agent.replay_buffer which does what you want. Finally, the gradient update is located in agent.train().

@araffin could you explain more how you would use the rollout buffer to do this?

LittleRobertTables avatar Nov 29 '21 01:11 LittleRobertTables

Different updates regarding this issue:

Agreed, this question came up in sb repository quite often. Another related thing we could do is to make getting action probabilities/values bit easier / clarified.

We now provide policy.obs_to_tensor (for all algorithms), policy.predict_values() and policy.get_distribution() (for on-policy algorithms).

@araffin could you explain more how you would use the rollout buffer to do this?

We have a developer guide on where to look when you want to dive deeper into SB3. I won't give more details that "please read the code and learning about the algorithms" as you really need to understand how algorithms work when you are accessing/changing the low-level logic.

araffin avatar Nov 29 '21 09:11 araffin