ray
ray copied to clipboard
[RLlib] EagerTFPolicyV2 wrongly calls overridden action_sampler_fn
What happened + What you expected to happen
Creating a custom Policy based on EagerTFPolicyV2
that overrides action_sampler_fn
causes an error during initialization:
File "C:\Users\<username>\AppData\Local\pypoetry\Cache\virtualenvs\venv-py3.11\Lib\site-packages\ray\rllib\policy\eager_tf_policy_v2.py", line 1013, in _compute_actions_helper
actions, logp, dist_inputs, state_out = self.action_sampler_fn(
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: get_custom_appo_tf_policy.<locals>.CustomAPPOTFPolicy.action_sampler_fn() missing 1 required keyword-only argument: 'state_batches'
It seems that self.action_sampler_fn
is called incorrectly. In line 1013 of ray\rllib\policy\eager_tf_policy_v2.py
it should be
actions, logp, dist_inputs, state_out = self.action_sampler_fn(
self.model,
obs_batch=input_dict[SampleBatch.OBS],
state_batches=state_batches,
explore=explore,
timestep=timestep,
episodes=episodes,
)
instead of
actions, logp, dist_inputs, state_out = self.action_sampler_fn(
self.model,
input_dict[SampleBatch.OBS],
explore=explore,
timestep=timestep,
episodes=episodes,
)
Versions / Dependencies
Ray: 2.10.0 Python: 3.11.9 OS: Windows 10
Reproduction script
from ray.rllib.algorithms.appo.appo_tf_policy import get_appo_tf_policy
APPOTFPolicy = get_appo_tf_policy("CustomAPPOTFPolicy", EagerTFPolicyV2)
# this class crashes during initialization
class CustomAPPOTFPolicy(APPOTFPolicy):
@override(base)
def action_sampler_fn(
self,
model: ModelV2,
*,
obs_batch: TensorType,
state_batches: TensorType,
seq_lens: Optional[TensorType] = None,
**kwargs,
) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
# custom code here
return None, None, None, None
Issue Severity
High: It blocks me from completing my task.
I created a pull request to fix the issue