array-api icon indicating copy to clipboard operation
array-api copied to clipboard

RFC: `item()` to return scalar for arrays with exactly 1 element.

Open randolf-scholz opened this issue 8 months ago • 8 comments

def item(self) -> Scalar:
     """If array contains exactly one element, retun it as a scalar, else raises ValueError."""

Examples:

Demo:

import pytest
import xarray as xr
import pandas as pd
import polars as pl
import numpy as np

@pytest.mark.parametrize("data", [[], [1, 2, 3]])
@pytest.mark.parametrize(
    "array_type", [torch.tensor, np.array, pd.Series, pd.Index, pl.Series, xr.DataArray]
)
def test_item_valueerror(data, array_type):
    array = array_type(data)
    with pytest.raises(ValueError):
        array.item()


@pytest.mark.parametrize(
    "array_type", [torch.tensor, np.array, pd.Series, pd.Index, pl.Series, xr.DataArray]
)
def test_item(array_type):
    array = array_type([1])
    array.item()

Currently, only torch fails, because it raises RuntimeError instead of ValueError.

randolf-scholz avatar Jun 20 '24 07:06 randolf-scholz