array-api
array-api copied to clipboard
RFC: `item()` to return scalar for arrays with exactly 1 element.
def item(self) -> Scalar:
"""If array contains exactly one element, retun it as a scalar, else raises ValueError."""
Examples:
-
numpy.ndarray.item
-
torch.Tensor.item
-
pandas.Series.item
-
pandas.Index.item
-
polars.Series.item
-
xarray.DataArray.item
Demo:
import pytest
import xarray as xr
import pandas as pd
import polars as pl
import numpy as np
@pytest.mark.parametrize("data", [[], [1, 2, 3]])
@pytest.mark.parametrize(
"array_type", [torch.tensor, np.array, pd.Series, pd.Index, pl.Series, xr.DataArray]
)
def test_item_valueerror(data, array_type):
array = array_type(data)
with pytest.raises(ValueError):
array.item()
@pytest.mark.parametrize(
"array_type", [torch.tensor, np.array, pd.Series, pd.Index, pl.Series, xr.DataArray]
)
def test_item(array_type):
array = array_type([1])
array.item()
Currently, only torch
fails, because it raises RuntimeError
instead of ValueError
.