array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

Spell out where views are allowed

Open crusaderky opened this issue 8 months ago • 14 comments

In https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html, the Standard says

Array API consumers are strongly advised to avoid any mutating operations when an array object may [...] be a “view” [...] It is not always clear, however, when a library will return a view and when it will return a copy. This standard does not attempt to specify this—libraries may do either.

The above is fine after __getitem__ , asarray(..., copy=None), astype(..., copy=False), and similar functions that are explicitly explained by the standard to potentially return views.

However, there are a few corner cases where views could be possible but a normal user is very unlikely to think about them. I just stumbled on one in https://github.com/data-apis/array-api-compat/pull/298, where array_api_compat.torch.sum(x, dtype=x.dtype, axis=()) was accidentally returning x instead of a copy of it.

There are a few more cases where a library could try to be smart; for example

  • search functions (min, max, other?) could return a view to the minimum/maximum point
  • replacement functions (minimum, maximum, clip, where) could return one of the input arrays when there is nothing to do
  • same for arithmetic functions (__add__ / __sub__ vs. 0, __mul__ / __div__ vs. 1, etc.)
  • same for sort functions when they realise the input is already sorted
  • possibly more

In real life, I expect end users to assume that the above functions will always return a copy. I think the standard should spell this out, limiting the possibily of views to an explicit list of allowed functions:

  • __getitem__
  • asarray
  • astype
  • __dlpack__
  • from_dlpack
  • reshape
  • broacast_to
  • broadcast_arrays
  • ...more?

crusaderky avatar Apr 04 '25 10:04 crusaderky

xpx.at(x).set(y) has a copy parameter that defaults to None. This was picked after some discussion to avoid unnecessary copies in writable libraries. The pattern currently used in scipy is that, when the developer thinks that xpx.at may write back to the input, they need to explicitly pass copy=True. This however happens when the input is the unmodified parameter of the function, not the output of some processing/reduction on it.

crusaderky avatar Apr 04 '25 10:04 crusaderky

Practical example from scipy:

https://github.com/scipy/scipy/blob/27157ac1db4fc23bef76df50dfd8a4393453153c/scipy/special/_logsumexp.py#L397-L403

The above code is fine in all the backends we know of. But a backend could have max return x[argmax(x)], which would cause the function to write back to its input.

crusaderky avatar Apr 04 '25 10:04 crusaderky

In real life, I expect end users to assume that the above functions will always return a copy. I think the standard should spell this out, limiting the possibily of views to an explicit list of allowed functions:

