gradio icon indicating copy to clipboard operation
gradio copied to clipboard

Batching

Open abidlabs opened this issue 1 year ago • 13 comments

Here is the design I propose for GPU Batching:

User API

  • To enable "batch mode", the user provides a batch_size and batch_timeout argument to the .queue() method. The queue will process jobs in batches of size batch_size unless a smaller number of jobs than batch_size have been queued within batch_timeout number of seconds (by default 1.0). If batch_timeout seconds have elapsed, the Queue will process the jobs in the batch.
  • The user must also provide a batch_fn in the Blocks event trigger or in Interface, which is a version of the function that is designed to take in a batch of samples.

Backend changes

  • The queue will have a new process_batch_events() method to handle a batch of events
  • There will be a new /api/batch/{api_name} route created to handle batched prediction
  • This route will internally use the batch_fn instead of the regular fn

Would appreciate any feedback on this design before I start working on it @aliabid94 @freddyaboulton @apolinario

Fixes: #1597

abidlabs avatar Sep 09 '22 01:09 abidlabs

All the demos for this PR have been deployed at https://huggingface.co/spaces/gradio-pr-deploys/pr-2218-all-demos

github-actions[bot] avatar Sep 09 '22 01:09 github-actions[bot]

makes sense. Should batch_size be tied to the event listener instead of the queue? do batch functions usually support a single batch size? what happens if there's less than that (e.g. 5 data points in the batch for a batch_fn that takes 8 batches when the timeout goes off)

aliabid94 avatar Sep 09 '22 19:09 aliabid94

No I don't think batched functions typically support a single batch size. Usually, you use the largest batch size that you expect to fit in memory. So there would be no real problem if a batch of size 5 samples were to get fed into a batch_fn where we've set that batch_size to be 8)

Would you agree with that @apolinario @pcuenca?

abidlabs avatar Sep 10 '22 07:09 abidlabs

Yes, except when using JAX, where a multiple of 8 is often required for parallelization and batch sizes are therefore fixed, usually. But this could be handled by the user in the batch_fn, I guess. I don't think it's necessary to add more complexity to the queue by introducing a fixed_size parameter or something like that, what do you think?

pcuenca avatar Sep 10 '22 09:09 pcuenca

@abidlabs This looks great! I agree with @aliabid94 that it may be better to tie the batch fn with the event listener as opposed to the queue. With this approach, it seems only one function can be batched per demo.

freddyaboulton avatar Sep 10 '22 14:09 freddyaboulton

@freddyaboulton yes for sure batch_fn will be a parameter in the event listener. Given @pcuenca's point, I think we might want to implement a per-function batch size as well, but I don't think we need in this in the v1. So in summary:

The .queue() will have the following parameters:

  • batch_size (by default None)
  • batch_timeout (by default 1 seconds)

The event handlers will have the following parameters:

  • batch_fn
  • Eventually, we may add a batch_size that overrides the global batch_size in queue()
  • Eventually, we may add a batch_timeout that overrides the global batch_timeout in queue()
  • Eventually, we may add a boolean force_batch_size that manually increases a smaller batch size to a larger one by repeating input samples.

abidlabs avatar Sep 12 '22 01:09 abidlabs

@abidlabs why is there a batch_fn? Isn't it unnecessary and make it more complex? Sending arbitrary length of list of inputs to the fn seems fine to me and had verified it with @pcuenca in the past. I feel like instead of batch_fn, there should be batch_size and batch_threshold in the event listener.

omerXfaruq avatar Sep 12 '22 13:09 omerXfaruq

@abidlabs Thank you for the summary! Another question - Can we avoid adding another api endpoint? I think it will make it easier for api "consumers" to use the api, e.g. people loading interfaces, the front-end, people using the api manually, if we don't add another end point and it will make it easier to avoid bugs like #1316

freddyaboulton avatar Sep 12 '22 14:09 freddyaboulton

Agree with @FarukOzderim, why do we need a separate batch fn? And with @freddyaboulton that we shouldn't add another API endpoint. Here's the syntax I have in mind: btn.click(fn, inputs, outputs, batch=True, max_batch_size=8) and everything should automatically be handled from there.

aliabid94 avatar Sep 13 '22 01:09 aliabid94

