RFC: add `logsumexp`
This RFC seeks to include a new API in the array API specification for the purpose of computing the log of summed exponentials.
Overview
The Array API specification currently includes logaddexp which performs an element-wise operation on two input arrays, but does not include the reduction logsumexp. This API is commonly implemented in accelerator libraries for better numerical stability in deep learning applications.
This can be implemented using log(sum(exp)); however, such an implementation is not likely to be numerically stable.
Prior art
- NumPy: (not currently implemented)
- NumPy does, however, implement
logaddexp.reduce.
- NumPy does, however, implement
- Dask: (not currently implemented)
- SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
- CuPy: https://docs.cupy.dev/en/stable/reference/generated/cupyx.scipy.special.logsumexp.html
- In
scipy.specialnamespace.
- In
- PyTorch: https://pytorch.org/docs/stable/generated/torch.logsumexp.html (also an alias in
torch.special: https://pytorch.org/docs/stable/special.html#torch.special.logsumexp) - TensorFlow: https://www.tensorflow.org/api_docs/python/tf/math/reduce_logsumexp
- JAX: jax.nn.logsumexp and jax.scipy.special.logsumexp (same function, exposed in two places)
Proposal:
def logsumexp(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array
-
dtypekwarg is for consistency withsumet al
Related
- https://github.com/numpy/numpy/issues/4260
cc @kgryte
logsumexp was also mentioned as a candidate for inclusion in the special functions extension: https://github.com/data-apis/array-api/issues/725. Accordingly, before moving forward with this proposal, we should first determine whether it makes sense to add in the main namespace or in that extension.
I updated the PR description:
- JAX has
logsumexpexposed in two places (jax.nnandjax.scipy.special) - PyTorch also added an alias in
torch.special.
So the path of least resistance is probably to add it in special. However, it would be strange to have logaddexp in the main namespace and logsumexp in an optional extension, since logsumexp is more commonly used than logaddexp, and they're very much related.