brax
brax copied to clipboard
Saving out the best policy
Hi Brax,
I'm not sure if this is an issue, but its just tripped me up, so I figured I'd put it here.
In the RL demo (e.g. here), the parameters returned from the training loop appear to be the final parameters, as opposed to the best parameters as evaluated during the training run.* So even if you are evaluating the policy during training, the best parameters are lost.
There doesn't seem to be a way of getting them out of the training loop as is. progress_function (which seems most similar to StableBaselines callbacks) only accepts the number of steps and the metrics, and so there doesn't seem to be a place where you can manually "extract" the best parameters.
Ideally, progress function would also accept as arguments params** (or at least the raw arguments passed in to the evaluator), so that the user can then implement any logic in the progress_function to inspect the policy, save out the best policy, and maybe even make code re-entrant.
My request:
Modify the call signature of progress_function from progress_fn: Callable[[int, Metrics], None] to progress_fn: Callable[[int, Metrics, Optional[ParamType]], None] (I confess, I don't know what the correct general types for the optional params would be), such that I can define a progress function with a header of def progress_fn(num_steps, metrics, params):.
There may already be a way of doing this, but I cannot see how.
Thanks! Andy
* Given determinism, or close-to determinism, maybe you could re-run it with early stopping, but that seems a bit archaic.
** Where params is defined as params = _unpmap((training_state.normalizer_params, training_state.params.policy))