array-api icon indicating copy to clipboard operation
array-api copied to clipboard

RFC: Why `stable=True` is the default for `xp.sort`?

Open cakedev0 opened this issue 3 months ago • 12 comments

In xp.sort and xp.argsort, the sort is stable by default.

I understand very well the interest for argsort: I've seen a lot of bugs because people expected np.argsort to be stable.

And I guess sort is stable by default to match argsort behavior. But:

  • stable sort is slower. Typically, 8x slower in numpy on my machine.
  • for totally ordered data types, stable sort output is indistinguable from unstable sort output
  • for complex numbers the doc says:

For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent

So: are we sure it's a good idea to default to stable sort?

Alternatively: shouldn't we call unstable sort under-the-hood for non-complex numbers, even when stable=True?

cakedev0 avatar Oct 03 '25 15:10 cakedev0

I tried to do a bit of digging regarding rationale. See https://github.com/data-apis/array-api/pull/19 for the PR which added the sorting methods. I need to track down the referenced workgroup meeting minutes. Once I do, I'll try and circle back, as my memory is a little fuzzy on why stable=True apart from doing the more obvious thing and requiring users to opt-out of the obvious thing if a reproducible sort order is lower priority than perf.

kgryte avatar Oct 05 '25 22:10 kgryte

doing the more obvious thing and requiring users to opt-out of the obvious thing if a reproducible sort order is lower priority than perf.

I understand this for argsort, and it makes a lot of sense! But I don't understand this for sort: because sort output is just the exact same whether it's stable or not, for any total order.

cakedev0 avatar Oct 15 '25 09:10 cakedev0

sort output is just the exact same whether it's stable or not, for any total order.

While true for integer inputs, this isn't true for floating-point values. In the specification, we allow, but do not require, signed zeros to be sorted. Similarly, for NaNs. See https://data-apis.org/array-api/latest/API_specification/sorting_functions.html.

kgryte avatar Oct 15 '25 10:10 kgryte

Oh! I didn't have that in mind, thanks a lot for the explanation! Floating-point values are complex objects... 😅

Well, I think we can close this issue.

Though, we might want to attract the attention of developers from scipy/scikit-learn/array-api-extra/etc. about xp.sort being stable by default and being typically 6-9x slower for numpy than unstable sort.

cakedev0 avatar Oct 15 '25 19:10 cakedev0

Hum... And generally speaking, it would probably be worth to raise awareness for Array API developers about the differences between np.sort and xp.sort:

Stability:

xp.sort is stable by default, np.sort isn't. In most case, when using sort and not argsort, you don't care much about sort stability. So the performance cost should be well known and kept in my mind, we're talking about 6-9x slower for numpy.

NaNs handling

The specs says that:

Implementations may choose to sort NaNs (e.g., to the end or to the beginning of a returned array) or leave them in-place.

I'm pretty sure that scipy quantile implementation relies on NaNs being at then end after sorting, for nan_policy="omit" - ping @mdhaber (these lines just changes n to handle the NaNs, which is valid only if the NaNs are at the end)

And scikit-learn weighted percentile function relies on NaNs being at the end, as this comment in the code explicitly says. Ping @EmilyXinyi and @lucyleeow. Edit: I see this is already documented: https://github.com/scikit-learn/scikit-learn/issues/31368

cakedev0 avatar Oct 15 '25 21:10 cakedev0

which is valid only if the NaNs are at the end

I know. Need to make progress somehow, though. It would be nice if it were just standardized if current implementations all do it that way.

mdhaber avatar Oct 15 '25 21:10 mdhaber

Need to make progress somehow, though

Very understandable, that was also the conclusion of the discussion for sklearn's weighted percentile function: https://github.com/scikit-learn/scikit-learn/issues/31368#issuecomment-3280938275


But I would say it's not great to specifically implement/rewrite functions based on some given specs while ignoring some part of the specs 😅 Can't we find a way to avoid that?

It would be nice if it were just standardized if current implementations all do it that way.

Apparently, it might not be 100% true for Pytorch, see:

  • https://github.com/pytorch/pytorch/issues/46544#issuecomment-883356705
  • https://github.com/pytorch/pytorch/issues/116567

What do you think of adding a kwarg nan_policy="end"|None in sort/argsort and write a wrapper to put the NaNs at the end when nan_policy=="end" for the implementations that don't put NaNs at the end (only Pytorch for now, as far as I know)?

That should be doable with some with some masking tricks I guess?

cakedev0 avatar Oct 15 '25 22:10 cakedev0

That should be doable with some with some masking tricks I guess?

Yes, if you want to do it without changing the algorithm, e.g. replacing all $n$ NaNs with positive infinity (or the maximum value allowed by the dtype), sorting as usual, and replacing the last $n$ elements of the output with NaN.


What do you think of adding a kwarg nan_policy

To the question of whether nan_policy would be a good name, I'd say probably not because it has a somewhat different meaning here than nan_policy used elsewhere in the ecosystem (e.g. scipy.stats. stats would probably call this propagate, arguing that the output of sort satisfies not (x[i] > x[i + 1]) for all valid indices, and NaNs "propagate" through the sorting algorithm according to that rule.)

