array-api
array-api copied to clipboard
RFC: require that dtypes obey Python hashing rules
Python's documentation promises that: "The only required property is that objects which compare equal have the same hash valueβ¦" However, NumPy dtypes do not follow this requirement. As discussed in https://github.com/numpy/numpy/issues/7242, dtype objects, their types, and their names all compare equal despite hashing unequal. Could the Array API promise that this will no longer be the case?
That seems fine to me to explicitly specify. float32 == 'float32'
should clearly return False. In NumPy it's a bit messy:
>>> import numpy as np
>>> np.float32 == 'float32'
False
>>> np.dtype(np.float32) == 'float32'
True
Only the first example is relevant for the array API standard, so I think this will be fine to specify since NumPy already complies.
This one there is a problem in NumPy however:
>>> np.dtype(np.float64) == float
True
That can be considered a clear bug though, should be fixed in NumPy.
Only the first example is relevant for the array API standard, so I think this will be fine to specify since NumPy already complies.
So you're saying that np.dtype(np.float32) == 'float32'
will be true or false?
That can be considered a clear bug though, should be fixed in NumPy.
Agreed.
What about np.float32 == np.dtype(np.float32)
?
This also violates Python's hashing invariant.
What about
np.float32 == np.dtype(np.float32)
?
I agree that it's a bug technically. Not 100% sure that the NumPy team will want that changed, but I hope so (and a proposal for a major release is in the works, so that could go into it). For the array API standard it's not an issue, because there is no dtype
constructor/function in the main namespace.
So you're saying that
np.dtype(np.float32) == 'float32'
will be blocked or not?
That's more for the NumPy issue tracker, but if it were up to me then yes.
For this issue tracker, I'm +1 on adopting language in the standard like: "All objects in this standard must adhere to the following requirement (as required by Python itself): objects which compare equal have the same hash value".
For this issue tracker, I'm +1 on adopting language in the standard like: "All objects in this standard must adhere to the following requirement (as required by Python itself): objects which compare equal have the same hash value".
That would be amazing. That's exactly what I was hoping for.
I agree that it's a bug technically. Not 100% sure that the NumPy team will want that changed, but I hope so (and a proposal for a major release is in the works, so that could go into it). For the array API standard it's not an issue, because there is no dtype constructor/function in the main namespace.
Okay, thanks for explaining. If the above language were adopted, NumPy could implement that by making xp.float32
not simply equal to np.dtype(np.float32)
, but rather a special dtype object that doesn't have the pernicious behavior.
Let's give it a bit of time to see if anyone sees a reason not to add such a requirement. I can open a PR after the holidays.
For the array API standard it's not an issue, because there is no dtype constructor/function in the main namespace.
Just noticed this comment. It is currently an issue in NumPy's implementation of the Array API:
import numpy.array_api as xp
xp.float32 == xp.float32.type # True!
This is because xp.float32
points to an object np.dtype(np.float32)
. For this to be fixed, NumPy would just need a new dtype class for use in its Array API xp
.
With the language you suggested above, NumPy would be forced to do this to become compliant π .
So you're saying that np.dtype(np.float32) == 'float32' will be blocked or not?
That's more for the NumPy issue tracker, but if it were up to me then yes.
Same thing here, I think. NumPy will probably reject this for their own namespace (np
), but if you adopt that language, they would have to fix it in the array API (xp
).
Incidentally, I assume you want numpy.array_api.float32
to compare equal to jax.array_api.float32
? Since there is no root project to provide a base implementation of dtypes, you may need to standardize how dtype.__hash__
and comparison work.
xp.float32 == xp.float32.type # True!
There is no float32.type
in the standard. That it shows up with numpy.array_api.float32
is because the dtype objects there are aliases to the regular numpy ones, rather than new objects. That was a shortcut I think, because adding new dtypes is a lot of work. So that's one place where currently numpy.array_api
doesn't 100% meet its goal of being completely minimal.
Incidentally, I assume you want
numpy.array_api.float32
to compare equal tojax.array_api.float32
?
No, definitely not. No objects from two different libraries should ever compare equal, unless they're indeed the same object.
So that's one place where currently numpy.array_api doesn't 100% meet its goal of being completely minimal.
Ok! Thanks for explaining.
No, definitely not. No objects from two different libraries should ever compare equal, unless they're indeed the same object.
So to do things like checking that two arrays have the same dtype, or creating a NumPy array that has the same type as a Jax array, we'll need mappings like:
m = {jax.array_api.float32: np.array_api.float32, ...}
And code like
np.array_api.ones_like(some_jax_array) # works today, in either direction.
is impossible, yes? You need:
np.array_api.ones(some_jax_array.shape, dtype=m[some_jax_array.dtype])
So to do things like checking that two arrays have the same dtype ...
Having to use library-specific constructs should not be needed - if so, we're missing an API I'd say. More importantly: mixing arrays from different libraries like this is a bit of an anti-pattern. You can't do much with that, neither library has kernels for functions that use both array types, so you're probably relying on implicit conversion of one to the other.
So in this case, let me assume that x
is a numpy array, y
a JAX array and you're wanting to use functions from x
(numpy):
# First retrieve the namespace you want to work with
xp = x.__array_namespace__()
# Use DLPack or the buffer protocol to convert a CPU JAX array to a NumPy array
y = xp.asarray(y)
# Now we can compare dtypes:
if x.dtype == y.dtype == xp.float32:
# If the same dtypes, do stuff
# Or, similarly:
if xp.isdtype(x, xp.float32) and xp.isdtype(y, xp.float32):
is impossible, yes? You need:
yes indeed
I'm actually a little surprised JAX accepts numpy arrays. It seems to go against its philosophy; TensorFlow, PyTorch and CuPy will all raise. When you call jnp.xxx(a_numpy_array)
, JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.
JAX is also annotating its array inputs as array_like
, but it doesn't mean the same as for NumPy:
>>> jnp.sin([1, 2, 3])
...
TypeError: sin requires ndarray or scalar arguments, got <class 'list'> at position 0
All this stuff is bug-prone:
>>> jnp.sin(np.array([1, 2, 3]))
Array([0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
>>> jnp.sin(np.ma.array([1, 2, 3], mask=[True, False, True])) # bug in user code here, because JAX silently discards mask
Array([0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
>>> np.sin(np.ma.array([1, 2, 3], mask=[True, False, True]))
masked_array(data=[--, 0.9092974268256816, --],
mask=[ True, False, True],
fill_value=1e+20)
More importantly: mixing arrays from different libraries like this is a bit of an anti-pattern. You can't do much with that, neither library has kernels for functions that use both array types, so you're probably relying on implicit conversion of one to the other.
Okay, makes sense. I haven't been very conscious about this because (as you pointed out) Jax implicitly converts. I will be more careful.
y = xp.asarray(y)
I think this is where I'm confused. Somehow numpy has to know what its equivalent dtypes are for Jax's dtypes even though they don't compare equal? Or will it produce a numpy array with a Jax dtype? As this seems to work:
In [12]: x = jnp.ones(10, jnp.bfloat16)
In [14]: np.asarray(x)
Out[14]: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=bfloat16)
When you call jnp.xxx(a_numpy_array), JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.
Very interesting. I wonder what the Jax team would say.
I think this is where I'm confused. Somehow numpy has to know what its equivalent dtypes are for Jax's dtypes even though they don't compare equal? Or will it produce a numpy array with a Jax dtype?
NumPy knows the dtype, as does JAX. This conversion uses the Python buffer protocol or DLPack, both of which are protocols explicitly meant for exchanging data in a reliable way (that includes dtype, shape, endianness, etc.). So the asarray
call will produce a numpy array with a numpy dtype, and to do so numpy does not need to know anything specifically about JAX.
When you call jnp.xxx(a_numpy_array), JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.
Very interesting. I wonder what the Jax team would say.
Let's try to find out:) This section of the JAX docs only explains why JAX doesn't accept list/tuple/etc., but I cannot find an explanation of why it does accept numpy arrays and scalars. @shoyer or @jakevdp, would you be able to comment on why JAX implements a limited form of "array-like"?
Also, in addition to bug with masked arrays above, here is another bug:
>>> jnp.sin(np.float64(1.5)) # silent precision loss here, downcasting to float32
Array(0.997495, dtype=float32)
>>> jax.__version__
'0.4.1'
So the asarray call will produce a numpy array with a numpy dtype, and to do so numpy does not need to know anything specifically about JAX.
In that case, there should be a way to convert dtypes using both the buffer protocol or DLPack? Something more efficient than:
def x_to_y_dtype(some_xp_dtype: DType, yp: ArrayInterfac) -> DType:
xp = some_xp_dtype.__array_interface__ # doesn't exist
x = xp.ones((), dtype=some_xp_dtype)
yp.asarray(x)
return yp.dtype
Should dtypes have a __array_namespace__
attribute? Currently, they don't. So, the above function can't be written unless you know xp
.
No, those protocols are specifically for exchanging data (strided arrays/buffers). A dtype without data isn't very meaningful. You could exchange a size-1 array if needed, or a 'float32'
string representation, or whatever works.
No, those protocols are specifically for exchanging data (strided arrays/buffers).
I understand, but in order to exchange data, they have to be able to convert dtypes. So, that dtype conversion is happening somehow, and I was just wondering if that conversion can be accessed by the user.
It's not user-accessible, it's all under the hood.
Specifically for JAX you have a shortcut, because it reuses NumPy dtypes directly:
>>> type(jnp.float32)
<class 'jax._src.numpy.lax_numpy._ScalarMeta'>
>>> type(jnp.float32.dtype)
<class 'numpy.dtype[float32]'>
(Thanks for all the patient explanations!)
@shoyer or @jakevdp, would you be able to comment on why JAX implements a limited form of "array-like"?
JAX avoids implicit conversion of Python sequences, because it can hide severe performance issues. When something like x = [i for i in range(10000)]
is passed to the XLA compiler, it is passed as a list of 10000 XLA scalars. We found this to be a common mistake people made, and decided to disallow it. This is discussed at Non-Array Inputs: Numpy vs. JAX.
On the other hand, np.arange(10000)
is a single XLA array, and doesn't have this problem. On CPU, the transfer can even be done in most cases in a zero-copy fashion, although on accelerators there will be a device transfer cost.
Also, in addition to bug with masked arrays above, here is another bug:
>>> jnp.sin(np.float64(1.5)) # silent precision loss here, downcasting to float32 Array(0.997495, dtype=float32) >>> jax.__version__ '0.4.1'
This is working as intended: JAX only allows 64-bit values when explicitly enabled; see Double (64-bit) Precision. This was an early design decision that the team recognizes as non-ideal, but it has proven difficult to change because so many users depend on the bit truncation behavior and enjoy the accelerator-friendly type safety it confers.
Specifically for JAX you have a shortcut, because it reuses NumPy dtypes directly:
>>> type(jnp.float32) <class 'jax._src.numpy.lax_numpy._ScalarMeta'> >>> type(jnp.float32.dtype) <class 'numpy.dtype[float32]'>
The reason JAX defines this is that it made the early design choice to not distinguish between scalars and zero-dimensional arrays. np.float32
, despite its common use (probably stemming from NumPy's anything-goes approach to dtype equality/identity that is the original reason for this issue) is not a dtype, but rather it is a scalar float32 type. When JAX added the jax.numpy
convenience wrapper around its core functionality, it needed dtype-specific scalar constructors similar to NumPy's np.float32
, np.int32
, etc. that would output appropriate zero-dimensional arrays. It does not make sense for jnp.float32
to be its own type, because unlike numpy there are no dedicated scalar types in JAX.
We could have defined simple functions named float32
, int32
, etc. but because np.float32
is so commonly used as a stand-in for np.dtype('float32')
, we needed the scalar constructor functions to be something that np.dtype
would treat as a dtype, and so the _ScalarMeta
classes were born.
>>> jnp.sin(np.ma.array([1, 2, 3], mask=[True, False, True])) # bug in user code here, because JAX silently discards mask Array([0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
To my knowledge, this bug has never come up (probably because masked arrays are so rarely used in practice). I'll raise it in the JAX repo.
Thanks for the context @jakevdp!
This is working as intended: JAX only allows 64-bit values when explicitly enabled; see Double (64-bit) Precision. This was an early design decision that has proven difficult to change
I knew that 64-bit precision must be explicitly enabled, but this is still surely a bug? The expected behavior is an exception, saying asking the user to explicitly downcast if the precision loss is fine, or to enable 64-bit precision. Or at the very least emit a warning. Silent downcasting is terrible - it may be okay for deep learning, but it typically isn't for general purposes numerical/scientific computing.
On the other hand, np.arange(10000) is a single XLA array, and doesn't have this problem. On CPU, the transfer can even be done in most cases in a zero-copy fashion, although on accelerators there will be a device transfer cost.
That might be unfortunate when one goes from converting a CPU program to a GPU one? It might be nice to be able to enable a flag that makes this into a runtime error. That way I can remove all of my unintentional jax/numpy array mixing.
Silent downcasting is terrible - it may be okay for deep learning, but it typically isn't for general purposes numerical/scientific computing.
I think you hit on the key point here: there are different communities with different requirements, and JAX attempts, maybe clumsily, to serve them all. If you are doing deep learning and care about performance over potential precision loss, you can set JAX_ENABLE_X64=0
. If you are using JAX for general purposes and don't want this, you can set JAX_ENABLE_X64=1
. The fact that the former is the default was an early decision based on initial uses of the package; we've actively discussed changing it, but it would be a big change and there are many pros and cons that must be weighed.
It's a difficult problem to solve well in a single package: it's worth noting that NumPy's answer to requests to serve the needs of deep learning is essentially no, which is a defensible choice given the package's early design decisions.
What about
np.float32 == np.dtype(np.float32)
?
This has been one of the few NumPy things that I dislike (and that would be moot for Array API). In NumPy, np.float32
is a Python type
>>> type(np.float32)
<class 'type'>
whereas np.dtype(np.float32)
is a dtype instance
>>> type(np.dtype(np.float32))
<class 'numpy.dtype[float32]'>
The former is needed, IIUC, only because of the need to construct NumPy scalars. Once NumPy removes this concept (how about NumPy 2.0, @seberg? π) we can (and should) make them equivalent!
I do not really want to touch removing scalars from NumPy; maybe someone more confident about it can push for such a thing...
Maybe to be clear, to change NumPy here I see now other way then (I think this is what Ralf said):
- Most dtype comparison that currently work must raise an error (this is very noisy and in some instances possibly hard to work around). The alternative of just returning False seems too error prone to me.
- Is
arr.dtype == np.dtype(...)
good enough, or do we need another way to spell that conveniently?arr.dtype.is_equiv("float32")
, ...?
If you remove scalars, then np.float32(0)
would have to raise an error, helps, but also noisy?
I don't see another way, so you can put it into np.array_api
or np.ndarray.__array_namespace__
, but np.float32
is borked and I don't see how to fix it except doing the above, and probably doing it very slowly.
I'd argue that if there's any design that could bring us closer to full compliance in the main namespace with the standard, we should consider it, and removing scalars in favor of 0D arrays is one of them. It's been a source of confusion with no obvious gain except for keeping legacy code work. It's been made clear that no accelerator library would support it. Also, removing scalars would keep the type promotion lattice cleaner.
So,
Is
arr.dtype == np.dtype(...)
good enough
Yes.
then
np.float32(0)
would have to raise an error, helps, but also noisy?
Not at all noisy π
and probably doing it very slowly
All I care is 1. eventual compliance, and 2. reducing both user confusion and developer (you) workload π If this is something that could take 1 full developer year to do, so be it.
You can change NumPy relatively easily. The problem is dealing with whether pandas and others need involved change. So the issue about scalars (and to some degree also this in general), is that it is very much holistic and I can zoom in on NumPy and give you a branch where scalars may still be in the code base but should never be created... But I am not sure I am equipped with understanding what would happen to pandas or ... if we do it.
(I am also admittedly the one person who hates everything about NumPy scalars, but isn't sure that scalars themselves are all that bad.)
Scalars themselves aren't that bad, if only they weren't created by operations like np.sum(x)
and x[()]
. If those started returning 0-D arrays that probably wouldn't even break that much in downstream libraries, the problem is the corner cases in end user code where they're used in places that Python scalars are expected.
I have a sense that it's doable in principle, but that it's one step too far for NumPy 2.0.
Yes, I would be willing to experiment with the "getting scalars more right part". But also yes: even that needs at least testing to have confidence that it would be 2.0 scoped (i.e. few enough users actually notice and if they do mostly in harmless ways).
Yes, I would be willing to experiment with the "getting scalars more right part". But also yes: even that needs at least testing to have confidence that it would be 2.0 scoped (i.e. few enough users actually notice and if they do mostly in harmless ways).
I'd be interested in helping out with an effort like this. But I don't think I can be one of the two "champions" (see here) for this one, I already signed up for enough other stuff for NumPy 2.0.
Going back to the original discussion, another annoying thing NumPy does is
>>> np.dtype('float64') == None
True
which has tripped us up a few times in the test suite.
Would it be possible in NumPy to make np.float64
just be np.dtype('float64')
by implementing __call__
and __instancecheck__
on it (the actual type of a scalar float64 would become a hidden _float64
class)? That wouldn't remove scalars but it would make it so that there is only one dtype object representing any given dtype.
That wouldn't remove scalars but it would make it so that there is only one dtype object representing any given dtype.
I love this idea. This would be a step towards what Leo wanted above: "any design that could bring us closer to full compliance in the main namespace with the standard". I think if we don't do what you're suggesting, it will be a source of confusion that np.float32
is a type, but np.array_api.float32
is a dtype.
by implementing
__call__
and__instancecheck__
(And __subclasscheck__
.)