Thanks for all of the suggestions guys! This makes a lot of sense, we don't actually need a separate batch_fn. If batch mode is enabled, we require the function to work with a batch of inputs, which also avoids us creating a separate API endpoint. So the new proposed design is:

User API

  • To enable "batch mode", the user creates an event trigger that looks like this: btn.click(fn, inputs, outputs, batch=True).
  • To configure the batch processing, the user can also pass in two optional parameters: max_batch_size=8, batch_timeout=1
  • We also add these parameters (default_batch=False, default_max_batch_size=8, default_batch_timeout=1) to queue(), which applies them to all of the backend functions for Blocks. The main advantage of this would be making it easier to get batching to work with Interface (open to other suggestions here)

Backend changes

  • The queue will have a new process_batch_events() method to handle a batch of events, but this is called from the regular /api/{api_name} endpoint if batch=True for that particular function

abidlabs avatar Sep 13 '22 15:09 abidlabs

We also add these parameters (default_batch=False, default_max_batch_size=8, default_batch_timeout=1) to queue(), which applies them to all of the backend functions for Blocks. The main advantage of this would be making it easier to get batching to work with Interface (open to other suggestions here)

I'm not sure if adding those parameters to the queue method is the best because I don't think it's a good idea to apply the same batching behavior to all backend functions. For Interfaces, for example, I don't think we want to batch the flagging or at least not the same way we batch the main prediction function.

What if we control the batching behavior for interfaces via the constructor?

freddyaboulton avatar Sep 13 '22 16:09 freddyaboulton

I'm not sure if adding those parameters to the queue method is the best because I don't think it's a good idea to apply the same batching behavior to all backend functions. For Interfaces, for example, I don't think we want to batch the flagging or at least not the same way we batch the main prediction function.

What if we control the batching behavior for interfaces via the constructor?

I had originally avoided that because (a) I didn't want to add too many parameters to the Interface constructor, and (b) I thought that these parameters only made sense if queueing was enabled

But now that I think about it, if batch mode is enabled, we require the user to pass in a fn that can handle batch inputs, so it would make sense to put it in the same method as the fn, which in this case would the be Interface constructor. And strictly speaking, we don't need queuing to enable batch mode. You could pass in a batch of inputs through http requests as well

So yeah I think this is a better approach, thanks @freddyaboulton

abidlabs avatar Sep 13 '22 16:09 abidlabs

Thinking aloud here...

When batch=True, the following should happen:

  • If a user calls the Interface/Blocks as a function, it should always work with a single sample (regardless of whether the underlying function is a batched function or not)
  • If a user calls the /api/predict endpoint, it should also accept a batch of data up to the max_batch_size (the API docs will need to be updated to reflect this) (batch_timeout is irrelevant here) OR it should be able to take in a single sample (there will be a parameter, batched, that is included in the payload that indicates whether the incoming sample is a single sample or a batch)
  • If a user launches() the UI without queueing, the UI will directly call /api/predict by setting batched=False (the default) in the payload
  • If a user launches() the UI with queueing, the queue should wait until batch_timeout seconds pass or batch is filled before internally calling /api/predict with the batched=True

abidlabs avatar Sep 21 '22 20:09 abidlabs

All right folks, this is finally ready for review!

User API

  • To enable "batch mode", the user creates an event trigger that looks like this: btn.click(fn, inputs, outputs, batch=True).
  • To configure the batch processing, the user can also pass an optional parameter: max_batch_size=8,
  • Note that I got rid of the batch_timeout parameter as it adds unnecessary complexity for v1 of queuing. And for viral demos for long queues, it shouldn't make any difference.
  • When using Interface, a full example might look like this:
import gradio as gr

def trim_words(words, lens):
    trimmed_words = []
    for w, l in zip(words, lens):
        trimmed_words.append(w[:l])        
    return [trimmed_words]

interface = gr.Interface(trim_words, ["textbox", "number"], ["textbox"], batch=True, max_batch_size=16)
interface.queue()
interface.launch()

Changes Made to the Code

Besides the basic changes needed to enable batching, I also made a few other changes:

  • Cleaned up the blocks.py code and separated the processing steps (serialize, preprocess, postprocess, deserialize) into separate methods
  • Added docs and test for batching
  • We need to know which function an Event is related to before we collect data for the event, so this information is now passed in via query parameters when the websocket connection is first made

