matplotlib
matplotlib copied to clipboard
Update `_unpack_to_numpy` function to convert JAX and PyTorch arrays to NumPy
PR summary [Testing in process]
This PR closes #25882 by modifying the _unpack_to_numpy
function. The main changes are the following.
- Added an
if
condition to check if an object has__array__
method and the new object returned by accessing__array__
method is a NumPy array. - Added an
if
condition to capture NumPy scalars which were not captured by thendarray
check earlier. This was needed because otherwise NumPy scalar objects get infinitely stuck into__array__
check since they get converted tondarray
upon calling__array__
method on them.
PR checklist
- [x] "closes #0000" is in the body of the PR description to link the related issue
- [x] new and changed code is tested
- [x] Plotting related features are demonstrated in an example
- [x] New Features and API Changes are noted with a directive and release note
- [x] Documentation complies with general and docstring guidelines
Please remove the unrelated changes.
Hmm, it seems like this is probably the clearest indication of the test failures:
> assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours"
E AssertionError: assert None == 'hours'
As the Quantity
test class has an __array__
method
https://github.com/matplotlib/matplotlib/blob/10e8bf134c77e73863c65a1f246bb8d9a684fc59/lib/matplotlib/tests/test_units.py#L41-L42
the units will be dropped and the test fail.
I have no idea about the unit support though, so not really clear how to get around it...
FYI I don't think you're supposed to call __array__
directly:
Users should not call this directly. Rather, it is invoked by numpy.array() and numpy.asarray().
But not sure how much it matters in practice, or whether another library like matplotlib is a "user" in this context.
This is hitting the same reason why we do not blindly call np.asanyarray(...)
on all of our input because that can strip off unit information (which will break some of our users, changing that is off the table for now).
Between this and the note about not calling __array__
directly probably means this particular solution is not the right path.
Thank you for a diverse set of feedback, @oscargus, @mwaskom and @tacaswell. So, what do you think would be a better way to go ahead?
- Implement
to_numpy()
method in JAX and PyTorch. - JAX and PyTorch specific
if
check in_unpack_to_numpy
function:
if str(type(x)) == "<class 'torch.Tensor'>":
return x.numpy()
if str(type(x)) == "<class 'jaxlib.xla_extension.ArrayImpl'>":
return np.asarray(x)
Or something completely different from these directions?
Do not trust this fully, but I think that checking if there is a numpy
method and then calling that (and checking the output) would probably make sense and solve torch.Tensor
. But I also think that from Matplotlib's perspective (and probably many other libraries), having a well defined interface, like to_numpy
, on everything that can be converted to NumPy would make most sense.
I am very :-1: on added string checking of types.
Unfortunately we are in a awkward bind where we very permissive in what we take as input, do not want to depend on any imports, and due to the diversity of input we can not treat them all the same.
I have not checked, but maybe there is something in the python array API standard - At least it would belong there.
Would __array_namespace__
be a solution for us?
There's a third option not mentioned here: use __array__
as the standard convert-to-array function, and give libraries where this is not the apropriate behavior here some way to opt-out. Adding a new standard (to_numpy
) with poorly-defined semantics is not a great way forward in my opinion.
My sense is that the direction data libraries would like to move is for exchange via the __array__
method, so despite throwing some cold water on that above, it's probably a better path than trying to get torch to add to_numpy
in addition to numpy
and __array__
.
We discussed the history of this a bit on the dev call today. I think the below is close to correct, but I could be misunderstanding:
We officially support numpy arrays as inputs to our data plotting functions.
We also officially support mechanisms for objects to get passed that contain "unit" information (eg pint). Somewhat confusingly, this unit information is sometimes at the container level (eg pint), and sometimes at the element level, or the dtype of the elements (eg nparrays of datetime64, or lists of strings).
We unofficially support xarray and pandas objects, assume they have no units, by calling their values
or to_numpy
methods.
At the level that _unpack_to_numpy
is called, we cannot strip units from objects with units, because they have not been checked for yet. In the case discussed here, it is indeed the unit checking that is slowing things down.
After we have checked for units, we usually call np.asarray
. But we can't call that right away because of our unit support.
I'm not sure what the path out of the conundrum is - I somewhat feel the unit conversion interface should have been less magical, and more explicit, so users would have to specify a converter on an axis manually, rather than us guessing the converter.
Would
__array_namespace__
be a solution for us?
That isn't quite the right thing; the array API standard is meant to use "native functions", so this method is what you'd use if you want to retrieve the torch
namespace and use torch.asarray
& co. Here you specifically want numpy
arrays instead.
I agree with @jakevdp and @mwaskom that use of __array__
is more idiomatic. The most standard thing is np.asarray
(which relies on __array__
, or the Python buffer protocol, or DLPack), but if that's too permissive than using __array__
directly is fine.
After we have checked for units, we usually call
np.asarray
. But we can't call that right away because of our unit support.
If units libraries silently lose data when np.asarray
is called on their container objects, they really should implement __array__
and make it raise an exception. This is also what, for example, sparse arrays do.
From an interface perspective, it's reasonable to rely on __array__
. I think we should investigate how we can make this work internally.
@timhoffm It has been a while since I worked on this PR. Can you please suggest if your latest suggestion in #25882 will resolve the issue we are discussing in this PR? If not, could you please suggest potential workarounds?
@patel-zeel in my comment https://github.com/matplotlib/matplotlib/issues/25882#issuecomment-1872440671, I hadn't considered the unit problem. That indeed makes the problem much more complicated.
To all: To summarize and comment on the above proposed solutions:
-
Implement
to_numpy()
method in JAX and PyTorch IMO this won't happen. With some right, they say that__array__
is nowadays their ideomatic hook to turn them into numpy arrays. - Import and type-check by type We don't want to try and import complex libraries just because someone might have passed an element of that type. This would be a performance hit for all users that happen to have that library installed, but don't use it.
- type-check by string This is inelegant and brittle.
- Rework our unit handling system Any changes to unit handling, that could help here, would definitively be a major project in matplotlib, and would likely also require changes for some downstream users.
There is no easy solution here. Special situations sometimes require special measures:
Given all the boundary conditions, I'd be +0.5 on type-checking by string, despite @tacaswell being strongly 👎 on this. Usually, I'd agree, but that's the only realistic way forward. 1. won't happen; 2. is introducing strong coupling, which IMHO is worse; 4. won't realistcally happen, because we don't have the capacity for it.
So what would we buy into with type-checking by string. Drawbacks are (1) the str comparison is slower than a type check - but that should be negligible; and (2) It's brittle because the str representation could change without us noticing and then the functionality would be broken. To alleviate (2), we could use f"{type(x).__module__}.{type(x).__qualname__}"
That leaves out unnecessary fluff and would only change when the libraries reorganize. Additionally, in the worst case scenario that the string changes, we would fall back to the current solution.
In short: we can easily make using JAX/Torch arrays faster with the string-type check; With a not too high likelihood, that can break in the future, which would bring us back to where we are now. - Sounds like a reasonable deal to me.
The only other alternative would be to tell users to convert their JAX/Torch arrays explicitly (or live with the performance impact). But that'd be not user friendly.
- won't realistcally happen, because we don't have the capacity for it.
This is what Kyle is working on, but is 1-2 years off, but I don't think we should wait for it.
I am convinced by @timhoffm 's analysis and am also +0.5 on string typing now.
Couldn’t you accomplish option 2 without the performance impact by looking to see if certain modules are already in sys.modules
?
Good point! That may indeed work. e.g. “torch in sys.modules
is very likely to be true if a user can create and pass to us a torch tensor. It‘s not a strict prerequisite, but the the edge cases would require non-canonical imports (e.g. import torch as _torch
) or messing with the import system or sys.modules. Such cases would not be detected and fall back to the current implementation.
To be clear: This is viable here, because it‘s works under normal circumstances (98%), and the missing cases are covered by the current fallback.
@jakevdp @timhoffm Are you suggesting something like this?
...
if "torch" in sys.modules:
from torch import Tensor as TorchTensor
if isinstance(x, TorchTensor):
return x.numpy()
if "jax" in sys.modules:
from jax import Array as JAXArray
if isinstance(x, JAXArray):
return x.__array__()
...
Yeah, probably something like that. I might even avoid the imports, and wrap with a try
/except
just in case there's something strange (like torch.py
in the namespace, which either doesn't define Tensor
or defines Tensor
as a non-type):
import sys
def is_torch_array(x) -> bool:
try:
return isinstance(x, sys.modules['torch'].Tensor)
except: # TypeError, KeyError, AttributeError, maybe others?
return False
def is_jax_array(x) -> bool:
try:
return isinstance(x, sys.modules['jax'].Array)
except: # TypeError, KeyError, AttributeError, maybe others?
return False
I think the only way to confound this would be if the user manually deleted entries from sys.modules
, in which case the worst thing that would happen is a false-negative output.
This also would possibly return false negatives for unrecognized versions: e.g. jax.Array
was introduced in version 0.4.X (first released December 2022); this would incorrectly return False
if an array were created with an older JAX version.
@jakevdp Thanks for the improved solution! I'd have never thought about torch.py
in the namespace and various exceptions in different scenarios.
@timhoffm I have applied @jakevdp's solution. I am now wondering how to test this feature or should we even consider testing it.
I am now wondering how to test this feature or should we even consider testing it.
Testing this is quite hard because we don't have PyTorch or JAX as test dependencies. The only possibility I see is mocking. The mock would be quite implementation specific (make a namespace with a Tensor class with an __array__
attribute). So we cannot really test whether _is_torch_array actually detects torch arrays. But what is still valuable is passing such a mock to _unpack_to_numpy
and testing that the numpy array from __array__
is retuned.
@timhoffm Thanks for the review and suggestion for testing. I have applied the suggested changes and implemented the first version of testing for this feature.
@jakevdp How'd you suggest abstracting this? Some relevant points:
- I tried creating a
cupy
array and getting the NumPy array with__array__()
method but it breaks and suggests to use.get()
method.
import cupy
array = cupy.array([1, 2, 3.0])
np_array = array.__array__() # fails
# np_array = array.get() # this works
Output:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[3], [line 1](vscode-notebook-cell:?execution_count=3&line=1)
----> [1](vscode-notebook-cell:?execution_count=3&line=1) array.__array__()
File cupy/_core/core.pyx:1475, in cupy._core.core._ndarray_base.__array__()
TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.
- I came across this issue on PyTorch and realized that
.numpy(force=True)
method can help in cases where arrays need.detach()
and/or.cpu()
before.__array__()
can successfully get the underlying NumPy array. As far asmatplotlib
is concerned, I guess.numpy(force=True)
can be a better alternative compared to.__array__()
for PyTorch (or even.detach().cpu().numpy()
can work to support the older versions).
Considering both of the above cases, discussion in #25882, and discussion in this PR, would it be better to provide two methods, is_type
and to_numpy
, for each object like the following?
import sys
import numpy as np
from abc import ABC, abstractmethod
class TypeArray(ABC):
@abstractmethod
def is_type(x):
pass
@abstractmethod
def to_numpy(x):
pass
class TorchArray(TypeArray):
def is_type(x):
"""Check if 'x' is a PyTorch Tensor."""
try:
# we're intentionally not attempting to import torch. If somebody
# has created a torch array, torch should already be in sys.modules
return isinstance(x, sys.modules['torch'].Tensor)
except Exception: # TypeError, KeyError, AttributeError, maybe others?
# we're attempting to access attributes on imported modules which
# may have arbitrary user code, so we deliberately catch all exceptions
return False
def to_numpy(x):
"""Convert to NumPy array"""
# preferred over `.numpy(force=True)` to support older PyTorch versions.
return x.detach().cpu().numpy()
class JaxArray(TypeArray):
def is_type(x):
"""Check if 'x' is a JAX array."""
try:
# we're intentionally not attempting to import jax. If somebody
# has created a jax array, jax should already be in sys.modules
return isinstance(x, sys.modules['jax'].Array)
except Exception: # TypeError, KeyError, AttributeError, maybe others?
# we're attempting to access attributes on imported modules which
# may have arbitrary user code, so we deliberately catch all exceptions
return False
def to_numpy(x):
"""Convert to NumPy array"""
return x.__array__() # works even if `x` is on GPU
class CupyArray(TypeArray):
def is_type(x):
"""Check if 'x' is a CuPy array."""
try:
# we're intentionally not attempting to import cupy. If somebody
# has created a cupy array, cupy should already be in sys.modules
return isinstance(x, sys.modules['cupy'].ndarray)
except Exception: # TypeError, KeyError, AttributeError, maybe others?
# we're attempting to access attributes on imported modules which
# may have arbitrary user code, so we deliberately catch all exceptions
return False
def to_numpy(x):
"""Convert to NumPy array"""
return x.get()
external_objects = [TorchArray, JaxArray, CupyArray]
def _unpack_to_numpy(x):
"""Internal helper to extract data from e.g. pandas and xarray objects."""
if isinstance(x, np.ndarray):
# If numpy, return directly
return x
if hasattr(x, 'to_numpy'):
# Assume that any to_numpy() method actually returns a numpy array
return x.to_numpy()
if hasattr(x, 'values'):
xtmp = x.values
# For example a dict has a 'values' attribute, but it is not a property
# so in this case we do not want to return a function
if isinstance(xtmp, np.ndarray):
return xtmp
for obj in external_objects:
assert isinstance(obj, TypeArray)
if obj.is_type(x):
xtmp = obj.to_numpy(x)
# In case to_numpy() doesn't return a numpy array in future
if isinstance(xtmp, np.ndarray):
return xtmp
return x
IMHO further abstraction would be premature. The current implementation is simple and good enough. Paraphrased from https://youtu.be/UANN2Eu6ZnM?feature=shared
If something happens for the first time, do a concrete implementation. If it happens for the second time, copy andcadaptvrhe code. If it happens for the third time, factor out commonalities.
This has two major advantages: 1. You don't create abstractions that you don't use. 2. When you build the abstraction, you have three concrete use cases, so it's more likely the abstraction is suitable.
I didn't mean to suggest any complicated abstraction; I was thinking something simple like this:
ARRAYLIKE_OBJECTS = [('jax', 'Array'), ('torch', 'Tensor')]
def maybe_convert_to_array(x):
for mod, name in ARRAYLIKE_OBJECTS:
try:
is_array = isinstance(x, getattr(sys.modules[mod], name)):
except Exception:
pass
else:
if is_array: return np.asarray(x)
return x
It reduces duplication of logic and makes it easier to add additional types if/when needed.
If you wanted to add cupy support, it would just require doing ARRAYLIKE_OBJECTS.append(('cupy', 'ndarray'))
I didn't mean to suggest any complicated abstraction; I was thinking something simple like this:
[...]
It reduces duplication of logic and makes it easier to add additional types if/when needed. If you wanted to add cupy support, it would just require doing
ARRAYLIKE_OBJECTS.append(('cupy', 'ndarray'))
Yes, that would be marginally better, and can optionally be done. In the interest of not endlessly bikeshedding the PR, I have accepted the current version. After all, this is all internal and can be refactored any time.
OK, apologies for not paying attention to this properly, but hard-coding certain libraries to have a cut around seems incorrect and brittle to me. What criteria will we have if we get requests to support other libraries?
I think the fundamental problem is with where cbook._reshape_2D
gets called. This method is only used for hist
and a couple of the helpers for violin_stats
and boxplot_stats
.
I think it would be a mistake to change _unpack_to_numpy
, which is used for all unit conversion, and hence almost every other plot method, to work around this problem. If it were me, I'd split the preprocessing in hist
and friends to keep units where they are required (eg using Quantity
), and strip them when we are ready to do so.
The methods where this gets used are all binning methods. The bins
need to keep units (so they can be added to the axes properly). However, the data needs to be turned into an array to pass to histogram
(or other stats functions). I think the proper solution here is to properly differentiate these roles in hist
(and friends).
hard-coding certain libraries to have a cut around seems incorrect and brittle to me.
This is indeed a workaround. The proper way would be for _unpack_to_numpy()
to use the __array__
interface for all types if available (maybe through np.asarray()
). However, in the current internal usages this may lead to loss of units. If you have an alternate proposal how to make JAX and Pytorch arrays work, I'm more than happy to take that. - t.b.h. I don't fully oversee the unit handling and its implications.
Otherwise, I think this PR is good enough to be included in 3.9. It achieves the desired speedup and otherwise is completely internal, so we can still change the implementation whenever we like.
What criteria will we have if we get requests to support other libraries?
Case-by-case. Support them if it's easily possible, don't if it's not. There's little maintanance burden and no API liability. Also, I don't expect that there would be more than a hand full of such libraries.