RFC: update `full` to accept 0D arrays for `fill_value`
According to the current standard, the fill_value of full must be a Python scalar.
full(shape: int | Tuple[int, ...], fill_value: bool | int | float | complex, *, dtype: dtype | None = None, device: device | None = None) → array[¶](https://data-apis.org/array-api/latest/API_specification/generated/array_api.full.html#array_api.full)
Are 0d arrays allowed? This usage seems common so that the dype of the output is determined by the dtype of the fill_value. That's what happens in array_api_strict, for instance.
xp.full(1, xp.asarray(1.0, dtype=xp.float32))
# Array([1.], dtype=array_api_strict.float32)
@mdhaber Do you have a sense as to how common array fill_value support is across the ecosystem? (e.g., NumPy, PyTorch, JAX, Dask, MLX, ndonnx, CuPy)
Yes, support is already quite good. I haven't tested MLX.
from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
x = xp.asarray(0.)
try:
print(xp.full((2, 2), x))
except Exception as e:
print(e)
/usr/local/lib/python3.11/dist-packages/ndonnx/__init__.py:16: UserWarning: onnxruntime is not installed. ndonnx will use the (incomplete) reference implementation for value propagation.
warn(
[[0. 0.]
[0. 0.]]
tensor([[0., 0.],
[0., 0.]])
[[0. 0.]
[0. 0.]]
dask.array<full_like, shape=(2, 2), dtype=float64, chunksize=(2, 2), chunktype=numpy.ndarray>
[[0. 0.]
[0. 0.]]
tf.Tensor(
[[0. 0.]
[0. 0.]], shape=(2, 2), dtype=float64)
unable to infer dtype from `array(data: 0.0, dtype=float64)`
ndonnx is the only one of those that fails, and they would either want to update their error message or support it. It doesn't seem right to say that the data type can't be inferred from an array; the data type is right there.
This makes sense to me. The fill value might come from something like a[0] where it would be a 0-d array. The only concern is if the fill value dtype disagrees with the dtype argument. I guess the dtype argument should take priority, but we could also say it's undefined if not all libraries do that. The existing text also doesn't necessarily completely leave undefined the cases where the dtype and the fill value are different kinds, not sure if that is intentional.
from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
x = xp.asarray(0, dtype=xp.int16)
try:
print(xp.__name__, xp.full((2, 2), x, dtype=xp.int32).dtype)
except Exception as e:
print(e)
array_api_compat.numpy int32
array_api_compat.torch torch.int32
array_api_compat.cupy int32
array_api_compat.dask.array int32
jax.numpy int32
tensorflow.experimental.numpy <dtype: 'int32'>
ndonnx int32
So dtype passed into xp.full takes precedence for all libraries. And ndonnx no longer complains about being unable to infer the dtype.
Similar with full_like. dtype passed into full_like takes precedence over the dtypes of either positional argument, and if no dtype is specified, the first argument's dtype takes precedence over the second.
from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
zeros = xp.zeros((2, 2), dtype=xp.int16)
x = xp.asarray(0, dtype=xp.int16)
try:
print(xp.__name__, xp.full_like(zeros, x, dtype=xp.int32).dtype)
except Exception as e:
print(e)
array_api_compat.numpy int32
array_api_compat.torch torch.int32
array_api_compat.cupy int32
array_api_compat.dask.array int32
jax.numpy int32
tensorflow.experimental.numpy <dtype: 'int32'>
ndonnx int32
from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
zeros = xp.zeros((2, 2), dtype=xp.int16)
x = xp.asarray(0, dtype=xp.int32)
try:
print(xp.__name__, xp.full_like(zeros, x).dtype)
except Exception as e:
print(e)
array_api_compat.numpy int16
array_api_compat.torch torch.int16
array_api_compat.cupy int16
array_api_compat.dask.array int16
jax.numpy int16
tensorflow.experimental.numpy <dtype: 'int16'>
ndonnx int16
Dask currently produces a warning with:
from array_api_compat.dask import array
float(xp.full((), xp.asarray(1)))
# 1.0
# FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
FYI Dask loudly complains if fill_value is an array: https://github.com/scipy/scipy/pull/22900
@crusaderky Are you referring the warning that appears just above https://github.com/data-apis/array-api/issues/909#issuecomment-2835419568? (Please expand details.) It is not a very specific complaint, and Dask complains about a lot. It looks like a mistake on Dask's side.
The warning message is misleading, but the root cause is that Dask doesn't expect a Dask Array as fill_value, so the whole machinery works by pure accident:
>>> import dask.array as da
>>> da.full((2, ), da.asarray(1)).persist()
/home/crusaderky/miniforge3/envs/array-api-compat/lib/python3.10/site-packages/dask/array/core.py:1763: FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
warnings.warn(
dask.array<full_like, shape=(2,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
>>> import array_api_compat.dask.array as xp
>>> xp.full((2, ), xp.asarray(1)).persist()
/home/crusaderky/miniforge3/envs/array-api-compat/lib/python3.10/site-packages/dask/array/core.py:1763: FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
warnings.warn(
dask.array<full_like, shape=(2,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
So you mean yes - the same warning that appears at the end of the details?
Yes, it's the same warning.
This was discussed in the community meeting yesterday. Participants there agreed that it's nicer UX to allow 0-D arrays indeed, and it seems fine to add this to the spec - preferably after someone implements this in Dask.
There was a discussion about static typing, since allowing 0-D arrays makes typing harder - but it's still doable, and it seems like the usability improvement outweighs the extra static typing complexity. This shows up especially if fill_value is derived from another array:
x = somefunc(...)
xp.full(new_shape, float(x[0]), dtype=x.dtype, device=x.device)
becomes
x = somefunc(...)
xp.full(new_shape, x[0])
We are more than happy to officially support this in ndonnx! The ndonnx's api simply follows the latest array-api and rarely goes beyond it.