ray
ray copied to clipboard
[RLlib] eager_tf_policy drops is_training in call in _convert_to_tf
What happened + What you expected to happen
In eager_tf_policy.py, there is the following:
https://github.com/ray-project/ray/blob/166cd537d5a84041c3c1f3290b5f938ba150bce8/rllib/policy/eager_tf_policy.py#L34-L36
The problem is that is_training is now an attribute of SampleBatch rather than a dictionary key, so it doesn't get picked up by x.items() and thus is not included in dict_.
Versions / Dependencies
ray 1.11.0 python 3.9 tf 2.7 rhel 7.9
Reproduction script
n/a
Issue Severity
Medium: It is a significant difficulty but I can work around it.
Thanks to this issue I also found a workaround by
def _convert_to_tf(x, dtype=None):
if isinstance(x, SampleBatch):
dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
dict_['is_training'] = x.is_training # <------ manually add the 'is_training' flag
return tree.map_structure(_convert_to_tf, dict_)