lance
lance copied to clipboard
Faster ShardedFragmentSampler
Problem: fragment sampler downloads fragments in a blocking fashion. This is a bottleneck if your batch size is larger than the number of rows in a fragment.
Solution: read-ahead with threads
Rough implementation
class ReadAheadShardedFragmentSampler(ShardedFragmentSampler):
def __init__(
self,
rank: int,
world_size: int,
randomize: bool = False,
seed: int = 0,
read_ahead: int = 16,
):
super().__init__(rank, world_size, randomize, seed)
self._read_ahead = read_ahead
def iter_fragments(self, dataset: lance.LanceDataset, fragment_queue: Queue):
fragments = dataset.get_fragments()
if self._randomize:
random.seed(self._seed)
random.shuffle(fragments)
for offset, idx in enumerate(range(self._rank, len(fragments), self._world_size)):
fragment_queue.put((offset, fragments[idx]))
fragment_queue.put((None, None))
def iter_batches(self, fragment_queue: Queue, batch_queue: Queue, **kwargs):
while True:
offset, fragment = fragment_queue.get(block=True)
if offset is None:
batch_queue.put((None, None))
fragment_queue.task_done()
return
for idx, batch in enumerate(fragment.to_batches(**kwargs)):
batch_queue.put((idx * offset + idx, batch))
fragment_queue.task_done()
def __call__(
self,
dataset: lance.LanceDataset,
*args,
batch_size: int = 128,
columns: list[str] | dict[str, str] | None = None,
filter: str | None = None,
batch_readahead: int = 16,
with_row_id: bool = False,
**kwargs,
) -> Generator[pa.RecordBatch, None, None]:
fragment_queue = Queue(maxsize=self._read_ahead)
batch_queue = Queue()
fragment_thread = threading.Thread(
target=self.iter_fragments,
args=(dataset, fragment_queue),
daemon=True,
)
batch_threads = [
threading.Thread(
target=self.iter_batches,
args=(fragment_queue, batch_queue),
kwargs=dict(
batch_size=batch_size,
columns=columns,
filter=filter,
with_row_id=with_row_id,
**kwargs,
),
daemon=True,
)
for _ in range(batch_readahead)
]
fragment_thread.start()
for thread in batch_threads:
thread.start()
current_idx = 0
buffer = {}
while True:
try:
idx, batch = batch_queue.get(timeout=0.01)
if idx is None:
batch_queue.task_done()
raise StopIteration
buffer[idx] = batch
except Empty:
continue
if current_idx in buffer:
yield buffer[current_idx]
del buffer[current_idx]
current_idx += 1