rlberry
rlberry copied to clipboard
feat (manager) Print information statistics in AgentManager
This PR adds a small function to AgentManager that prints some statistics, in particular bootstrap confidence intervals. Inspired in part by https://arxiv.org/abs/2108.13264. Example of script :
from rlberry.agents.torch import A2CAgent
from rlberry.manager import AgentManager, evaluate_agents
from rlberry.envs import gym_make
import numpy as np
import numpy as np
manager = AgentManager(
A2CAgent,
(gym_make, dict(id="CartPole-v1")),
agent_name="A2CAgent",
fit_budget=1e4,
eval_kwargs=dict(eval_horizon=500),
n_fit=8,
)
manager.fit()
manager.print_stats()
Result of this script:
Statistics of writer data collected on last iteration of each fit.
Means of values over 8 fits:
episode_rewards 170.625
total_episodes 157.625
dtype: float64
Medians of values over 8 fits:
episode_rewards 144.0
total_episodes 155.0
dtype: float64
Confidence interval of level 0.95 for the mean of episode_rewards over 8 fits:
[101.25, 249.38]
Confidence interval of level 0.95 for the mean of total_episodes over 8 fits:
[134.38, 180.25]
Mean number of steps per second is 130.97
Statistics of the evaluation of fitted agents
[INFO] Evaluating agent 0
[INFO] Evaluating agent 1
[INFO] Evaluating agent 2
[INFO] Evaluating agent 3
[INFO] Evaluating agent 4
[INFO] Evaluating agent 5
[INFO] Evaluating agent 6
[INFO] Evaluating agent 7
Means of mean evaluations over 8 fits (100 evaluations):
episode_rewards 170.625
total_episodes 157.625
dtype: float64
Medians of mean evaluations over 8 fits (100 evaluations):
Confidence interval of level 0.95 for the mean of mean evaluations over 8 fits:
[118.02, 236.24]
Computing these stats can be a bit time-consuming because I need to compute a large number of evaluation (the default is 100) for each of the fitted agent.
Of particular interest is the last confidence interval, which shows the "efficiency" of the fitted agents on evaluation. Remark that contrary to what is done in evaluate_agents
there is an aggregation phase (i.e. the 100 evaluations are aggregated using the mean, and the CI is over the 8 fits).
If you have any suggestion, feel free to comment. @mmcenta on this.