xarray icon indicating copy to clipboard operation
xarray copied to clipboard

Use pytorch as backend for xarrays

Open fjanoos opened this issue 4 years ago • 49 comments

I would be interested in using pytorch as a backend for xarrays - because: a) pytorch is very similar to numpy - so the conceptual overhead is small b) [most helpful] enable having a GPU as the underlying hardware for compute - which would provide non-trivial speed up c) it would allow seamless integration with deep-learning algorithms and techniques

Any thoughts on what the interest for such a feature might be ? I would be open to implementing parts of it - so any suggestions on where I could start ?

Thanks

fjanoos avatar Aug 19 '19 21:08 fjanoos

If pytorch implements overrides of NumPy's API via the __array_function__ protocol, then this could work with minimal effort. We are already using this to support sparse arrays (this isn't an official release yet, but functionality is working in the development version).

I think there has been some discussion about this, but I don't know the current status (CC @rgommers). The biggest challenge for pytorch would be defining the translation layer that implements NumPy's API.

Personally, I think the most viable way to achieve seamless integration with deep learning libraries would be to support integration with JAX, which already implements NumPy's API almost exactly. I have an experimental pull request adding __array_function__ to JAX, but it still needs a bit of work to finish it up, e.g., we probably want to hide this behind a flag at first.

shoyer avatar Aug 20 '19 01:08 shoyer

I think there has been some discussion about this, but I don't know the current status (CC @rgommers).

The PyTorch team is definitely receptive to the idea of adding __array_function__ and __array_ufunc__, as well as expanding the API for better NumPy compatibility.

Also, they want a Tensor.__torch_function__ styled after __array_function__ so they can make their own API overridable.

The tracking issue for all of this is https://github.com/pytorch/pytorch/issues/22402

The biggest challenge for pytorch would be defining the translation layer that implements NumPy's API.

Agreed. No one is working on __array_function__ at the moment. Implementing it has some backwards compat concerns as well, because people may be relying on np.somefunc(some_torch_tensor) to be coerced to ndarray. It's not a small project, but implementing a prototype with a few function in the torch namespace that are not exactly matching the NumPy API would be a useful way to start pushing this forward.

rgommers avatar Aug 20 '19 02:08 rgommers

Personally, I think the most viable way to achieve seamless integration with deep learning libraries would be to support integration with JAX, which already implements NumPy's API almost exactly.

Less familiar with that, but pytorch does have experimental XLA support, so that's a start.

rgommers avatar Aug 20 '19 02:08 rgommers

Implementing it has some backwards compat concerns as well, because people may be relying on np.somefunc(some_torch_tensor) to be coerced to ndarray.

Yes, this is a concern for JAX as well. This is a definite downside of reusing NumPy's existing namespace.

It turns out even xarray was relying on this behavior with dask in at least one edge case: https://github.com/pydata/xarray/issues/3215

shoyer avatar Aug 20 '19 07:08 shoyer

This is a definite downside of reusing NumPy's existing namespace.

We didn't discuss an alternative very explicitly I think, but at least we'll have wide adoption fast. Hopefully the pain is limited ....

rgommers avatar Aug 20 '19 16:08 rgommers

I haven't used JAX - but was just browsing through its documentation and it looks super cool. Any ideas on how it compares with Pytorch in terms of:

a) Cxecution speed, esp. on GPU b) Memory management on GPUs. Pytorch has the 'Dataloader/Dataset' paradigm which uses background multithreading to shuttle batches of data back and forth - along with a lot of tips and tricks on efficient memory usage. c) support for deep-learning optimization algorithms ?

fjanoos avatar Aug 23 '19 15:08 fjanoos

Within a jit compiled function, JAX's execution speed should be quite competitive on GPUs. It uses the XLA compiler, which was recently enabled by default in TensorFlow.

For data loading and deep learning algorithms, take a look at the examples in the notebooks directory in the JAX repo. The APIs for deep learning in JAX are still undergoing rapid development, so APIs are not quite as stable/usable as pytorch or keras yet, but they are quite capable. See jax.experimental.stax and tensor2tensor.trax for examples.

