array-api
array-api copied to clipboard
RFC: add `broadcast_shapes` to the specification
This RFC proposes adding an API to the specification for explicitly broadcasting a list of shapes to a single shape.
Overview
Based on array API comparison data, this API, or some variation of it, is commonly implemented across array libraries.
Currently, the Array API specification only includes broadcast_arrays and broadcast_to which both require array input. The specification lacks APIs for working directly with shapes without needing to create new array instances.
Prior Art
- NumPy: https://numpy.org/doc/stable/reference/generated/numpy.broadcast_shapes.html
- added in 2020: https://github.com/numpy/numpy/pull/17535
- returns a Tuple
- CuPy: ~~does not currently support~~
- Correction: CuPy simply borrows
broadcast_shapesfrom NumPy: https://github.com/cupy/cupy/blob/a888cc94c79729cf24ebb808d15b9702c0342392/cupy/init.py#L302
- Correction: CuPy simply borrows
- Dask:
da.core.broadcast_arraysexists as private API only. Supports Dask's bespokenan's in the shape. - JAX: follows NumPy
- returns a Tuple
- PyTorch: follows NumPy
- returns a
Size
- returns a
- TensorFlow: has two APIs for statically and dynamically known shapes
broadcast_static_shape: https://www.tensorflow.org/api_docs/python/tf/broadcast_static_shapebroadcast_dynamic_shape: https://www.tensorflow.org/api_docs/python/tf/broadcast_dynamic_shape- both functions only accept two shape arguments
- ndonnx: no API. Shapes can contain
None, so one cannot use numpy's implementation.
Proposal
This RFC proposes adding the following API to the specification:
def broadcast_shapes(*shapes: tuple[int | None, ...]) → tuple[int | None, ...]
in which one or more shapes are broadcasted together according to broadcasting rules as enumerated in the specification.
Questions
-
How to handle shapes having unknown dimensions?
dask.array.core.broadcast_shapessets the output size to nan if any of the input shapes are nan on the same axisndonnx.broadcast_arrays(a, b)returns arrays with material shapes.
Note that shape materialization can be a very expensive operation, as it requires materializing the whole graph until that point. In the case of Dask, which doesn't cache intermediate results as a deliberate memory management policy, this means computing everything at least twice.
Notes
The top-level page on broadcasting mentions on the first line, using non-prescriptive language, that broadcasting allows creating views of the inputs:
Broadcasting refers to the automatic (implicit) expansion of array dimensions to be of equal sizes without copying array data
However, no mention of sharing memory is made in broadcast_to or broadcast_arrays.
For the sake of comparison, see the verbiage in asarray(copy=False).
The problem with this ambiguity is that one can work around the lack of broadcast_shapes by calling xp.broadcast_arrays(*args)[0].shape, but there is no strong guarantee that the backend won't deep-copy the inputs.
Note that numpy.broadcast_shapes doesn't work with shapes containing None (ndonnx and hopefully in the future JAX too) or NaN (Dask; non-standard).
I suggest to either
- Add prescriptive verbiage to
broadcast_toandbroadcast_arraysthat the output must share memory with the input, or in other words the operation must be O(1), or - Add
broadcast_shapesto the standard, and change the verbiage of the broadcasting high level page to "typically without copying array data"
For the time being I am adding the function to array_api_extra:
- https://github.com/data-apis/array-api-extra/issues/80
- https://github.com/data-apis/array-api-extra/pull/133
@crusaderky Would you mind sharing your use case for needing broadcast_shapes? (i.e., where you want to determine a shape, without actually broadcasting an array)
@kgryte I'm writing a backend-agnostic wrapper around jax.pure_callback / dask.blockwise / equivalent functions (ndonnx?) that let you apply an arbitrary eager callback to a lazy array: https://github.com/data-apis/array-api-extra/pull/86
This requires me to know the output shape well before the inputs will be used for anything.
It would be useful to add to the issue description which array libraries do or don't support broadcast_shapes, and if they do whether their signatures and behavior are all compatible.
Add prescriptive verbiage to
broadcast_toandbroadcast_arraysthat the output must share memory with the input, or in other words the operation must be O(1), or
That's a performance/implementation question - as a rule we never specify that, only syntax and semantics. It might not be true for all implementations (e.g., JAX in eager mode).
there is no strong guarantee that the backend won't deep-copy the inputs.
One shouldn't rely on such an assumption.
The problem with this ambiguity is that one can work around the lack of
broadcast_shapesby callingxp.broadcast_arrays(*args)[0].shape, there is no strong guarantee that the backend won't deep-copy the inputs.
This syntax in itself is a good indication that broadcast_shapes is a useful thing to have. xp.broadcast_arrays(*args)[0].shape is a pretty clunky workaround - I wouldn't worry about performance here (libraries where performance matters will do the fast thing), but about the maintainability/readability of code.
For the time being I am adding the function to
array_api_extra:
That sounds like a good idea to me. Based on usage frequency and library support it can be "promoted" to a standardized function afterwards.
@rgommers I've updated the OP with a comparison across libraries.