As for whether a keyword is appropriate - we run into the usual problem of a balance between functionality, performance, and API simplicity.

  1. If you don't add the keyword and don't add a NaN requirement, we don't have the functionality we want.
  2. If you don't add the keyword and require always pushing NaNs to the end, there are potential performance implications.
  3. If you do add the keyword, you complicate the API (for something that I assume most users would view as an extreme edge case).

I doubt that everyone will agree on which is the least worst. I've suggested 2 over 3 because I think the downside has less impact overall, since most libraries wouldn't need to change anything.~~and - if recent history is any guide - non-compliant libraries would probably ignore the suggestion, at least in the short term.~~

mdhaber avatar Oct 15 '25 23:10 mdhaber

Dataframes have different names for this, but the problem is also that the meaning is tricky w.r.t. descending/ascending sorts.

My 2c on stable= is that you could just make it a required kwarg (for the user) in the standard and then you can even suggest stable=False, dunno if NumPy can change or not. Quite honestly, it seems like a small inconvenience to always pass it. We are talking about library authors after all, not really small scripts... They don't use "array-api-strict".

If you don't add the keyword and require always pushing NaNs to the end, there are potential performance implications.

How bad is that really? I don't really think it is all that expensive in practice IIRC, although I guess it would be more so for partial sorting. (I.e. sorting is normally far more expensive than a single light-weight O(N) pass to deal with NaNs.)

IIRC it is torch that had undefined behavior? Can you convince them to change that? And does the "or at the beginning" come from somewhere actually doing that?

About Nan placemnt kwarg

On Nan order, I don't care much about what the name is, but I do care about the name and parameter not referring to the value such as nan>inf. The difference is that ascending/descending sorts should place the NaNs both to the end (or front). Dataframe libraries have various naming schemes and it is a huge source of confusion because names are rarely quite clear about things (at least if you come with the opposite mindset). And the problem is... they don't agree on how it's done. (I think the reason is that nan>inf is more useful from an implementation point of view, while "put at end" is the much clearer user API.)

But, the important thing seems actually to have the "undefined" option (which I guess is the default because is what torch does I think?) that gives you garbage if you have any NaNs? At least unless torch(?) (others?) can be convinced to always deal with NaNs.

One thing to remember is top_k which needs "to end" sorting for a default "omit" nan-plicy, but "to start" sorting for a "propagate" one. (Although "omit" could include NaNs if there aren't enough non-nan values)

seberg avatar Oct 16 '25 08:10 seberg

One thing this RFC seems to assume is that stable sorting algorithms are inherently slower than unstable sorting algorithms. While in NumPy this may be the case, I am not convinced that this is true more generally.

So my question is how much the difference in performance is due to NumPy's choice of algorithms (and associated implementation) than the fact that the specification requires stable sorting by default.

kgryte avatar Oct 16 '25 10:10 kgryte

Question 1: should we discuss this stable=True default?

One thing this RFC seems to assume is that stable sorting algorithms are inherently slower than unstable sorting algorithms.

Indeed, that's my assumption. I looked into the subject, and here's a quick overview of sorting algorithms used in array libraries:

  • GPU: mostly radix sort or variants => inherently stable (so no difference between stable=True and stable=False)
  • CPU - numpy
    • introsort (aka "safe" quicksort): fast, unstable sort
    • radix sort: relatively fast, stable sort. Numpy uses it for bools and small ints (8 or 16 bits). It's faster than introsort for small dtypes and big arrays. For 32 or 64 bits, I tested it and found it slower than introsort, but not by much (about 2x).
    • timsort: "slow" stable sort. Numpy uses it for other dtypes. Typically 8x slower than introsort.
  • CPU - others: torch and jax sorts are about 10x slower than numpy's introsort for floats. I didn't dig too much, as I think than numpy largely dominates CPU backends.

While in NumPy this may be the case, I am not convinced that this is true more generally.

That's fair, there might be room for improvement on the numpy's stable sort. However:

  • it will likely remain significantly slower
  • for now it's much slower, 8x is a lot.

Also, I suspect that a lot of code currently written against the Array API spec is a rewrite of existing NumPy-based code. We don’t want performance regressions for users who will continue to rely on NumPy (which is likely the majority).

Conclusion: in my opinion, this is something that should be discussed/addressed.


Question 2: how to address this?

Option 1:

My 2c on stable= is that you could just make it a required kwarg (for the user) in the standard

I like that, clean and simple. It's a backward incompatible change, but it shouldn't be too hard to address for libraries maintainers. Still, here’s an alternative:

Option 2: For numpy and for numerical dtypes, just always do unstable sort, even when calling xp.sort(..., stable=True). Given that for numerical dtypes, the only difference is the position of signed zeros... And because np.float64(0.) == np.float64(-0.), you have: np.array_equal(np.sort(x), np.sort(x, stable=True), equal_nan=True) for any x of any numerical dtype... So why bother stable sorting?

For argsort, the slowdown is "only" 2x, so it’s more acceptable to always use a stable sort there (and defaulting to stable sort definitely helps prevent bugs).

cakedev0 avatar Oct 17 '25 21:10 cakedev0

The 8x is presumably that there are fast SIMD implementation for unstable only, not algorithmic as such. But doubt that actually changes anything.

seberg avatar Oct 18 '25 06:10 seberg