Some Additional Notes on Batch Mode

  • Note that a "batched" function has a very different function signature (always has to take a single parameter which is a nested list of list) than a regular Gradio function (the docstring explains this I hope, feedback welcome)
  • If a user calls the Interface/Blocks as a function, it should always work with a single sample (regardless of whether the underlying function is a batched function or not)
  • If a user calls the /api/predict endpoint, it should also accept a batch of data up to the max_batch_size (the API docs should be updated to reflect this, but leaving that for later) OR it should be able to take in a single sample (there is a key in the data dictionary , batched, that is included in the payload that indicates whether the incoming sample is a single sample or a batch)
  • If someone launches a Gradio/Blocks demo with batched functions, the user who is using the demo should not be able to tell the difference. Their data gets uploaded as a batch of size 1, and Gradio takes care of the rest!

How to Test This

  • Run the following code
import gradio as gr
import time

def trim_words(words, lens):
    trimmed_words = []
    time.sleep(5)
    for w, l in zip(words, lens):
        trimmed_words.append(w[:l])        
    return [trimmed_words]

interface = gr.Interface(trim_words, ["textbox", gr.Number(precision=0)], ["textbox"], batch=True, max_batch_size=16)
interface.queue()
interface.launch()
  • Open the Gradio UI on three different tabs
  • Run a prediction on the first tab
  • Before the first prediction completes, run the 2nd and 3rd tab

You should see the following behavior: the first prediction should complete (after 5 seconds), then the 2nd and 3rd prediction should start running. They should be "batched together" and their predictions should complete at the same time (after another 5 seconds).

abidlabs avatar Sep 27 '22 00:09 abidlabs

Before anyone does a review, I realized that this PR still needs to:

  • add support for examples
  • handle the case if some of the examples in the batch error out Fun stuff

:D

abidlabs avatar Oct 03 '22 20:10 abidlabs

All righty, this is finally ready for another review! @freddyaboulton would you be able to take a look? I've merged in your changes as well as added support for examples in batch mode and error handling.

abidlabs avatar Oct 14 '22 05:10 abidlabs

Merged in main, the biggest sources of conflicts were between this branch and the branch that added the ability to cancel events. I've resolved the conflicts, but it would be good to ensure that canceling continues to work as expected. Also you can now cancel some of the events within a batch and the rest of the events within the batch will continue to work as expected. Here's how you can test it out:

Run the following:

import gradio as gr
import time

def trim_words(words, lens):
    trimmed_words = []
    time.sleep(5)
    for w, l in zip(words, lens):
        trimmed_words.append(w[:l])        
    return [trimmed_words]

with gr.Blocks() as demo:
    with gr.Row():
        word = gr.Textbox(label="word", value="abc")
        leng = gr.Number(label="leng", precision=0, value=1)
        output = gr.Textbox(label="Output")
    with gr.Row():
        run = gr.Button()
        canc = gr.Button("Cancel", variant="stop")

    event = run.click(trim_words, [word, leng], output, batch=True, max_batch_size=16)
    canc.click(None, None, None, cancels=event)

demo.queue()
demo.launch()

Open up 4 separate tabs. Then:

  1. Run all 4 of the tabs
  2. Go the 2nd tab, and cancel the event
  3. Tabs 3 and 4 should continue to run and complete together

abidlabs avatar Oct 19 '22 19:10 abidlabs

Thanks for the detailed review @freddyaboulton! Let me go through the issues you identified and fix this up.

Another thing we need to do is add batched functions to our guides and show example usage. But I'll do that as part of a separate PR (maybe one tackling #2016) so that we don't keep this open for too much longer and accumulate more conflicts.

abidlabs avatar Oct 20 '22 02:10 abidlabs

@abidlabs Good news I just did some benchmarking on batching and I think the results are favorable!

What I did

I created this space: https://huggingface.co/spaces/gradio/queue-batch-benchmark which is a copy of our existing queue benchmark but I added a batching. I also created this space: https://huggingface.co/spaces/gradio/queue-batch-benchmark-max-size-50 which uses a max_batch_size=20 (I titled the space incorrectly lol) for each function.

Results

Main branch (no batching)

