Laplace icon indicating copy to clipboard operation
Laplace copied to clipboard

Feature Request: Allow sampling of log probs and logits for Likelihood.CLASSIFICATION

Open BlackHC opened this issue 1 year ago • 3 comments

The code in _glm_predictive_samples always applies torch.softmax to the results under classification.

For numerical stability supporting torch.log_softmax here would be helpful. Similarly, it would be helpful if there was an easy way to obtain the logits without having to change self.likelihood intermittently.

Thanks,
Andreas

BlackHC avatar Sep 12 '24 22:09 BlackHC

Thanks for the input, Andreas! I wonder if something like this works for your case:

def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
+   link_function: Optional[Callable[[torch.Tensor], torch.Tensor]]
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:

Where

  • link_function = None restores the current implementation
  • link_function = lambda f: f gets you a sample logits
  • link_function = functools.partial(torch.log_softmax, dim=-1) gets you a sample log-softmax.
  • This can also be used to compute an arbitrary expectation.

@aleximmer, @runame feel free to chime in.

Looking for feedback before implementing this.

wiseodd avatar Sep 13 '24 15:09 wiseodd

Aww, yeah, that would be great! It would cover all my use cases and provide a nice extensible interface.

BlackHC avatar Sep 14 '24 11:09 BlackHC

Sounds like a good improvement!

runame avatar Sep 14 '24 15:09 runame