This won't fly, since views aren't a concept in standard. There really is no way to fix this problem in the standard, the only way to do it is (a) fix bugs in libraries like the torch.sum one (I'm fairly sure that that is indeed a bug and not a feature), and (b) for libraries to implement ways to return read-only arrays so that any user that uses in-place operations can actually tell the difference between "I'm modifying one array" vs. "I'm modifying >=2 arrays".

(b) is quite desirable, it's in the works for PyTorch and I hope that will actually materialize at some point. For NumPy we've brainstormed about it a bit recently, since it's also desirable for thread-safety - which is becoming much more relevant with free-threading.

Also pragmatically: even if we did in the standard what you suggest, libraries aren't going to follow that and do a whole bunch of work to audit everything and make changes to how functions behave (which would all be bc-breaking changes anyway, I think it's a nonstarter for anything that's not considered a bug).

...more?

linalg.diagonal is an infamous example in NumPy.

rgommers avatar Apr 04 '25 11:04 rgommers

(b) for libraries to implement ways to return read-only arrays so that any user that uses in-place operations can actually tell the difference between "I'm modifying one array" vs. "I'm modifying >=2 arrays".

With this, do you mean something like xp.asarray(obj, writable=False)?

That would indeed solve the scipy example I posted:

    xp = array_namespace(x)
    x = xp.asarray(x, writable=False)  # NOTE THIS!

    x_max = xp.max(x, axis=axis, keepdims=True)

    if x_max.ndim > 0:
        x_max = xpx.at(x_max, ~xp.isfinite(x_max)).set(0)

With this change, xp.max can either return

  • a writable brand new array, which causes xpx.at to efficiently write back to it;
  • or a read-only view of x, which causes xpx.at to perform a copy.

crusaderky avatar Apr 04 '25 12:04 crusaderky

Kinda, but without a writable=False argument, that's too ugly. The idea is to better track of views internally, so the last comment here changes to that of the line above:

>>> import numpy as np
>>> x = np.arange(5)
>>> y = x[::2]
>>> y.data is x.data
False
>>> y.base
array([0, 1, 2, 3, 4])
>>> y.base is x
True
>>> x.base

>>> y[0] += 1  # you, and numpy, can be sure this modifies >1 arrays (because of .base)
>>> x[0] += 1  # you, and numpy, cannot know if this modifies 1 or >1 arrays

Once you can always know, it's straightforward to implement modes (context manager, global setting, etc.) where in-place operations either raise or do copy-on-write if the operation affects >1 array. And over time even migrate the default possibly. The harder part is implementing the machinery for this.

rgommers avatar Apr 04 '25 12:04 rgommers

I ran a test on a bunch of functions with obvious no-op use cases, and I think the picture is paints is very problematic:

backend abs(uint) clip(min=None, max=None) trunc(int) ceil(int) floor(int) round(int)
numpy 2.2 copy copy copy copy copy view
array_api_compat.numpy 2.2 copy copy view view view view
array_api_compat.numpy 1.26 copy copy view view view view
array_api_compat.dask.array copy copy view view view copy
array_api_compat.cupy copy copy view view view copy
array_api_compat.torch copy copy copy copy copy copy
ndonnx copy copy copy copy copy copy
array_api_strict (numpy 2.2) copy view view view view view
array_api_strict (numpy 1.26) copy view view view view view

reproducer:

import importlib

BACKENDS = (
    "array_api_compat.numpy",
    "array_api_compat.dask.array",
    "array_api_compat.cupy",
    "array_api_compat.torch",
    "numpy",
    "ndonnx",
    "array_api_strict",
)

FUNCTIONS = ["abs", "clip", "trunc", "ceil", "floor", "round"]

print("| backend | " + " | ".join(FUNCTIONS) + " |")
print("| --- | " + " | ".join(['---'] * len(FUNCTIONS)) + " |")
for backend in BACKENDS:
    print(f"{backend} |", end="")
    xp = importlib.import_module(backend)
    for func_name in FUNCTIONS:
        a = xp.asarray([1,2], dtype=getattr(xp, "uint8"))
        func = getattr(xp, func_name)
        try:
            b = func(a)
            assert b.dtype == a.dtype
        except Exception:
            print(" n/a |", end="")
        else:
            b[0] = 3
            res = "view" if a[0] == 3 else "copy"
            print(f" {res} |", end="")
    print()

crusaderky avatar Apr 23 '25 12:04 crusaderky

I ran a test on a bunch of functions with obvious no-op use cases, and I think the picture is paints is very problematic:

I agree, but it's primarily a historical design issue in NumPy et al., I don't know what to do about it here.

The one thing that does seem like we should address is making array_api_compat not mismatch from the library it's wrapping (e.g., array_api_compat.numpy 2.2 vs numpy 2.2) if that can be avoided.

rgommers avatar Apr 24 '25 11:04 rgommers

cc @ev-br

kgryte avatar May 29 '25 09:05 kgryte

This indeed looks like a collection of array-api-compat issues, which stem of this sort of constructions:

def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
    if xp.issubdtype(x.dtype, xp.integer):
        return x
    return xp.trunc(x, **kwargs)

So I'd suggest we transfer this issue to array-api-compat and address it in there (most likely, with some if numpy.__version__ branching).

ev-br avatar May 29 '25 13:05 ev-br

@ev-br Transferred.

kgryte avatar May 29 '25 15:05 kgryte

https://github.com/data-apis/array-api-compat/pull/333 loops over unary functions and fixes trunc, ceil and floor. Potentially there are more potentially trickier cases, but those are not tested for in gh-333 just yet.

ev-br avatar Jun 01 '25 13:06 ev-br

To clarify my position on the matter: Since the spec does not mandate whether any given function returns a view or a copy, this is "unspecified, thus implementation-defined". This is one place where the whole Array API abstraction leaks (inevitably, IMO).

Thus, what's left in the array-api-compat level is to decide what we do about it. The only reasonable thing IMO is to declare that for a bare array library X, its wrapped version behaves the same. A small favor to users is to extend this to library versions: if a library version X.Y returns a view, so does its wrapped version. If we all agree to this, maybe it's worth spelling it out explicitly somewhere in the docs. Not sure where though.

This of course does not shield users from surprises where, say, numpy returns a view but jax.numpy returns a copy. Not much we can do about it, maybe we could document it somewhere (not sure where either).

ev-br avatar Jun 02 '25 18:06 ev-br

A small favor to users is to extend this to library versions: if a library version X.Y returns a view, so does its wrapped version.

For array_api_compat.dask.array.asarray(copy=None) we did the opposite: array_api_compat changes the behaviour of old wrapped versions to behave like the latest one. I think this makes things easier for final users.

This of course does not shield users from surprises where, say, numpy returns a view but jax.numpy returns a copy.

It's not really meaningful to talk about views for immutable backends, as their only impact is memory usage. But yes, your comment stands e.g. for numpy vs. cupy vs. torch.

crusaderky avatar Jun 03 '25 07:06 crusaderky

Agreed on both counts: the dask.array.asarray(copy=None) is a conscious break of the version policy so that the dask wrapper layer is forward compatible; there are no views at all in jax.numpy, so anything that returns a view in other backends returns a copy for both jax.numpy. And we don't wrap jax anyway. Modulo these two corrections, we seem to be in agreement.

ev-br avatar Jun 03 '25 08:06 ev-br