ray icon indicating copy to clipboard operation
ray copied to clipboard

[RLlib] eager_tf_policy drops is_training in call in _convert_to_tf

Open HJasperson opened this issue 3 years ago • 1 comments

What happened + What you expected to happen

In eager_tf_policy.py, there is the following:

https://github.com/ray-project/ray/blob/166cd537d5a84041c3c1f3290b5f938ba150bce8/rllib/policy/eager_tf_policy.py#L34-L36

The problem is that is_training is now an attribute of SampleBatch rather than a dictionary key, so it doesn't get picked up by x.items() and thus is not included in dict_.

Versions / Dependencies

ray 1.11.0 python 3.9 tf 2.7 rhel 7.9

Reproduction script

n/a

Issue Severity

Medium: It is a significant difficulty but I can work around it.

HJasperson avatar Apr 15 '22 23:04 HJasperson

Thanks to this issue I also found a workaround by

def _convert_to_tf(x, dtype=None):
    if isinstance(x, SampleBatch):
        dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
        dict_['is_training'] = x.is_training       # <------ manually add the 'is_training' flag
        return tree.map_structure(_convert_to_tf, dict_)

signalprime avatar Dec 20 '22 23:12 signalprime