tianshou icon indicating copy to clipboard operation
tianshou copied to clipboard

Stopping function on alternative criteria

Open drozzy opened this issue 3 years ago • 14 comments

It woulld be nice if we had access to the info dict in the stop callback.

Presently the stopping callback in trainer looks like this:

def stop_fn(mean_rewards): 
        return mean_rewards >= env.spec.reward_threshold

But I would like to be able to early stop based on some other criteria, for example maybe some cost key in the info dict e.g.:

def stop_fn(rewards, info): 
        return sum(info['cost']) <= env.spec.cost_constraint

drozzy avatar Oct 01 '21 15:10 drozzy

Is there any other criteria instead of reward and info? I'm thinking about completeness.

Trinkle23897 avatar Oct 01 '21 15:10 Trinkle23897

Probably all the things that env step returns? reward, state, info?

Initially I expected to just have a pointer to the replay buffer, but then I realized that it's the testing phase, so it probably doesn't store it there.

drozzy avatar Oct 01 '21 15:10 drozzy

How about learning metrics?

Trinkle23897 avatar Oct 01 '21 15:10 Trinkle23897

Learning metrics?

drozzy avatar Oct 02 '21 01:10 drozzy

Hmm... after rethinking, learning metrics are seldom treated as a stopping criterion. Feel free to submit a PR :)

Trinkle23897 avatar Oct 02 '21 12:10 Trinkle23897

This is very complicated, I don't want to mess up stuff. For example, collector itself doesn't return the info objects: https://github.com/thu-ml/tianshou/blob/fc251ab0b85bf3f0de7b24c1c553cb0ec938a9ee/tianshou/data/collector.py#L317

It's also not clear to me how result["rews"] (here) gets summed. It seems like this is just a list of rewards from the episode - but in the stop function it's used as a single value:

def stop_fn(mean_rewards): 
        return mean_rewards >= env.spec.reward_threshold

drozzy avatar Oct 04 '21 06:10 drozzy

It's also not clear to me how result["rews"] (here) gets summed. It seems like this is just a list of rewards from the episode - but in the stop function it's used as a single value:

fixed in https://github.com/thu-ml/tianshou/pull/459/commits/be077afbc1c1ef7f76c971ee70403ee2a4fe1d0d

Trinkle23897 avatar Oct 04 '21 15:10 Trinkle23897

For example, collector itself doesn't return the info objects:

because info is not gatherable, and return all infos in collector.collector return dict is ugly...

Trinkle23897 avatar Oct 04 '21 16:10 Trinkle23897

@drozzy any further thought on this API design?

Trinkle23897 avatar Oct 09 '21 12:10 Trinkle23897

Here is an example of where it might be useful: https://github.com/openai/safety-gym#getting-started

Stop when costs are below some threshold.

>>> info
{'cost_hazards': 0.0, 'cost': 0.0}

I'm happy if only certain keys of info are gatherable.

Hack right now for me is I just don't do early stopping haha.

drozzy avatar Oct 10 '21 02:10 drozzy

I would like to calculate other criteria (e.g. success rate in some tasks) other than rewards. I wonder how we can do this or hack this. @drozzy

KwanWaiChung avatar Nov 29 '21 10:11 KwanWaiChung

This hack is kind of messy but I think it is the most efficient way...

  1. use preprocess_fn in collector to get the statistical result for success rate, e.g.,
# in the main.py, when initializing collector
from tianshou.utils import MovAvg

rates = [MovAvg() for _ in range(training_num)]

def preprocess_fn(info, env_id, **kwargs):
  succ_rate = info["success_rate"]  # this is a numpy array
  for eid, r in zip(env_id, succ_rate):
    rates[eid].add(r)

collector = Collector(..., preprocess_fn=preprocess_fn)
  1. change stop_fn:
def stop_fn(mean_rewards):
  return np.mean([m.get() for m in rates]) >= threshold

Trinkle23897 avatar Nov 29 '21 14:11 Trinkle23897

This hack is kind of messy but I think it is the most efficient way...

  1. use preprocess_fn in collector to get the statistical result for success rate, e.g.,
# in the main.py, when initializing collector
from tianshou.utils import MovAvg

rates = [MovAvg() for _ in range(training_num)]

def preprocess_fn(info, env_id, **kwargs):
  succ_rate = info["success_rate"]  # this is a numpy array
  for eid, r in zip(env_id, succ_rate):
    rates[eid].add(r)

collector = Collector(..., preprocess_fn=preprocess_fn)
  1. change stop_fn:
def stop_fn(mean_rewards):
  return np.mean([m.get() for m in rates]) >= threshold

Thanks, that should work! I guess the thorough solution would be allowing a more flexible output from the collector?

KwanWaiChung avatar Nov 29 '21 15:11 KwanWaiChung

I think so. The solution is like a self-defined function gather_info_fn(prev_result, info, env_id) -> result and we can return the result when we call .collect

Trinkle23897 avatar Nov 29 '21 15:11 Trinkle23897

Closing as stale, in gymnasium there is truncated and done, and overall stopping should rather be handled by the env and not the core trainer, I feel.

MischaPanch avatar Oct 14 '23 15:10 MischaPanch