data icon indicating copy to clipboard operation
data copied to clipboard

Serialize np.ndarray via shared memory

Open sehoffmann opened this issue 1 year ago • 8 comments

🚀 The feature

When transmitting np.ndarray via torch.multiprocessing (e.g. used by MPRS), back them by shared memory (SM) to significantly speed up transmission. This could potentially be expanded to other objects implementing the array interface in the future as well.

We can easily reuse exisiting implementations for this. E.g. to (pre-)copy a numpy array to SM (potentially before serialization):

def share_memory(arr: np.ndarray):
    view = torch.as_tensor(arr)
    sm = view.share_memory_()  # this will copy
    return sm.numpy()  # returned np.ndarray is a view into freshly copied SM storage

For serialization we would just need to call torch.as_tensor(arr).share_memory_() (plus the mandatory bookkeeping ofc, e.g. "this is a np.ndarray" etc.) which would either yield a view if backed by SM already or make a copy, and then reuse the serialization infrastructure for torch.Tensor to transmit it.

Since np.ndarray is a very foundational type used by many other libraries in turn, this could have a very significant impact.

Motivation, pitch

Transmitting np.ndarrays with the MPRS is very expensive since they get completely serialized via pickle and then pushed through a (named) pipe, c.f. https://github.com/pytorch/data/issues/1078.

However, the same does not hold true for torch.Tensor for which torch.multiprocessing provides a very cheap and fully fleshed out serialization method using shared memory. Since we can easily create views from torch -> numpy and numpy -> torch, we can easily reuse the existing serialization infrastructure in torch.

Thus it should be possible to deliver the very same performance boost to np.ndarray with very little effort.

Alternatives

No response

Additional context

No response

sehoffmann avatar Mar 15 '23 18:03 sehoffmann

I created a small proof of concept and benchmarked it: https://github.com/sehoffmann/numpy_shared_mem/blob/master/numpy_shared_memory.ipynb

Bottom line: A order of magnitude faster. Also, if preloading the memory into shared memory, it obviously won't be duplicated across worker process which was the main reason for me to delve into this. Both memory and performance savings will obviously be even more significant when running multiple IO worker processes (almost everything here scales O(n)).

sehoffmann avatar Mar 17 '23 15:03 sehoffmann

Also, I did some small testing and it appears that this new serialization method for np.ndarray also, out-of-the-box, trickles down to downstream libraries such as xarray. I.e. these libraries and their datatypes would instantly benefit from these performance gains as well. Need to do some more testing wrt this though to find out the exact limitations of that.

sehoffmann avatar Mar 17 '23 15:03 sehoffmann

I did some further testing and can confirm these points:

  • The approach directly works with both xr.DataArray as well as xr.Dataset, containing multiple variables and coordinate arrays
  • When testing on real-world data, 100 timesteps of ERA5 temperature fields (each ~700x1400 big), transmission time is reduced from 3.4s to 190ms!
  • With the current version it does not work for pd.DataFrame where it produces an error.
    • The reason for that is, that pd.DataFrame also wants to serialize its column names which are an np.ndarray of strings (dtype: 'object') and for that arr.view(np.int8) rightfully fails.
  • After fixing, we see the same performance gain when transmitting the same data via a pd.DataFrame

The object dtype issue can be solved easily by falling back to the default serialization behavior in that case:

def reduce_ndarray(arr: np.ndarray):
    if arr.dtype.hasobject:  # fall back to default impl for python objects
        return arr.__reduce__()
    
    shape = arr.__array_interface__['shape']
    strides = arr.__array_interface__['strides']
    typestr = arr.__array_interface__['typestr']
    
    base = arr.base
    while type(base) is np.ndarray and base.base is not None:  # only support pure np.ndarray's for now
        base = base.base

    if isinstance(base, torch.Tensor):
        tensor = base
        offset = np.asarray(base).__array_interface__['data'][0] - arr.__array_interface__['data'][0]
    else:
        tensor = torch.as_tensor(arr.view(np.int8))
        offset = 0
    
    return (rebuild_ndarray, (tensor, (offset,shape,strides,typestr)))

Also, one idea I came up with which is in my opinion important is that probably the most general way to preload any kind of datastructure into shared memory is to pickle and unpickle it. By doing so, internal np.ndarrays will be replaced by ones backed from shared memory, and it requires no knowledge or specific code for these. One just has to be careful to only such a function with objects for which it is clear that they can be put into shared mem, i.e. have internal np.ndarrays. Otherwise, unnecessary overhead is introduced.

sehoffmann avatar Mar 17 '23 16:03 sehoffmann

I can report that in a real pipeline with real data implementing this results in a 13.2% throughput increase. But obviously the biggest gain is the vastly reduced memory consumption (if sharing before dispatching to worker processes)

sehoffmann avatar Mar 21 '23 21:03 sehoffmann

A little bit context on old DataLoader. It always tries to collate samples into Tensor via collate_fn. Therefore, it would help reduce overhead of transmitting samples from worker process to main process.

For DataLoader2, we are providing more flexible composable pipeline for users. It means there isn't such mandatory mechanism. We might add an argument to MPRS to provide such feature.

ejguan avatar Mar 23 '23 13:03 ejguan

Hey @ejguan,

I'm already solely working with the Dataloader2. Also, since at this point my current implementation does all IPC via shared mem, I experimented with transmitting the processed arrays (which already/still reside in shared mem at that point) to the main process and collate there. This resulted in a degradation of the performance to the single-threaded case which lets me believe that my main performance overhead right now is actually the collate. This makes sense since it essentially needs to stack a lot of different slices into one tensor, and memory copies are among the most expensive operations on modern cpus.

If this is interesting for you, you can find a working implementation at: https://github.com/sehoffmann/atmodata/blob/develop/atmodata/serialization.py

sehoffmann avatar Mar 23 '23 14:03 sehoffmann

This resulted in a degradation of the performance to the single-threaded case which lets me believe that my main performance overhead right now is actually the collate.

I am not sure about it. Do you mean collate is the problem causing the overhead? I though collate would help on it because it would convert array to Tensors and relying on shared memory to pass Tensor from worker process to main process.

ejguan avatar Apr 17 '23 17:04 ejguan

This resulted in a degradation of the performance to the single-threaded case which lets me believe that my main performance overhead right now is actually the collate.

I am not sure about it. Do you mean collate is the problem causing the overhead? I though collate would help on it because it would convert array to Tensors and relying on shared memory to pass Tensor from worker process to main process.

Hey @ejguan,

I still have to dive deeper into this and do some more sophisticated benchmarks. As of now, it was only speculation from my side.

What I meant is that I believe that in my specific pipeline, the most expensive operation is collate. Not because there is something wrong with it, but because I collate many different (shared mem) slices, i.e. views, into a new tensor. This will necessarily result in many memory copies (one for each slice) and these are expensive.

sehoffmann avatar Apr 18 '23 13:04 sehoffmann