data icon indicating copy to clipboard operation
data copied to clipboard

`scan` support

Open samuela opened this issue 2 years ago • 3 comments

🚀 The feature

How does one create an IterDataPipe with scan/fold semantics?

Motivation, pitch

Necessary for pipelines that require some kind of state, eg. label encoding for an unknown number of labels.

Alternatives

No response

Additional context

No response

samuela avatar Mar 24 '23 18:03 samuela

I don't think the existing built-in DataPipes for scan/fold. The only thing I can think of off the top of my head is using .map(fn) with fn that has side effects but that is undesirable for many reasons.

Right now, they can be easily implemented as a custom IterDataPipes. It should accept:

  1. source DataPipe
  2. initial state
  3. function that an element from source DataPipe and state to a new state/output

It will also need a reset method to reset the state to initial state when a new iterator is created from the DataPipe.

We also will likely accept those PRs as built-in DataPipe, as I can see this being useful for other users as well.

NivekT avatar Mar 24 '23 19:03 NivekT

I don't have bandwidth submit a PR atm, but I envision something like the following should work:

@functional_datapipe("scan")
class ScanDataPipe(IterDataPipe):
    """An IterDataPipe implementing scan.

    scan :: (acc -> x -> (acc, y)) -> acc -> [x] -> [y]
    """
    def __init__(self, datapipe, func, initial):
        super().__init__()
        self.datapipe = datapipe
        self.func = func
        self.initial = initial

    def __iter__(self):
        accumulated = self.initial
        for data in self.datapipe:
            accumulated, result = self.func(accumulated, data)
            yield result

samuela avatar Mar 24 '23 20:03 samuela

Yea, I believe that will work. Optionally, you can save accumulated as an instance variable so that it will be saved when the state is saved.

NivekT avatar Mar 24 '23 22:03 NivekT