Input type Average time to complete prediction over 1000 requests
Audio 1.5452740753398222
Image 1.5123367877716714
Text 1.5666406706197937
Video 2.029747503848115

This PR (max_batch_size 4)

Input type Average time to complete prediction over 1000 requests
Audio 1.165496392960244
Image 0.8552168062177755
Text 0.841041189251524
Video 2.532181569765199

This PR (max_batch_size 20)

Input type Average time to complete prediction over 1000 requests
Audio 1.2939001665276997
Image 0.9907303559825795
Text 0.993362974653057
Video 3.0297954572785284

I think it's interesting that a higher max_batch_size did not uniformly bring down the latency.

I tried again, only sending requests for the image function:

Max batch size Average time to complete prediction over 1000 requests
4 0.9002170345783234
20 1.3551465120315551

Using a max_batch_size of 20 is still faster than no batching but not faster than a batch size of 4 🤔

I think this is evidence that we can optimize the batch logic (maybe gathering data for each event in the batch is adding overhead?) but we can do so in a future PR. IMO this can go out and we can get user feedback.

Other thing

gr.Blocks.load does not work if the space using batching. I tried with the queue benchmark and you'll see that the websocket returns "['foo']" and then we take the first element of it to return "[".

image

Other than that (and the comments I already made) I think this is good to merge! Feel free to merge without fixing gr.Blocks.load and we can do that in a separate PR.

freddyaboulton avatar Oct 20 '22 15:10 freddyaboulton

Amaaazing thank you so much @freddyaboulton for the benchmarking! Have a few meetings, but then I'll parse through this more fully and address the comments above

abidlabs avatar Oct 20 '22 16:10 abidlabs

Also, just noticed that a demo without batching built off this PR is showing higher latency than the same demo built off main. Would be good to look into this as well!

python benchmark_script_gather.py wss://spaces.huggingface.tech/gradio/queue-benchmark/queue/join --batch_size 50

{'fn_index': ['audio', 'image', 'text', 'video'], 'duration': [1.724048194614982, 1.6302076853891285, 1.5274697332470506, 2.1359671726822853]}

python benchmark_script_gather.py wss://spaces.huggingface.tech/gradio/queue-benchmark-batch-branch/queue/join --batch_size 50

{'fn_index': ['audio', 'image', 'text', 'video'], 'duration': [2.0668310448826555, 1.888936443934365, 1.9982213218865446, 2.88502193372184]}

freddyaboulton avatar Oct 20 '22 16:10 freddyaboulton

Also, just noticed that a demo without batching built off this PR is showing higher latency than the same demo built off main. Would be good to look into this as well!

How are you running --batch_size 50 with the demo built off main?

abidlabs avatar Oct 20 '22 16:10 abidlabs

My bad for confusing terminology. There are two batches here: —batch_size 50 means send 50 requests to the space at once. The other batch size is max_batch_size in the event listener. I’ll edit my previous comment to make that clearer.

freddyaboulton avatar Oct 20 '22 17:10 freddyaboulton

I’m also going to run the two benchmark apps locally (not hosted on spaces) and run the benchmark to make sure it’s not something weird with my internet connection.

freddyaboulton avatar Oct 20 '22 17:10 freddyaboulton

Thanks @freddyaboulton this is super helpful

abidlabs avatar Oct 20 '22 17:10 abidlabs

Actually disregard that last comment about not-batching being slower in this branch. I can't reproduce locally (while I still reproduce the fact batching is faster than not batching). I did a factory reset on https://huggingface.co/spaces/gradio/queue-benchmark-batch-branch and the times are now pretty much in line with https://huggingface.co/spaces/gradio/queue-benchmark 🤷‍♂️ . False alarm I think.

freddyaboulton avatar Oct 20 '22 17:10 freddyaboulton

All righty, this is finally ready to be merged. I've added documentation about how to use batched functions as well. There are further optimizations that can be added, like @FarukOzderim pointed out, but I'll save that for a future PR (probably when I write the Guide on batched functions and concurrency)

@freddyaboulton I didn't understand your question about asyncgenfunctions, but otherwise everything else has been addressed.

I'll merge in in a couple of hours -- thanks all!

abidlabs avatar Oct 24 '22 21:10 abidlabs