stable-baselines3
stable-baselines3 copied to clipboard
[Feature Request] Saving rendered trajectories in evaluations
🚀 Feature
- Save rendered images in
evaluate_policy()
. - Save rendered images in
EvalCallback()
.
Motivation
In a headless server, the simplest way to examine the behavior of a stable-baselines3 policy is to set render_mode='rgb_array'
and do img_array = env.render()
at each step. However, saving these images requires some extra efforts, which could be cumbersome.
It would be great to support saving rendered images in the evaluate_policy()
and in an EvalCallback
which calls the former function periodically.
Pitch
In the evaluate_policy()
function, render=True
is not an option in a headless server. We may use callback
to pass a function to solve the problem.
def build_save_render_callback(path):
def save_render(locals_, globals_):
# NOTE: assume n_envs = 1 for now
env_idx = 0
episode_count = locals_['episode_counts'][env_idx]
episode_path = os.path.join(os.getcwd(), path, str(episode_count))
os.makedirs(episode_path, exist_ok=True)
current_length = locals_['current_lengths'][env_idx]
Image.fromarray(locals_['env'].envs[env_idx].render()).save(os.path.join(episode_path, str(current_length) + '.png'))
return save_render
callback=build_save_render_callback(path='/tmp/rendering')
evaluate_policy(model=model, env=env, n_eval_episodes=10, callback=callback)
Surely this works for me. But I am not sure if this is a good practice, or it might be better to refactor the evaluate_policy()
function to support this directly.
In the EvalCallback
class, the evaluate_policy
is called like this:
https://github.com/DLR-RM/stable-baselines3/blob/620e58e61f649d0f415b7796386d6fe405778026/stable_baselines3/common/callbacks.py#L460-L469
Note in the line 469, the self._log_success_callback
is passed:
https://github.com/DLR-RM/stable-baselines3/blob/620e58e61f649d0f415b7796386d6fe405778026/stable_baselines3/common/callbacks.py#L426-L440
To support saving the rendered images, we may create the self._save_render_callbak
and pass a list of callbacks.
callbacks=[self._log_success_callback, self._save_render_callback]
Then we slightly modify the evaluate_policy()
function so that it can handle multiple callbacks sequentially.
if callback is not None:
if isinstance(callback, list):
for callback_ in callback:
callback_(locals(), globals())
else:
callback(locals(), globals())
Now we may render and save trajectories in a headless server accompanying with a EvalCallback
.
Alternatives
No response
Additional context
I can submit a pull request later if this looks good to you, thanks.
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] If I'm requesting a new feature, I have proposed alternatives
Hello,
it seems that you need to use a VecVideoRecorder
as we do in the RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo/blob/8cecab429726d7e6aaebd261d26ed8fc23b7d948/rl_zoo3/record_video.py#L141-L147
In a headless server, the simplest way to examine the behavior of a stable-baselines3 polic
Another easy way is to save checkpoints and evaluate them on the side (that's also possible with the RL Zoo).
Bump for the returning of the rendered data from evaluate_policy
. I manipulate the rendered data directly before sending it to Weights&Biases, and saving it as a file to re-open it is wasteful.
Bump for the returning of the rendered data from evaluate_policy. I manipulate the rendered data directly before sending it to Weights&Biases, and saving it as a file to re-open it is wasteful.
I'm not sure what you mean, you can pass a callback and/or get the list of episodic return from evaluate_policy
, for something even more custom, you can define a custom version of it.
Bump for the returning of the rendered data from evaluate_policy. I manipulate the rendered data directly before sending it to Weights&Biases, and saving it as a file to re-open it is wasteful.
I'm not sure what you mean, you can pass a callback and/or get the list of episodic return from
evaluate_policy
, for something even more custom, you can define a custom version of it.
Sorry, I re-read the FE and I guess it's not exactly what I was looking for after all.
Essentially, evaluate_policy
calls env.render()
but if this is rgb_array
or rgb_array_list
nothing happens. It'd be great to have a return_render
flag to capture this data if desired.
Sorry, I re-read the FE and I guess it's not exactly what I was looking for after all.
FE?
Essentially, evaluate_policy calls env.render() but if this is rgb_array or rgb_array_list nothing happens. It'd be great to have a return_render flag to capture this data if desired.
That was my first point about VecVideoRecorder
, by wrapping the env, you have access to everything.