rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] SafeProbabilisticModule constructor missing `log_prob_keys` argument

Open rerz opened this issue 11 months ago • 0 comments

Describe the bug

When disabling logprob aggregation for a probabilistic actor you are supposed to pass a sequence of log_prob_keys as a parameter instead of a single log_prob_key. However, the parameter is not properly passed up the class hierarchy in the constructor. This happens because the kwargs are passed to SafeProbabilisticModule properly the constructor does not expect a log_prob_keys argument resulting in a TypeError.

TypeError: SafeProbabilisticModule.__init__() got an unexpected keyword argument 'log_prob_keys'

To Reproduce

Disable logprob aggregation by using set_composite_lp_aggregate(...).set(), pass the return_log_prob argument to the ProbabilisticActor constructor and provide a sequence of logprob keys via the log_prob_keys argument.

Expected behavior

The TypeError does not occur.

System info

torchrl + tensordict from the current main branches.

Reason and Possible fixes

To fix it should be enough to add the expected argument to the __init__ method of SafeProbabilisticModule and also pass it to the superclass constructor of ProbabilisticTensorDictModule which properly defines the argument already.

class SafeProbabilisticModule(ProbabilisticTensorDictModule):
    def __init__(
        self,
        in_keys: Union[NestedKey, List[NestedKey], Dict[str, NestedKey]],
        out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None,
        spec: Optional[TensorSpec] = None,
        safe: bool = False,
        default_interaction_type: str = InteractionType.DETERMINISTIC,
        distribution_class: Type = Delta,
        distribution_kwargs: Optional[dict] = None,
        return_log_prob: bool = False,
        log_prob_key: NestedKey | None = None,
        log_prob_keys: List[NestedKey] | None = None,   <----- here
        cache_dist: bool = False,
        n_empirical_estimate: int = 1000,
    ):
        super().__init__(
            in_keys=in_keys,
            out_keys=out_keys,
            default_interaction_type=default_interaction_type,
            distribution_class=distribution_class,
            distribution_kwargs=distribution_kwargs,
            return_log_prob=return_log_prob,
            log_prob_key=log_prob_key,
            log_prob_keys=log_prob_keys, <----- here
            cache_dist=cache_dist,
            n_empirical_estimate=n_empirical_estimate,
        )

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

rerz avatar Jan 29 '25 22:01 rerz