shoyer avatar Aug 23 '19 17:08 shoyer

While it is pretty straightforward to implement a lot of standard xarray operations with a pytorch / Jax backend (since they just fallback on native functions) - it will be interesting to think about how to implement rolling operations / expanding / exponential window in a way that is both efficient and maintains differentiability.

Expanding and exponential window operations would be easy to do leveraging RNN semantics - but doing rolling using convolutions is going to be very inefficient.

Do you have any thoughts on this?

fjanoos avatar Aug 23 '19 18:08 fjanoos

I have not thought too much about these yet. But I agree that they will probably require backend specific logic to do efficiently.

On Fri, Aug 23, 2019 at 12:13 PM firdaus janoos [email protected] wrote:

While it is pretty straightforward to implement a lot of standard xarray operations with a pytorch / Jax backend (since they just fallback on native functions) - it will be interesting to think about how to implement rolling operations / expanding / exponential window in a way that is both efficient and maintains differentiability.

Expanding and exponential window operations would be easy to do leveraging RNN semantics - but doing rolling using convolutions is going to be very inefficient.

Do you have any thoughts on this?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/pydata/xarray/issues/3232?email_source=notifications&email_token=AAJJFVWRVLTFNT3DYOZIJB3QGASFBA5CNFSM4ING6FH2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD5A6IWY#issuecomment-524411995, or mute the thread https://github.com/notifications/unsubscribe-auth/AAJJFVQ7JBUNO3CAIFGVJ63QGASFBANCNFSM4ING6FHQ .

shoyer avatar Aug 23 '19 18:08 shoyer

