data icon indicating copy to clipboard operation
data copied to clipboard

Allow overriding of existing functional APIs on DataPipe subclasses

Open BarclayII opened this issue 1 year ago • 1 comments

🚀 The feature

Allow users to override existing functional APIs on their own DataPipe subclasses.

Motivation, pitch

Consider the use case where I iterate over a tensor in minibatches:

# MapDataPipe version
x = dp.map.SequenceWrapper(torch.arange(200)).batch(20)
for i in x:
    print(i)
    break

# IterDataPipe version
x = dp.iter.IterableWrapper(torch.arange(200)).batch(20)
for i in x:
    print(i)
    break

Both MapDataPipe and IterDataPipe will return a list of tensors. What should I do if I want to return a single tensor instead of a list of multiple small tensors?

I thought of subclassing IterDataPipe and registering my own batch() transformation to the subclass, but it doesn't work:

class TensorDataPipe(dp.iter.IterDataPipe):
    def __init__(self, x):
        self.x = x
        
    def __len__(self):
        return self.x.shape[0]
        
    def __iter__(self):
        for i in range(self.x.shape[0]):
            yield self.x[i]
 
# crashes because 'batch' is already used
@dp.functional_datapipe('batch')
class BatchedTensorDataPipe(TensorDataPipe):
    def __init__(self, source_dp, batch_size):
        self.source_dp = source_dp
        self.batch_size = batch_size
        
    def __iter__(self):
        n = len(self.source_dp)
        for i in range(0, n, batch_size):
            i_end = min(i + batch_size, n)
            yield self.source_dp.x[i:i_end]

Additional context

Iterating a single tensor on GPU is common in training graph neural networks on large-scale graphs where one needs to iterate over a tensor of node/edge IDs. We observed that returning multiple scalars at each iteration will cause a large overhead, and as a result DGL wrote their own Dataset and DataLoader (e.g. https://github.com/dmlc/dgl/pull/2716 and https://github.com/dmlc/dgl/pull/3665) to return a 1D tensor instead of a list of scalars as a result.

Alternatives

For this particular problem I have a solution by rethinking the problem as chunking-then-slicing rather than batching individual scalars.

class RangeDataPipe(dp.iter.IterDataPipe):
    def __init__(self, n):
        self.n = n
        
    def __len__(self):
        return self.n
        
    def __iter__(self):
        yield torch.arange(self.n)

@dp.functional_datapipe('chunk')
class Chunker(dp.iter.IterDataPipe):
    def __init__(self, source_dp, chunk_size, drop_last=False):
        self.source_dp = source_dp
        self.chunk_size = chunk_size
        self.drop_last = drop_last
        
    def __iter__(self):
        perm = next(iter(self.source_dp))
        n = len(self.source_dp)
        for i in range(0, n, self.chunk_size):
            if self.drop_last and (i + self.chunk_size > n):
                break
            i_end = min(i + self.chunk_size, n)
            yield perm[i:i_end]
            
@dp.functional_datapipe('slice_from')
class Slicer(dp.iter.IterDataPipe):
    def __init__(self, source_dp, tensor):
        self.source_dp = source_dp
        self.tensor = tensor
        
    def __iter__(self):
        for indices in self.source_dp:
            yield self.tensor[indices]

# This works
p = RangeDataPipe(200).chunk(20).slice_from(torch.arange(200))
dl = torch.utils.data.DataLoader(p, batch_size=None)
for x in dl:
    print(x)

But still I think allowing overriding could be nice for us to define customized "shuffle", "batch" behaviors on custom datapipes.

A related question would be how to resolve conflicts between DataPipes from two packages that shares the same functional API name (but with different behaviors).

BarclayII avatar Aug 08 '22 05:08 BarclayII

Both MapDataPipe and IterDataPipe will return a list of tensors. What should I do if I want to return a single tensor instead of a list of multiple small tensors?

You can add collate after batch to achieve so.

We might not allow override the existing DataPipe for now because there are still a few caveats for DataPipe like reset function, __getstate__/__setstate__ functions and etc., which are provided a native implementation for the sake of our planned feature such as snapshotting.

As an alternative, you can always register your custom DataPipe with a different name.

ejguan avatar Aug 08 '22 15:08 ejguan