array-api icon indicating copy to clipboard operation
array-api copied to clipboard

RFC: add `broadcast_shapes` to the specification

Open crusaderky opened this issue 10 months ago • 5 comments

This RFC proposes adding an API to the specification for explicitly broadcasting a list of shapes to a single shape.

Overview

Based on array API comparison data, this API, or some variation of it, is commonly implemented across array libraries.

Currently, the Array API specification only includes broadcast_arrays and broadcast_to which both require array input. The specification lacks APIs for working directly with shapes without needing to create new array instances.

Prior Art

  • NumPy: https://numpy.org/doc/stable/reference/generated/numpy.broadcast_shapes.html
    • added in 2020: https://github.com/numpy/numpy/pull/17535
    • returns a Tuple
  • CuPy: ~~does not currently support~~
    • Correction: CuPy simply borrows broadcast_shapes from NumPy: https://github.com/cupy/cupy/blob/a888cc94c79729cf24ebb808d15b9702c0342392/cupy/init.py#L302
  • Dask: da.core.broadcast_arrays exists as private API only. Supports Dask's bespoke nan's in the shape.
  • JAX: follows NumPy
    • returns a Tuple
  • PyTorch: follows NumPy
    • returns a Size
  • TensorFlow: has two APIs for statically and dynamically known shapes
    • broadcast_static_shape: https://www.tensorflow.org/api_docs/python/tf/broadcast_static_shape
    • broadcast_dynamic_shape: https://www.tensorflow.org/api_docs/python/tf/broadcast_dynamic_shape
    • both functions only accept two shape arguments
  • ndonnx: no API. Shapes can contain None, so one cannot use numpy's implementation.

Proposal

This RFC proposes adding the following API to the specification:

def broadcast_shapes(*shapes: tuple[int | None, ...]) → tuple[int | None, ...]

in which one or more shapes are broadcasted together according to broadcasting rules as enumerated in the specification.

Questions

  • How to handle shapes having unknown dimensions?

    • dask.array.core.broadcast_shapes sets the output size to nan if any of the input shapes are nan on the same axis
    • ndonnx.broadcast_arrays(a, b) returns arrays with material shapes.

    Note that shape materialization can be a very expensive operation, as it requires materializing the whole graph until that point. In the case of Dask, which doesn't cache intermediate results as a deliberate memory management policy, this means computing everything at least twice.

Notes

The top-level page on broadcasting mentions on the first line, using non-prescriptive language, that broadcasting allows creating views of the inputs:

Broadcasting refers to the automatic (implicit) expansion of array dimensions to be of equal sizes without copying array data

However, no mention of sharing memory is made in broadcast_to or broadcast_arrays.

For the sake of comparison, see the verbiage in asarray(copy=False).

The problem with this ambiguity is that one can work around the lack of broadcast_shapes by calling xp.broadcast_arrays(*args)[0].shape, but there is no strong guarantee that the backend won't deep-copy the inputs.

Note that numpy.broadcast_shapes doesn't work with shapes containing None (ndonnx and hopefully in the future JAX too) or NaN (Dask; non-standard).

I suggest to either

  • Add prescriptive verbiage to broadcast_to and broadcast_arrays that the output must share memory with the input, or in other words the operation must be O(1), or
  • Add broadcast_shapes to the standard, and change the verbiage of the broadcasting high level page to "typically without copying array data"

For the time being I am adding the function to array_api_extra:

  • https://github.com/data-apis/array-api-extra/issues/80
  • https://github.com/data-apis/array-api-extra/pull/133

crusaderky avatar Feb 05 '25 15:02 crusaderky

@crusaderky Would you mind sharing your use case for needing broadcast_shapes? (i.e., where you want to determine a shape, without actually broadcasting an array)

kgryte avatar Feb 06 '25 07:02 kgryte

@kgryte I'm writing a backend-agnostic wrapper around jax.pure_callback / dask.blockwise / equivalent functions (ndonnx?) that let you apply an arbitrary eager callback to a lazy array: https://github.com/data-apis/array-api-extra/pull/86 This requires me to know the output shape well before the inputs will be used for anything.

crusaderky avatar Feb 06 '25 12:02 crusaderky

It would be useful to add to the issue description which array libraries do or don't support broadcast_shapes, and if they do whether their signatures and behavior are all compatible.

rgommers avatar Feb 06 '25 17:02 rgommers

Add prescriptive verbiage to broadcast_to and broadcast_arrays that the output must share memory with the input, or in other words the operation must be O(1), or

That's a performance/implementation question - as a rule we never specify that, only syntax and semantics. It might not be true for all implementations (e.g., JAX in eager mode).

there is no strong guarantee that the backend won't deep-copy the inputs.

One shouldn't rely on such an assumption.

The problem with this ambiguity is that one can work around the lack of broadcast_shapes by calling xp.broadcast_arrays(*args)[0].shape, there is no strong guarantee that the backend won't deep-copy the inputs.

This syntax in itself is a good indication that broadcast_shapes is a useful thing to have. xp.broadcast_arrays(*args)[0].shape is a pretty clunky workaround - I wouldn't worry about performance here (libraries where performance matters will do the fast thing), but about the maintainability/readability of code.

For the time being I am adding the function to array_api_extra:

That sounds like a good idea to me. Based on usage frequency and library support it can be "promoted" to a standardized function afterwards.

rgommers avatar Feb 06 '25 17:02 rgommers

@rgommers I've updated the OP with a comparison across libraries.

kgryte avatar Feb 06 '25 17:02 kgryte