tianshou
tianshou copied to clipboard
Stopping function on alternative criteria
It woulld be nice if we had access to the info
dict in the stop callback.
Presently the stopping callback in trainer looks like this:
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
But I would like to be able to early stop based on some other criteria, for example maybe some cost
key in the info dict e.g.:
def stop_fn(rewards, info):
return sum(info['cost']) <= env.spec.cost_constraint
Is there any other criteria instead of reward and info? I'm thinking about completeness.
Probably all the things that env step returns? reward, state, info?
Initially I expected to just have a pointer to the replay buffer, but then I realized that it's the testing phase, so it probably doesn't store it there.
How about learning metrics?
Learning metrics?
Hmm... after rethinking, learning metrics are seldom treated as a stopping criterion. Feel free to submit a PR :)
This is very complicated, I don't want to mess up stuff. For example, collector itself doesn't return the info objects: https://github.com/thu-ml/tianshou/blob/fc251ab0b85bf3f0de7b24c1c553cb0ec938a9ee/tianshou/data/collector.py#L317
It's also not clear to me how result["rews"]
(here) gets summed. It seems like this is just a list of rewards from the episode - but in the stop function it's used as a single value:
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
It's also not clear to me how result["rews"] (here) gets summed. It seems like this is just a list of rewards from the episode - but in the stop function it's used as a single value:
fixed in https://github.com/thu-ml/tianshou/pull/459/commits/be077afbc1c1ef7f76c971ee70403ee2a4fe1d0d
For example, collector itself doesn't return the info objects:
because info is not gatherable, and return all infos in collector.collector return dict is ugly...
@drozzy any further thought on this API design?
Here is an example of where it might be useful: https://github.com/openai/safety-gym#getting-started
Stop when costs are below some threshold.
>>> info
{'cost_hazards': 0.0, 'cost': 0.0}
I'm happy if only certain keys of info are gatherable.
Hack right now for me is I just don't do early stopping haha.
I would like to calculate other criteria (e.g. success rate in some tasks) other than rewards. I wonder how we can do this or hack this. @drozzy
This hack is kind of messy but I think it is the most efficient way...
- use
preprocess_fn
in collector to get the statistical result for success rate, e.g.,
# in the main.py, when initializing collector
from tianshou.utils import MovAvg
rates = [MovAvg() for _ in range(training_num)]
def preprocess_fn(info, env_id, **kwargs):
succ_rate = info["success_rate"] # this is a numpy array
for eid, r in zip(env_id, succ_rate):
rates[eid].add(r)
collector = Collector(..., preprocess_fn=preprocess_fn)
- change
stop_fn
:
def stop_fn(mean_rewards):
return np.mean([m.get() for m in rates]) >= threshold
This hack is kind of messy but I think it is the most efficient way...
- use
preprocess_fn
in collector to get the statistical result for success rate, e.g.,# in the main.py, when initializing collector from tianshou.utils import MovAvg rates = [MovAvg() for _ in range(training_num)] def preprocess_fn(info, env_id, **kwargs): succ_rate = info["success_rate"] # this is a numpy array for eid, r in zip(env_id, succ_rate): rates[eid].add(r) collector = Collector(..., preprocess_fn=preprocess_fn)
- change
stop_fn
:def stop_fn(mean_rewards): return np.mean([m.get() for m in rates]) >= threshold
Thanks, that should work! I guess the thorough solution would be allowing a more flexible output from the collector?
I think so. The solution is like a self-defined function gather_info_fn(prev_result, info, env_id) -> result
and we can return the result when we call .collect
Closing as stale, in gymnasium there is truncated
and done
, and overall stopping should rather be handled by the env and not the core trainer, I feel.