`scan` support
🚀 The feature
How does one create an IterDataPipe with scan/fold semantics?
Motivation, pitch
Necessary for pipelines that require some kind of state, eg. label encoding for an unknown number of labels.
Alternatives
No response
Additional context
No response
I don't think the existing built-in DataPipes for scan/fold. The only thing I can think of off the top of my head is using .map(fn) with fn that has side effects but that is undesirable for many reasons.
Right now, they can be easily implemented as a custom IterDataPipes. It should accept:
- source DataPipe
- initial state
- function that an element from source DataPipe and state to a new state/output
It will also need a reset method to reset the state to initial state when a new iterator is created from the DataPipe.
We also will likely accept those PRs as built-in DataPipe, as I can see this being useful for other users as well.
I don't have bandwidth submit a PR atm, but I envision something like the following should work:
@functional_datapipe("scan")
class ScanDataPipe(IterDataPipe):
"""An IterDataPipe implementing scan.
scan :: (acc -> x -> (acc, y)) -> acc -> [x] -> [y]
"""
def __init__(self, datapipe, func, initial):
super().__init__()
self.datapipe = datapipe
self.func = func
self.initial = initial
def __iter__(self):
accumulated = self.initial
for data in self.datapipe:
accumulated, result = self.func(accumulated, data)
yield result
Yea, I believe that will work. Optionally, you can save accumulated as an instance variable so that it will be saved when the state is saved.