Feature Request: Allow sampling of log probs and logits for Likelihood.CLASSIFICATION
The code in _glm_predictive_samples always applies torch.softmax to the results under classification.
For numerical stability supporting torch.log_softmax here would be helpful. Similarly, it would be helpful if there was an easy way to obtain the logits without having to change self.likelihood intermittently.
Thanks,
Andreas
Thanks for the input, Andreas! I wonder if something like this works for your case:
def _glm_predictive_samples(
self,
f_mu: torch.Tensor,
f_var: torch.Tensor,
+ link_function: Optional[Callable[[torch.Tensor], torch.Tensor]]
n_samples: int,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
Where
-
link_function = Nonerestores the current implementation -
link_function = lambda f: fgets you a sample logits -
link_function = functools.partial(torch.log_softmax, dim=-1)gets you a sample log-softmax. - This can also be used to compute an arbitrary expectation.
@aleximmer, @runame feel free to chime in.
Looking for feedback before implementing this.
Aww, yeah, that would be great! It would cover all my use cases and provide a nice extensible interface.
Sounds like a good improvement!