data
data copied to clipboard
Allow custom sharding datapipes
🚀 The feature
https://github.com/pytorch/pytorch/blob/master/torch/utils/data/graph_settings.py#L51 currently explicitely checks for _ShardingIterDataPipe
which is 1. a private type, and 2, not in line with e.g. how apply_shuffle_settings
works (checking for presence of methods).
As a consequence, it's not canonical possible to write custom sharding operations. As a workaround, one could of course inherit from _ShardingIterDataPipe
for now (but again, this a private type).
I instead propose to add
def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
if not hasattr(datapipe, "apply_sharding"):
return False
if not inspect.ismethod(datapipe.apply_sharding):
return False
return True
and use that as a criterion instead.
These methods, both for shuffling and for sharding, should also be documented for both IterDataPipe and MapDataPipe.
Motivation, pitch
I'm implementing a specific kind of dataset that is essentially a mixture of a MapDataPipe and IterDataPipe and can't be implemented satisfactorily with the existing pipes. Thus I need to implement shuffling and sharding manually for this piece of the pipeline.
To give a rough sketch: I'm processing bigger chunks of data, i.e. single arrays, in a stream fashion (IterDataSet). These are shards of the overall dataset (distributed), e.g. a single month out of 30 years. However, I also want to:
- Index in a shuffled way within such an array (MapDataPipe) when yielding individual samples
- Shard these arrays by sharding the indices (MapDataPipe) so that I can use the MPRS as well.
To make things more clearer, this is how my __iter__
looks that currently already implements custom shuffling:
def __iter__(self):
T = self.steps*self.rate
for i, ds in enumerate(self.dp):
N = len(ds.variables[self.dim])
indices = list(range(N-T))
if self.shuffle:
if self._seed == None:
seed = int(th.empty((), dtype=th.int64).random_().item())
else:
seed = self._seed + i
self._rng.seed(seed)
self._rng.shuffle(indices)
for idx in indices:
yield ds.isel(**{self.dim: slice(idx, idx+T, self.rate)})
Each ds
is essentially a big array that we get in a streaming fashion from an upstream datapipe. That is why I am calling this construct a mixture of both IterDataPipe and MapDataPipe.
Going forward, I would by the way like to find a way to abstract this concept a bit more (potentially by zipping indices and the array), so that the sharding and shuffling can be done independently from the actual indexing operation and can thus be reused. If there is broader interests for such a construct, I would be open to submitting a PR.
Alternatives
No response
Additional context
No response
On a related note: I find the check at https://github.com/pytorch/pytorch/blob/master/torch/utils/data/graph_settings.py#L53 overly restrictive, and there should be an optional way to disable it in my opinion.
I already ran into it, e.g. by placing sharding operations within branches from a fork()
operation. So this is not just a hypothetical concern.
Related to https://github.com/pytorch/pytorch/issues/96975
We should allow users to provide custom sharding DataPipe. Will send a PR shortly.
@ejguan I think it'll be great to:
-
Clearly document (if allowed) the how + expectations of implementing custom
shuffle()
,apply_sharding()
(and other...) methods. Currently, this is hard to do without actually reading theDataLoader
source code. -
Figure out a way to making these composable via mechanisms other than pipe chaining. For instance (but not necessarily) as a MixIn:
class MyDataPipe(IterDataPipe[str], RoundRobinShardable, BufferedShuffle): ...
- Clearly document (if allowed) the how + expectations of implementing custom
shuffle()
,apply_sharding()
(and other...) methods. Currently, this is hard to do without actually reading theDataLoader
source code.
Yes this would be very appreciated. I had to do a lot of digging to get a hold of all the features I needed to realize my pipeline. That is of course expected for a project that is still not completely fleshed out yet, so no worries.
@ejguan If I am not mistaken, this has been fixed by https://github.com/pytorch/pytorch/pull/97287 and can be closed? Correct me if I am wrong.
Thanks for the fix