lance icon indicating copy to clipboard operation
lance copied to clipboard

Faster ShardedFragmentSampler

Open tonyf opened this issue 6 months ago • 2 comments

Problem: fragment sampler downloads fragments in a blocking fashion. This is a bottleneck if your batch size is larger than the number of rows in a fragment.

Solution: read-ahead with threads

Rough implementation

class ReadAheadShardedFragmentSampler(ShardedFragmentSampler):
    def __init__(
        self,
        rank: int,
        world_size: int,
        randomize: bool = False,
        seed: int = 0,
        read_ahead: int = 16,
    ):
        super().__init__(rank, world_size, randomize, seed)
        self._read_ahead = read_ahead

    def iter_fragments(self, dataset: lance.LanceDataset, fragment_queue: Queue):
        fragments = dataset.get_fragments()
        if self._randomize:
            random.seed(self._seed)
            random.shuffle(fragments)
        for offset, idx in enumerate(range(self._rank, len(fragments), self._world_size)):
            fragment_queue.put((offset, fragments[idx]))
        fragment_queue.put((None, None))

    def iter_batches(self, fragment_queue: Queue, batch_queue: Queue, **kwargs):
        while True:
            offset, fragment = fragment_queue.get(block=True)
            if offset is None:
                batch_queue.put((None, None))
                fragment_queue.task_done()
                return

            for idx, batch in enumerate(fragment.to_batches(**kwargs)):
                batch_queue.put((idx * offset + idx, batch))
            fragment_queue.task_done()

    def __call__(
        self,
        dataset: lance.LanceDataset,
        *args,
        batch_size: int = 128,
        columns: list[str] | dict[str, str] | None = None,
        filter: str | None = None,
        batch_readahead: int = 16,
        with_row_id: bool = False,
        **kwargs,
    ) -> Generator[pa.RecordBatch, None, None]:
        fragment_queue = Queue(maxsize=self._read_ahead)
        batch_queue = Queue()

        fragment_thread = threading.Thread(
            target=self.iter_fragments,
            args=(dataset, fragment_queue),
            daemon=True,
        )
        batch_threads = [
            threading.Thread(
                target=self.iter_batches,
                args=(fragment_queue, batch_queue),
                kwargs=dict(
                    batch_size=batch_size,
                    columns=columns,
                    filter=filter,
                    with_row_id=with_row_id,
                    **kwargs,
                ),
                daemon=True,
            )
            for _ in range(batch_readahead)
        ]

        fragment_thread.start()
        for thread in batch_threads:
            thread.start()

        current_idx = 0
        buffer = {}
        while True:
            try:
                idx, batch = batch_queue.get(timeout=0.01)
                if idx is None:
                    batch_queue.task_done()
                    raise StopIteration
                buffer[idx] = batch
            except Empty:
                continue

            if current_idx in buffer:
                yield buffer[current_idx]
                del buffer[current_idx]
                current_idx += 1

tonyf avatar Aug 08 '24 21:08 tonyf