rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] `MaskedCategorical` missing `mode` and `deterministic_sample` properties.

Open MorganeAyle opened this issue 1 year ago • 0 comments

Describe the bug

Similar issue as here. The MaskedCategorical distribution is missing the mode and deterministic_sample properties.

Reason and Possible fixes

The MaskedCategorical distribution should define the following additional properties:

@property
def mode(self) -> torch.Tensor:
    if hasattr(self, "logits"):
        return self.logits.max(-1, keepdim=True)[1]
    return self.probs.max(-1, keepdim=True)[1]

@property
def deterministic_sample(self) -> torch.Tensor:
    return self.mode

MorganeAyle avatar Oct 09 '24 17:10 MorganeAyle