array-api icon indicating copy to clipboard operation
array-api copied to clipboard

RFC: add `logsumexp`

Open steff456 opened this issue 1 year ago • 2 comments

This RFC seeks to include a new API in the array API specification for the purpose of computing the log of summed exponentials.

Overview

The Array API specification currently includes logaddexp which performs an element-wise operation on two input arrays, but does not include the reduction logsumexp. This API is commonly implemented in accelerator libraries for better numerical stability in deep learning applications.

This can be implemented using log(sum(exp)); however, such an implementation is not likely to be numerically stable.

Prior art

Proposal:

def logsumexp(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array
  • dtype kwarg is for consistency with sum et al

Related

  • https://github.com/numpy/numpy/issues/4260

cc @kgryte

steff456 avatar Feb 14 '23 22:02 steff456

logsumexp was also mentioned as a candidate for inclusion in the special functions extension: https://github.com/data-apis/array-api/issues/725. Accordingly, before moving forward with this proposal, we should first determine whether it makes sense to add in the main namespace or in that extension.

kgryte avatar Jan 11 '24 08:01 kgryte

I updated the PR description:

  • JAX has logsumexp exposed in two places (jax.nn and jax.scipy.special)
  • PyTorch also added an alias in torch.special.

So the path of least resistance is probably to add it in special. However, it would be strange to have logaddexp in the main namespace and logsumexp in an optional extension, since logsumexp is more commonly used than logaddexp, and they're very much related.

rgommers avatar Jan 11 '24 15:01 rgommers