This might be a good time to revive this thread and see if there is wider interest (and bandwidth) in having xarray use CuPy (https://cupy.chainer.org/ ) as a backend (along with numpy). It appears to be a plug-and-play replacement for numpy - so it might not have all the issues that were brought up regarding pytorch/jax ?

Any thoughts ? cc @mrocklin

fjanoos avatar Mar 30 '20 20:03 fjanoos

Just chiming in quickly. I think there's definitely interest in doing this through NEP-18.

It looks like CUDA has implemented __array_function__ (https://docs-cupy.chainer.org/en/stable/reference/interoperability.html) so many things may "just work". There was some work earlier on plugging in pydata/sparse, and there is some ongoing work to plug in pint. With both these efforts, a lot of xarray's code should be "backend-agnostic" but its not perfect.

Have you tried creating DataArrays with cupy arrays yet? I would just try things and see what works vs what doesn't.

Practically, our approach so far has been to add a number of xfailed tests (test_sparse.py and test_units.py) and slowly start fixing them. So that's one way to proceed if you're up for it.

dcherian avatar Mar 30 '20 20:03 dcherian

@jacobtomlinson gave CuPy a go a few months back. I seem to remember that he ran into a few problems but it would be good to get those documented here.

jhamman avatar Mar 30 '20 20:03 jhamman

Yeah Jacob and I played with this a few months back. There were some issues, but my recollection is pretty hazy. If someone gives this another try, it would be interesting to hear how things go.

jakirkham avatar Mar 30 '20 21:03 jakirkham

If you have any pointers on how to go about this - I can give it a try.

fjanoos avatar Mar 31 '20 00:03 fjanoos

Well here's a blogpost on using Dask + CuPy. Maybe start there and build up to using Xarray.

jakirkham avatar Mar 31 '20 02:03 jakirkham

@jacobtomlinson gave CuPy a go a few months back. I seem to remember that he ran into a few problems but it would be good to get those documented here.

I've been test driving xarray objects backed by CuPy arrays, and one issue I keep running into is that operations (such as plotting) that expect numpy arrays fail due to xarray's implicit converstion to Numpy arrays via np.asarray(). CuPy decided not to allow implicit conversion to NumPy arrays (see https://github.com/cupy/cupy/pull/3421).

I am wondering whether there is a plan for dealing with this issue?

Here's a small, reproducible example:


[23]: ds.tmin.data.device
      <CUDA Device 0>
[24]: ds.isel(time=0, lev=0).tmin.plot() # Fails
Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-21-69a72de2b9fd> in <module>
----> 1 ds.isel(time=0, lev=0).tmin.plot()

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in __call__(self, **kwargs)
    444 
    445     def __call__(self, **kwargs):
--> 446         return plot(self._da, **kwargs)
    447 
    448     @functools.wraps(hist)

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in plot(darray, row, col, col_wrap, ax, hue, rtol, subplot_kws, **kwargs)
    198     kwargs["ax"] = ax
    199 
--> 200     return plotfunc(darray, **kwargs)
    201 
    202 

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in newplotfunc(darray, x, y, figsize, size, aspect, ax, row, col, col_wrap, xincrease, yincrease, add_colorbar, add_labels, vmin, vmax, cmap, center, robust, extend, levels, infer_intervals, colors, subplot_kws, cbar_ax, cbar_kwargs, xscale, yscale, xticks, yticks, xlim, ylim, norm, **kwargs)
    684 
    685         # Pass the data as a masked ndarray too
--> 686         zval = darray.to_masked_array(copy=False)
    687 
    688         # Replace pd.Intervals if contained in xval or yval.

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/dataarray.py in to_masked_array(self, copy)
   2325             Masked where invalid values (nan or inf) occur.
   2326         """
-> 2327         values = self.values  # only compute lazy arrays once
   2328         isnull = pd.isnull(values)
   2329         return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/dataarray.py in values(self)
    556     def values(self) -> np.ndarray:
    557         """The array's data as a numpy.ndarray"""
--> 558         return self.variable.values
    559 
    560     @values.setter

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/variable.py in values(self)
    444     def values(self):
    445         """The variable's data as a numpy.ndarray"""
--> 446         return _as_array_or_item(self._data)
    447 
    448     @values.setter

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/variable.py in _as_array_or_item(data)
    247     TODO: remove this (replace with np.asarray) once these issues are fixed
    248     """
--> 249     data = np.asarray(data)
    250     if data.ndim == 0:
    251         if data.dtype.kind == "M":

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     83 
     84     """
---> 85     return array(a, dtype, copy=False, order=order)
     86 
     87 

ValueError: object __array__ method not producing an array

andersy005 avatar Jul 08 '20 20:07 andersy005

@andersy005 I'm about to start working actively on cupy support in xarray. Would be great to get some of your input.

Cupy requests that instead of calling __array__ you instead call their .get method for explicit conversion to numpy. So we need to add a little compatibility code for this.

jacobtomlinson avatar Jul 09 '20 14:07 jacobtomlinson

@andersy005 I'm about to start working actively on cupy support in xarray. Would be great to get some of your input.

Cupy requests that instead of calling __array__ you instead call their .get method for explicit conversion to numpy. So we need to add a little compatibility code for this.

Do you have a sense of the overhead / effort of making jax vs cupy as the gpu backend for xarrays ? One advantage of jax would be built in auto-diff functionality that would enable xarray to be plugged directly into deep learning pipelines. Downside is that it is not as numpy compatible as cupy. How much of a non-starter would this be ?

fjanoos avatar Jul 09 '20 22:07 fjanoos

@fjanoos I'm afraid I don't. In RAPIDS we support cupy as our GPU array implementation. So this request has come from the desire to make xarray compatible with the RAPIDS suite of tools.

We commonly see folks using cupy to switch straight over to a tool like pytorch using DLPack. https://docs-cupy.chainer.org/en/stable/reference/interoperability.html#dlpack

But I don't really see #4212 as an effort to make cupy the GPU backend for xarray. I see it as adding support for another backend to xarray. The more the merrier!

jacobtomlinson avatar Jul 10 '20 11:07 jacobtomlinson

I'd like to cast my vote in favor of getting this functionality in. It would be nice to autodiff through xarray operations.

From reading this and related threads, I'm trying to determine a gameplan to make this happen. I'm not familiar with xarray code, so any guidance would be much appreciated. This is what I'm thinking :

  1. Create a custom subclass of PyTorch's Tensors which meets the duck array required methods and attributes. Since this isn't officially supported, looks like I could run into issues getting this subclass to persist through tensor operations.
  2. Implement the __array_function__ protocol for PyTorch similar to how is demo-ed here.
  3. Pass this custom class into data array constructors and hope the .grad attribute works.

My first attempts at this haven't been successful. Whatever custom class I make and past to the DataArray constructor gets converted to something xarray can handle with this line :

https://github.com/pydata/xarray/blob/bc35548d96caaec225be9a26afbbaa94069c9494/xarray/core/dataarray.py#L408

Any suggestions would be appreciated. I'm hoping to figure out the shortest path to a working prototype.

Duane321 avatar Jan 22 '21 22:01 Duane321

No one is working on array_function at the moment. Implementing it has some backwards compat concerns as well, because people may be relying on np.somefunc(some_torch_tensor) to be coerced to ndarray. It's not a small project, but implementing a prototype with a few function in the torch namespace that are not exactly matching the NumPy API would be a useful way to start pushing this forward.

@rgommers Do you expect this solution to work with a PyTorch Tensor custom subclass? Or is monkey patching necessary?

Duane321 avatar Jan 22 '21 23:01 Duane321

Create a custom subclass of PyTorch's Tensors which meets the duck array required methods and attributes. Since this isn't officially supported, looks like I could run into issues getting this subclass to persist through tensor operations.

If you use PyTorch 1.7.1 or later, then Tensor subclasses are much better preserved through pytorch functions and operations like slicing. So a custom subclass, adding the attributes and methods Xarray requires for a duck array should be feasible.

data = as_compatible_data(data)

Looks like you need to patch that internally just a bit, probably adding pytorch to NON_NUMPY_SUPPORTED_ARRAY_TYPES.

Note that I do not expect anymore that we'll be adding __array_function__ to torch.Tensor, and certainly not any time soon. My current expectation is that the "get the correct namespace from an array/tensor object directly" from https://numpy.org/neps/nep-0037-array-module.html#how-to-use-get-array-module and https://data-apis.github.io/array-api/latest/ will turn out to be a much better design long-term.

rgommers avatar Jan 23 '21 10:01 rgommers

Note that your the main work in adding __array_function__ is not the dispatch mechanism, but mapping to 100% compatible APIs. That job should have gotten a lot easier now compared to 9 months ago. PyTorch now has a completely matching fft module, and a ~70% complete linalg module in master. And functions in the main namespace have gained dtype keywords, integer-to-float promotion, and other NumPy compat changes. So it should be feasible to write your custom subclass.

rgommers avatar Jan 23 '21 11:01 rgommers

@Duane321 While it would be fantastic to have gpu-enabled auto-diff-able xarrays / DataArrays, an interesting development worth looking into are the named tensor in https://pytorch.org/docs/stable/named_tensor.html. This appears to be an attempt to bridge the gap from the that they are making pytorch tensors increasingly dataarray like. I would not be surprised if within the next few iterations they add indexes to the tensors closing the gap even further.

fjanoos avatar Jan 23 '21 14:01 fjanoos

While it would be fantastic to have gpu-enabled auto-diff-able xarrays / DataArrays, an interesting development worth looking into are the named tensor in https://pytorch.org/docs/stable/named_tensor.html. This appears to be an attempt to bridge the gap from the that they are making pytorch tensors increasingly dataarray like. I would not be surprised if within the next few iterations they add indexes to the tensors closing the gap even further.

I really hope so. I explored named_tensors at first, but the lack an index for each dimension was a non-starter. So, I'll keep an eye out.

Duane321 avatar Jan 25 '21 00:01 Duane321

Note that your the main work in adding array_function is not the dispatch mechanism, but mapping to 100% compatible APIs. That job should have gotten a lot easier now compared to 9 months ago. PyTorch now has a completely matching fft module, and a ~70% complete linalg module in master. And functions in the main namespace have gained dtype keywords, integer-to-float promotion, and other NumPy compat changes. So it should be feasible to write your custom subclass.

Glad to hear there's progress I can lean on. I'll come back with a minimum version that does the API matching for maybe 1-2 methods, just to get feedback on theoverall structure. If it works, I can brute through a lot of the rest 🤞

Looks like you need to patch that internally just a bit, probably adding pytorch to NON_NUMPY_SUPPORTED_ARRAY_TYPES.

Thank you, I hesitate to change xarray code but not anymore.

Note that I do not expect anymore that we'll be adding array_function to torch.Tensor, and certainly not any time soon. My current expectation is that the "get the correct namespace from an array/tensor object directly" from https://numpy.org/neps/nep-0037-array-module.html#how-to-use-get-array-module and https://data-apis.github.io/array-api/latest/ will turn out to be a much better design long-term.

Does this mean I shouldn't fill out __array_function__ in my subclass? Or is this just a forward looking expectation?

Duane321 avatar Jan 25 '21 00:01 Duane321

Looks like you need to patch that internally just a bit, probably adding pytorch to NON_NUMPY_SUPPORTED_ARRAY_TYPES.

defining __array_function__ (and the other properties listed in the docs) should be enough: https://github.com/pydata/xarray/blob/a0c71c1508f34345ad7eef244cdbbe224e031c1b/xarray/core/variable.py#L232-L235

keewis avatar Jan 25 '21 00:01 keewis

Does this mean I shouldn't fill out __array_function__ in my subclass? Or is this just a forward looking expectation?

No, adding it should be perfectly fine. The dispatch mechanism itself isn't going anywhere, it's part of numpy and it works. Whether or not torch.Tensor itself has an __array_function__ method isn't too relevant for your subclass.

rgommers avatar Jan 25 '21 09:01 rgommers

I've made some mild progress, but it raises a few questions. I've defined this simple Tensor subclass which meets the duck array criteria:

class XArrayTensor(torch.Tensor):
    def __new__(cls, data=None, requires_grad=False):
        if data is None:
            data = torch.Tensor()
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    def __init__(self, data=None, dims: Tuple[str] = None):
        self.dims = dims

    def __array_function__(self, func, types, args, kwargs):
        if func not in IMPLEMENTED_FUNCTIONS or not (not all(issubclass(t, torch.Tensor) for t in types)):
            return NotImplemented
        return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs)

    def __array_ufunc__(self, func, types, args, kwargs):
        if func not in IMPLEMENTED_FUNCTIONS or not (not all(issubclass(t, torch.Tensor) for t in types)):
            return NotImplementedError
        return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs)

where IMPLEMENTED_FUNCTIONS holds a mapping from numpy functions to API compatible tensor operators (similar in style to this)

I added a torch_array_type to pycompat.py, which allows DataArray's .data attribute to persist as an XArrayTensor:

xr_tsr = XArrayTensor(torch.rand(3, 2))

data_array = xr.DataArray(
    xr_tsr,
    coords=dict(a=["a1", "a2", "a3"], b=["b1", "b1"]),
    dims=["a", "b"],
    name="dummy",
    attrs={"grad": xr_tsr.grad},
)
print(type(data_array.data)) --> yields 'xarray_tensor.XArrayTensor'

The issue I'm running into is when I run an operation like np.mean(data_array). The operation gets dispatched to functions within duck_array_ops.py, which are the things I'd like to override.

Also, I'd like to confirm something. If the API matching were complete, would the following be possible?

some_sum = data_array.sum()
some_sum.backward()
data_array.grad --> provides the gradient

I'm starting to suspect not because that would involve data_array being both DataArray and a Torch.Tensor object. It seems what I'm in fact enabling is that DataArray.data is a Torch.Tensor.

Duane321 avatar Jan 27 '21 19:01 Duane321

I'm starting to suspect not because that would involve data_array being both DataArray and a Torch.Tensor object. It seems what I'm in fact enabling is that DataArray.data is a Torch.Tensor.

some_sum is still a DataArray, which doesn't have a backward method. You could use

data_array = xr.DataArray(
    xr_tsr,
    coords=dict(a=["a1", "a2", "a3"], b=["b1", "b1"]),
    dims=["a", "b"],
    name="dummy",
    attrs={"grad": xr_tsr.grad, "backward": xr_tsr.backward},
)

and your example should work (I assume you meant .grad not .grid).

rgommers avatar Jan 29 '21 08:01 rgommers