rl
rl copied to clipboard
[BUG] `MaskedCategorical` missing `mode` and `deterministic_sample` properties.
Describe the bug
Similar issue as here. The MaskedCategorical distribution is missing the mode and deterministic_sample properties.
Reason and Possible fixes
The MaskedCategorical distribution should define the following additional properties:
@property
def mode(self) -> torch.Tensor:
if hasattr(self, "logits"):
return self.logits.max(-1, keepdim=True)[1]
return self.probs.max(-1, keepdim=True)[1]
@property
def deterministic_sample(self) -> torch.Tensor:
return self.mode