stable-baselines3
stable-baselines3 copied to clipboard
[feature request] Low-level API
Having a lower level API can be helpful for some users to have more control over the training loop. The low-level API is particularly useful in Hierarichal RL or Multi-Agent RL problems. One idea is to provide an API like this repo from @eleurent:
for episode in episodes:
while not done:
action = agent.act(state)
next_state, reward, done, info = env.step(action)
agent.record(state, action, next_state, reward, done, info)
Hello,
Why not, but I would say you already have one:
action = agent.act(state)
It is called predict(): action, _ = agent.predict(state)
agent.record(state, action, next_state, reward, done, info)
As mentioned in the developer guide (cf doc), if you are rewriting the training loop (take a look at collect_rollout(), you have access to agent.rollout_buffer or agent.replay_buffer which does what you want.
Finally, the gradient update is located in agent.train().
Are these available in SB2 as well?
Are these available in SB2 as well?
Yes for the predict (cf doc).
For the rest, more or less, it is a bit more messy.
The train() corresponds to the _train_step() usually and the collect_rollout() is done inside the learn().
Thanks. I think it is worth adding a tutorial about using low-level API to docs.
Agreed, this question came up in sb repository quite often. Another related thing we could do is to make getting action probabilities/values bit easier / clarified.
Hello,
Why not, but I would say you already have one:
action = agent.act(state)It is called
predict():action, _ = agent.predict(state)
agent.record(state, action, next_state, reward, done, info)As mentioned in the developer guide (cf doc), if you are rewriting the training loop (take a look at
collect_rollout(), you have access toagent.rollout_bufferoragent.replay_bufferwhich does what you want. Finally, the gradient update is located inagent.train().
@araffin could you explain more how you would use the rollout buffer to do this?
Different updates regarding this issue:
Agreed, this question came up in sb repository quite often. Another related thing we could do is to make getting action probabilities/values bit easier / clarified.
We now provide policy.obs_to_tensor (for all algorithms), policy.predict_values() and policy.get_distribution() (for on-policy algorithms).
@araffin could you explain more how you would use the rollout buffer to do this?
We have a developer guide on where to look when you want to dive deeper into SB3. I won't give more details that "please read the code and learning about the algorithms" as you really need to understand how algorithms work when you are accessing/changing the low-level logic.