datachain icon indicating copy to clipboard operation
datachain copied to clipboard

pre_fetch option in additional to cache for lib.File

Open dmpetrov opened this issue 1 year ago • 8 comments

We need to download items in async mode before processing them:

chain.settings(pre_fetch=2, cache=True, parallel=True).gen(laion=process_webdataset(spec=WDSLaion))
  • pre_fetch this should enable async file download (per thread) for a given limit of files (like, pre_fetch=10). Like pre_fetch in pytorch datasets. Default should be pre_fetch=2

OUTDATED:

  • consider introducing pre_fetch=0 that returns Stream() descriptor with direct access to storage and no caching.
ds.generate(WebDataset(spec=WDSLaion), parallel=4, cache=True, pre_fetch=10)

dmpetrov avatar Apr 04 '24 22:04 dmpetrov

Note that with the current architecture, pre_fetch won't do much, since only one File object exists at a time (assuming no batching).

rlamy avatar Apr 12 '24 19:04 rlamy

@rlamy we should change it in a way that pre-caching helps.

dmpetrov avatar Apr 14 '24 18:04 dmpetrov

Depends on the file API refactoring. Moving indexing to the app level. For now moving back to backlog.

shcheklein avatar Jul 31 '24 16:07 shcheklein

Since we are done with indexing more or less, moving it back to the ready stage cc @rlamy . Might still depend on some work that Ronan is doing now with decoupling datasetquery and datachain.

One of the use cases I have atm is:

  • A lot of video files in the bucket
  • Cache can help, but won't scale (we would need to pretty much download all at once). But it's a fine starting point.
  • Pre-fetch would help if we have also an option to cleanup files (if cache is disabled)

One thing that is a bit annoying is that some tools (OpenCV) seems to require a local path. Yes, cache helps in that case and pre-fetch can help - but both require downloading the whole file, while for some operations I just need some header. If someone has ideas how that can be improved - let me know. Is there a way to create file-like-looking object but that is a stream from the cloud underneath?

shcheklein avatar Sep 19 '24 18:09 shcheklein

Some notes:

  • In order to implement this, we need to insert logic similar to DatasetQuery.extract() before (or maybe in) udf.run().
  • Fetching should be implemented, or at least controlled in some way, by the model. For instance, if we have an ArrowRow (which contains a File but doesn't inherit from it), we should fetch the row, not the whole file.

This means that udf.run() should receive model instances, not raw DB rows, which requires some refactoring...

rlamy avatar Sep 27 '24 18:09 rlamy

This means that udf.run() should receive model instances, not raw DB rows, which requires some refactoring...

where do we receive raw DB rows there? (I wonder if this related or should be taken into account - https://github.com/iterative/studio/issues/10531#issuecomment-2379390308 )

shcheklein avatar Sep 27 '24 21:09 shcheklein

After probably too much refactoring, I can confirm that this can be implemented inside udf.run() which means that:

  • we don't need any (significant) changes to parallel or distributed code
  • each process gets pre_fetch async workers

Ignoring a lot of details, the basic idea is to change the implementation of udf.run() from this:

for db_row in udf_inputs:
    obj_row = self._prepare(db_row)
    obj_result = self.process(obj_row)
    yield self._convert_result(obj_result)

to this:

obj_rows = (self._prepare(db_row) for db_row in udf_inputs)
obj_rows = AsyncMapper(_prefetch_row, obj_rows, workers=pre_fetch)
for obj_row in obj_rows:
    obj_result = self.process(obj_row)
    yield self._convert_result(obj_result)

where prefetch_row looks like

async def prefetch_row(row):
    for obj in row:
        if isinstance(obj, File):
            await obj._prefetch()
    return row

Note that the latter can easily be generalised to arbitrary models, if we define some kind of "prefetching protocol".

rlamy avatar Oct 04 '24 21:10 rlamy

this can be implemented inside udf.run()

It looks like the right way of solving this. Thank you!

dmpetrov avatar Oct 04 '24 21:10 dmpetrov

The proposed implementation has a problem: it hangs when run in distributed mode, i.e. when using something like .settings(prefetch=2, workers=2). Here's what happens (with some simplifications!) when running a mapper UDF in that case:

  • UDFDistributor groups rows into tasks and sends them to distributed.UDFWorker.
  • Each task is put in the worker's internal task_queue, to be processed in .run_udf_internal_other_thread().
  • That method creates an input_queue and sets up udf_results = dispatch.run_udf_parallel(None, n_workers, input_queue).
  • For each task, it puts its rows in the input_queue and waits to get the same number of rows back from udf_results to put them in the DB.
  • Meanwhile, UDFDispatcher spawns dispatch.UDFWorker subprocesses that take rows from the input_queue and put results on done_queue. UDFDispatcher.run_udf_parallel gets results from done_queue and yields them.
  • UDFWorker calls udf.run() using UDFWorker.get_inputs() as the value for udf_inputs.
  • In Mapper.run(), AsyncMapper starts iterating over UDFWorker.get_inputs() which eventually blocks, waiting for more input. That blocks the event loop, which blocks AsyncMapper.iterate(), which means nothing goes to done_queue, which blocks iterating over udf_results...

Possible solutions

  • Disable prefetching in distributed mode 😢
  • Ensure that AsyncMapper.produce() doesn't block the event loop by running next(iter(self.iterable)) in a separate thread.
  • ??

rlamy avatar Oct 15 '24 21:10 rlamy

Using threading in AsyncMapper.produce() runs into the issue that iteration needs to be thread-safe, but that seems fixable, see #521. That PR only deals with Mapper and Generator though. Regarding the other 2 classes:

  • I'm not sure prefetching really makes sense for Aggregator.
  • It could be implemented in BatchMapper, but that would probably require the batching to be done inside udf.run() (i.e. create the batches after prefetching, sending the file objects to the UDF when they're ready) which requires some refactoring in parallel and distributed mode.

rlamy avatar Oct 18 '24 22:10 rlamy

    def get_inputs(self):
        while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
            yield batch

minor observation - batch can be renamed here - it's not really a batch, right?


Aggregator and BatchMapper should be related to each other, no? both send iterate probably on batches of rows and send them to UDF?

I think prefetch still makes sense (can start fetching the next batch?). I think definitely can be a followup / separate ticket to discuss and prioritize.

shcheklein avatar Oct 20 '24 00:10 shcheklein

The proposed implementation has a problem: it hangs when run in distributed mode, i.e. when using something like .settings(prefetch=2, workers=2). Here's what happens (with some simplifications!) when running a mapper UDF in that case:

@rlamy, was that fixed? I see that now produce is run in a separate thread, and it seems to work fine on Studio while running locally. I do see occasional failures with celery timeout, but I think it's a setup issue on my end.

I tried to fix a hanging issue when interrupted/error in #597 which was causing test failures. If you have a moment, I would appreciate your feedback on the PR. Thank you.

skshetry avatar Nov 15 '24 07:11 skshetry

@skshetry I think you've understood all the issues by now, but to clarify: my first attempt was hanging in distributed mode which I then fixed in #521, but that introduced a new issue which you fixed in #597.

rlamy avatar Nov 20 '24 12:11 rlamy

@skshetry can it be closed?

shcheklein avatar Nov 23 '24 00:11 shcheklein