ray icon indicating copy to clipboard operation
ray copied to clipboard

[RLlib] EagerTFPolicyV2 wrongly calls overridden action_sampler_fn

Open rubenjacob opened this issue 10 months ago • 1 comments

What happened + What you expected to happen

Creating a custom Policy based on EagerTFPolicyV2 that overrides action_sampler_fn causes an error during initialization:

File "C:\Users\<username>\AppData\Local\pypoetry\Cache\virtualenvs\venv-py3.11\Lib\site-packages\ray\rllib\policy\eager_tf_policy_v2.py", line 1013, in _compute_actions_helper
    actions, logp, dist_inputs, state_out = self.action_sampler_fn(
                                            ^^^^^^^^^^^^^^^^^^^^^^^
TypeError: get_custom_appo_tf_policy.<locals>.CustomAPPOTFPolicy.action_sampler_fn() missing 1 required keyword-only argument: 'state_batches'

It seems that self.action_sampler_fn is called incorrectly. In line 1013 of ray\rllib\policy\eager_tf_policy_v2.py it should be

actions, logp, dist_inputs, state_out = self.action_sampler_fn(
    self.model,
    obs_batch=input_dict[SampleBatch.OBS],
    state_batches=state_batches,
    explore=explore,
    timestep=timestep,
    episodes=episodes,
)

instead of

actions, logp, dist_inputs, state_out = self.action_sampler_fn(
    self.model,
    input_dict[SampleBatch.OBS],
    explore=explore,
    timestep=timestep,
    episodes=episodes,
)

Versions / Dependencies

Ray: 2.10.0 Python: 3.11.9 OS: Windows 10

Reproduction script

from ray.rllib.algorithms.appo.appo_tf_policy import get_appo_tf_policy

APPOTFPolicy = get_appo_tf_policy("CustomAPPOTFPolicy", EagerTFPolicyV2)

# this class crashes during initialization
class CustomAPPOTFPolicy(APPOTFPolicy):
    @override(base)
    def action_sampler_fn(
                self,
                model: ModelV2,
                *,
                obs_batch: TensorType,
                state_batches: TensorType,
                seq_lens: Optional[TensorType] = None,
                **kwargs,
        ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
            # custom code here
            return None, None, None, None

Issue Severity

High: It blocks me from completing my task.

rubenjacob avatar Apr 11 '24 12:04 rubenjacob

I created a pull request to fix the issue

RocketRider avatar May 25 '24 15:05 RocketRider