data
data copied to clipboard
Open for contribution on utility nodes like `Filter`, `Shuffler`, `Header`, `Cycler`?
Hi, do you think this kind of nodes would be in the scope of Torchdata? Then I'm down to open a PR to add them. with remaining and testing, for sure.
import logging
import random
from collections import deque
from typing import Any, Callable, Deque, Dict, Optional, TypeVar, Optional
from torchdata.nodes import BaseNode
logger = logging.getLogger(__name__)
X = TypeVar("X")
T = TypeVar("T")
U = TypeVar("U")
class Filter(BaseNode[T]):
"""Node that filters items from source node based on predicate function."""
SOURCE_KEY = "source"
def __init__(self, source_node: BaseNode[T], filter_fn: Callable[[T], bool]):
super().__init__()
self.source = source_node
self.filter_fn = filter_fn
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
self.source.reset(initial_state.get(self.SOURCE_KEY) if initial_state else None)
def next(self) -> T:
while True:
item = next(self.source)
if self.filter_fn(item):
return item
def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict()}
class Shuffler(BaseNode[T]):
"""Node that shuffles items from source node using a buffer."""
SOURCE_KEY = "source"
def __init__(self, source_node: BaseNode[T], buffer_size: int, seed: Optional[int] = None):
super().__init__()
if buffer_size < 1:
raise ValueError("Buffer size must be at least 1")
self.source = source_node
self.buffer_size = buffer_size
self.buffer: Deque[T] = deque()
self.rng = random.Random(seed)
self._initial_seed = seed
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
self.buffer.clear()
if initial_state is not None:
self.source.reset(initial_state.get(self.SOURCE_KEY))
self.rng.setstate(initial_state["rng_state"])
else:
self.source.reset()
if self._initial_seed is not None:
self.rng = random.Random(self._initial_seed)
def _fill_buffer(self) -> bool:
"""Fill buffer with items from source. Returns True if any items were added."""
try:
while len(self.buffer) < self.buffer_size:
self.buffer.append(next(self.source))
return True
except StopIteration:
return len(self.buffer) > 0
def next(self) -> T:
if not self.buffer and not self._fill_buffer():
raise StopIteration
# Randomly select and remove an item from the buffer
idx = self.rng.randrange(len(self.buffer))
item = self.buffer[idx]
self.buffer[idx] = self.buffer[-1]
self.buffer.pop()
# Try to refill buffer
self._fill_buffer()
return item
def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict(), "rng_state": self.rng.getstate()}
class Header(BaseNode[T]):
"""Node that yields only the first N items from source node."""
SOURCE_KEY = "source"
def __init__(self, source_node: BaseNode[T], n: int):
super().__init__()
if n < 0:
raise ValueError("n must be non-negative")
self.source = source_node
self.n = n
self._count = 0
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
self.source.reset(initial_state.get(self.SOURCE_KEY) if initial_state else None)
if initial_state is not None:
self._count = initial_state["count"]
else:
self._count = 0
def next(self) -> T:
if self._count >= self.n:
raise StopIteration
item = next(self.source)
self._count += 1
return item
def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict(), "count": self._count}
class Cycler(BaseNode[T]):
"""Node that cycles through source node indefinitely."""
SOURCE_KEY = "source"
def __init__(self, source_node: BaseNode[T]):
super().__init__()
self.source = source_node
self._cycle_count: int = 0
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self._cycle_count = initial_state["cycle_count"]
self.source.reset(initial_state.get(self.SOURCE_KEY))
else:
self._cycle_count = 0
self.source.reset(None)
def next(self) -> T:
try:
return next(self.source)
except StopIteration:
self._cycle_count += 1
self.source.reset(None)
return next(self.source)
def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict(), "cycle_count": self._cycle_count}
Hey @keunwoochoi, thanks for this! These would be a great addition, especially excited for Shuffler.
cc. @ramanishsingh who has looking at Filter node.