array-api
array-api copied to clipboard
RFC: add `logsumexp`
This RFC seeks to include a new API in the array API specification for the purpose of computing the log of summed exponentials.
Overview
The Array API specification currently includes logaddexp
which performs an element-wise operation on two input arrays, but does not include the reduction logsumexp
. This API is commonly implemented in accelerator libraries for better numerical stability in deep learning applications.
This can be implemented using log(sum(exp))
; however, such an implementation is not likely to be numerically stable.
Prior art
- NumPy: (not currently implemented)
- NumPy does, however, implement
logaddexp.reduce
.
- NumPy does, however, implement
- Dask: (not currently implemented)
- SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
- CuPy: https://docs.cupy.dev/en/v13.0.0b1/reference/generated/cupyx.scipy.special.logsumexp.html
- In
scipy.special
namespace.
- In
- PyTorch: https://pytorch.org/docs/stable/generated/torch.logsumexp.html (also an alias in
torch.special
: https://pytorch.org/docs/stable/special.html#torch.special.logsumexp) - TensorFlow: https://www.tensorflow.org/api_docs/python/tf/math/reduce_logsumexp
- JAX: jax.nn.logsumexp and jax.scipy.special.logsumexp (same function, exposed in two places)
Proposal:
def logsumexp(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array
-
dtype
kwarg is for consistency withsum
et al
Related
- https://github.com/numpy/numpy/issues/4260
cc @kgryte
logsumexp
was also mentioned as a candidate for inclusion in the special functions extension: https://github.com/data-apis/array-api/issues/725. Accordingly, before moving forward with this proposal, we should first determine whether it makes sense to add in the main namespace or in that extension.
I updated the PR description:
- JAX has
logsumexp
exposed in two places (jax.nn
andjax.scipy.special
) - PyTorch also added an alias in
torch.special
.
So the path of least resistance is probably to add it in special
. However, it would be strange to have logaddexp in the main namespace and logsumexp
in an optional extension, since logsumexp
is more commonly used than logaddexp
, and they're very much related.