gradio
gradio copied to clipboard
Batching
Here is the design I propose for GPU Batching:
User API
- To enable "batch mode", the user provides a
batch_size
andbatch_timeout
argument to the.queue()
method. The queue will process jobs in batches of sizebatch_size
unless a smaller number of jobs thanbatch_size
have been queued withinbatch_timeout
number of seconds (by default 1.0). Ifbatch_timeout
seconds have elapsed, the Queue will process the jobs in the batch. - The user must also provide a
batch_fn
in the Blocks event trigger or inInterface
, which is a version of the function that is designed to take in a batch of samples.
Backend changes
- The queue will have a new
process_batch_events()
method to handle a batch of events - There will be a new
/api/batch/{api_name}
route created to handle batched prediction - This route will internally use the
batch_fn
instead of the regularfn
Would appreciate any feedback on this design before I start working on it @aliabid94 @freddyaboulton @apolinario
Fixes: #1597
All the demos for this PR have been deployed at https://huggingface.co/spaces/gradio-pr-deploys/pr-2218-all-demos
makes sense. Should batch_size
be tied to the event listener instead of the queue? do batch functions usually support a single batch size? what happens if there's less than that (e.g. 5 data points in the batch for a batch_fn that takes 8 batches when the timeout goes off)
No I don't think batched functions typically support a single batch size. Usually, you use the largest batch size that you expect to fit in memory. So there would be no real problem if a batch of size 5 samples were to get fed into a batch_fn where we've set that batch_size to be 8)
Would you agree with that @apolinario @pcuenca?
Yes, except when using JAX, where a multiple of 8 is often required for parallelization and batch sizes are therefore fixed, usually. But this could be handled by the user in the batch_fn
, I guess. I don't think it's necessary to add more complexity to the queue by introducing a fixed_size
parameter or something like that, what do you think?
@abidlabs This looks great! I agree with @aliabid94 that it may be better to tie the batch fn with the event listener as opposed to the queue. With this approach, it seems only one function can be batched per demo.
@freddyaboulton yes for sure batch_fn
will be a parameter in the event listener. Given @pcuenca's point, I think we might want to implement a per-function batch size as well, but I don't think we need in this in the v1. So in summary:
The .queue()
will have the following parameters:
-
batch_size
(by defaultNone
) -
batch_timeout
(by default1
seconds)
The event handlers will have the following parameters:
-
batch_fn
- Eventually, we may add a
batch_size
that overrides the globalbatch_size
inqueue()
- Eventually, we may add a
batch_timeout
that overrides the globalbatch_timeout
inqueue()
- Eventually, we may add a boolean
force_batch_size
that manually increases a smaller batch size to a larger one by repeating input samples.
@abidlabs why is there a batch_fn
? Isn't it unnecessary and make it more complex? Sending arbitrary length of list of inputs to the fn
seems fine to me and had verified it with @pcuenca in the past. I feel like instead of batch_fn, there should be batch_size
and batch_threshold
in the event listener.
@abidlabs Thank you for the summary! Another question - Can we avoid adding another api endpoint? I think it will make it easier for api "consumers" to use the api, e.g. people loading interfaces, the front-end, people using the api manually, if we don't add another end point and it will make it easier to avoid bugs like #1316
Agree with @FarukOzderim, why do we need a separate batch fn? And with @freddyaboulton that we shouldn't add another API endpoint. Here's the syntax I have in mind:
btn.click(fn, inputs, outputs, batch=True, max_batch_size=8)
and everything should automatically be handled from there.
Thanks for all of the suggestions guys! This makes a lot of sense, we don't actually need a separate batch_fn
. If batch mode is enabled, we require the function to work with a batch of inputs, which also avoids us creating a separate API endpoint. So the new proposed design is:
User API
- To enable "batch mode", the user creates an event trigger that looks like this:
btn.click(fn, inputs, outputs, batch=True)
. - To configure the batch processing, the user can also pass in two optional parameters:
max_batch_size=8
,batch_timeout=1
- We also add these parameters (
default_batch=False
,default_max_batch_size=8
,default_batch_timeout=1
) toqueue()
, which applies them to all of the backend functions for Blocks. The main advantage of this would be making it easier to get batching to work withInterface
(open to other suggestions here)
Backend changes
- The queue will have a new
process_batch_events()
method to handle a batch of events, but this is called from the regular/api/{api_name}
endpoint ifbatch=True
for that particular function
We also add these parameters (default_batch=False, default_max_batch_size=8, default_batch_timeout=1) to queue(), which applies them to all of the backend functions for Blocks. The main advantage of this would be making it easier to get batching to work with Interface (open to other suggestions here)
I'm not sure if adding those parameters to the queue method is the best because I don't think it's a good idea to apply the same batching behavior to all backend functions. For Interfaces, for example, I don't think we want to batch the flagging or at least not the same way we batch the main prediction function.
What if we control the batching behavior for interfaces via the constructor?
I'm not sure if adding those parameters to the queue method is the best because I don't think it's a good idea to apply the same batching behavior to all backend functions. For Interfaces, for example, I don't think we want to batch the flagging or at least not the same way we batch the main prediction function.
What if we control the batching behavior for interfaces via the constructor?
I had originally avoided that because (a) I didn't want to add too many parameters to the Interface
constructor, and (b) I thought that these parameters only made sense if queueing was enabled
But now that I think about it, if batch mode is enabled, we require the user to pass in a fn
that can handle batch inputs, so it would make sense to put it in the same method as the fn
, which in this case would the be Interface
constructor. And strictly speaking, we don't need queuing to enable batch mode. You could pass in a batch of inputs through http requests as well
So yeah I think this is a better approach, thanks @freddyaboulton
Thinking aloud here...
When batch=True
, the following should happen:
- If a user calls the Interface/Blocks as a function, it should always work with a single sample (regardless of whether the underlying function is a batched function or not)
- If a user calls the
/api/predict
endpoint, it should also accept a batch of data up to themax_batch_size
(the API docs will need to be updated to reflect this) (batch_timeout
is irrelevant here) OR it should be able to take in a single sample (there will be a parameter,batched
, that is included in the payload that indicates whether the incoming sample is a single sample or a batch) - If a user launches() the UI without queueing, the UI will directly call
/api/predict
by settingbatched=False
(the default) in the payload - If a user launches() the UI with queueing, the queue should wait until
batch_timeout
seconds pass or batch is filled before internally calling/api/predict
with thebatched=True
All right folks, this is finally ready for review!
User API
- To enable "batch mode", the user creates an event trigger that looks like this:
btn.click(fn, inputs, outputs, batch=True)
. - To configure the batch processing, the user can also pass an optional parameter:
max_batch_size=8
, - Note that I got rid of the
batch_timeout
parameter as it adds unnecessary complexity for v1 of queuing. And for viral demos for long queues, it shouldn't make any difference. - When using Interface, a full example might look like this:
import gradio as gr
def trim_words(words, lens):
trimmed_words = []
for w, l in zip(words, lens):
trimmed_words.append(w[:l])
return [trimmed_words]
interface = gr.Interface(trim_words, ["textbox", "number"], ["textbox"], batch=True, max_batch_size=16)
interface.queue()
interface.launch()
Changes Made to the Code
Besides the basic changes needed to enable batching, I also made a few other changes:
- Cleaned up the
blocks.py
code and separated the processing steps (serialize, preprocess, postprocess, deserialize) into separate methods - Added docs and test for batching
- We need to know which function an
Event
is related to before we collect data for the event, so this information is now passed in via query parameters when the websocket connection is first made
Some Additional Notes on Batch Mode
- Note that a "batched" function has a very different function signature (always has to take a single parameter which is a nested list of list) than a regular Gradio function (the docstring explains this I hope, feedback welcome)
- If a user calls the Interface/Blocks as a function, it should always work with a single sample (regardless of whether the underlying function is a batched function or not)
- If a user calls the /api/predict endpoint, it should also accept a batch of data up to the
max_batch_size
(the API docs should be updated to reflect this, but leaving that for later) OR it should be able to take in a single sample (there is a key in the data dictionary ,batched
, that is included in the payload that indicates whether the incoming sample is a single sample or a batch) - If someone launches a Gradio/Blocks demo with batched functions, the user who is using the demo should not be able to tell the difference. Their data gets uploaded as a batch of size 1, and Gradio takes care of the rest!
How to Test This
- Run the following code
import gradio as gr
import time
def trim_words(words, lens):
trimmed_words = []
time.sleep(5)
for w, l in zip(words, lens):
trimmed_words.append(w[:l])
return [trimmed_words]
interface = gr.Interface(trim_words, ["textbox", gr.Number(precision=0)], ["textbox"], batch=True, max_batch_size=16)
interface.queue()
interface.launch()
- Open the Gradio UI on three different tabs
- Run a prediction on the first tab
- Before the first prediction completes, run the 2nd and 3rd tab
You should see the following behavior: the first prediction should complete (after 5 seconds), then the 2nd and 3rd prediction should start running. They should be "batched together" and their predictions should complete at the same time (after another 5 seconds).
Before anyone does a review, I realized that this PR still needs to:
- add support for examples
- handle the case if some of the examples in the batch error out Fun stuff
:D
All righty, this is finally ready for another review! @freddyaboulton would you be able to take a look? I've merged in your changes as well as added support for examples in batch mode and error handling.
Merged in main
, the biggest sources of conflicts were between this branch and the branch that added the ability to cancel events. I've resolved the conflicts, but it would be good to ensure that canceling continues to work as expected. Also you can now cancel some of the events within a batch and the rest of the events within the batch will continue to work as expected. Here's how you can test it out:
Run the following:
import gradio as gr
import time
def trim_words(words, lens):
trimmed_words = []
time.sleep(5)
for w, l in zip(words, lens):
trimmed_words.append(w[:l])
return [trimmed_words]
with gr.Blocks() as demo:
with gr.Row():
word = gr.Textbox(label="word", value="abc")
leng = gr.Number(label="leng", precision=0, value=1)
output = gr.Textbox(label="Output")
with gr.Row():
run = gr.Button()
canc = gr.Button("Cancel", variant="stop")
event = run.click(trim_words, [word, leng], output, batch=True, max_batch_size=16)
canc.click(None, None, None, cancels=event)
demo.queue()
demo.launch()
Open up 4 separate tabs. Then:
- Run all 4 of the tabs
- Go the 2nd tab, and cancel the event
- Tabs 3 and 4 should continue to run and complete together
Thanks for the detailed review @freddyaboulton! Let me go through the issues you identified and fix this up.
Another thing we need to do is add batched functions to our guides and show example usage. But I'll do that as part of a separate PR (maybe one tackling #2016) so that we don't keep this open for too much longer and accumulate more conflicts.
@abidlabs Good news I just did some benchmarking on batching and I think the results are favorable!
What I did
I created this space: https://huggingface.co/spaces/gradio/queue-batch-benchmark which is a copy of our existing queue benchmark but I added a batching. I also created this space: https://huggingface.co/spaces/gradio/queue-batch-benchmark-max-size-50 which uses a max_batch_size=20 (I titled the space incorrectly lol) for each function.
Results
Main branch (no batching)
Input type | Average time to complete prediction over 1000 requests |
---|---|
Audio | 1.5452740753398222 |
Image | 1.5123367877716714 |
Text | 1.5666406706197937 |
Video | 2.029747503848115 |
This PR (max_batch_size 4)
Input type | Average time to complete prediction over 1000 requests |
---|---|
Audio | 1.165496392960244 |
Image | 0.8552168062177755 |
Text | 0.841041189251524 |
Video | 2.532181569765199 |
This PR (max_batch_size 20)
Input type | Average time to complete prediction over 1000 requests |
---|---|
Audio | 1.2939001665276997 |
Image | 0.9907303559825795 |
Text | 0.993362974653057 |
Video | 3.0297954572785284 |
I think it's interesting that a higher max_batch_size did not uniformly bring down the latency.
I tried again, only sending requests for the image function:
Max batch size | Average time to complete prediction over 1000 requests |
---|---|
4 | 0.9002170345783234 |
20 | 1.3551465120315551 |
Using a max_batch_size of 20 is still faster than no batching but not faster than a batch size of 4 🤔
I think this is evidence that we can optimize the batch logic (maybe gathering data for each event in the batch is adding overhead?) but we can do so in a future PR. IMO this can go out and we can get user feedback.
Other thing
gr.Blocks.load does not work if the space using batching. I tried with the queue benchmark and you'll see that the websocket returns "['foo']" and then we take the first element of it to return "[".
Other than that (and the comments I already made) I think this is good to merge! Feel free to merge without fixing gr.Blocks.load
and we can do that in a separate PR.
Amaaazing thank you so much @freddyaboulton for the benchmarking! Have a few meetings, but then I'll parse through this more fully and address the comments above
Also, just noticed that a demo without batching built off this PR is showing higher latency than the same demo built off main. Would be good to look into this as well!
python benchmark_script_gather.py wss://spaces.huggingface.tech/gradio/queue-benchmark/queue/join --batch_size 50
{'fn_index': ['audio', 'image', 'text', 'video'], 'duration': [1.724048194614982, 1.6302076853891285, 1.5274697332470506, 2.1359671726822853]}
python benchmark_script_gather.py wss://spaces.huggingface.tech/gradio/queue-benchmark-batch-branch/queue/join --batch_size 50
{'fn_index': ['audio', 'image', 'text', 'video'], 'duration': [2.0668310448826555, 1.888936443934365, 1.9982213218865446, 2.88502193372184]}
Also, just noticed that a demo without batching built off this PR is showing higher latency than the same demo built off main. Would be good to look into this as well!
How are you running --batch_size 50
with the demo built off main?
My bad for confusing terminology. There are two batches here: —batch_size 50 means send 50 requests to the space at once. The other batch size is max_batch_size in the event listener. I’ll edit my previous comment to make that clearer.
I’m also going to run the two benchmark apps locally (not hosted on spaces) and run the benchmark to make sure it’s not something weird with my internet connection.
Thanks @freddyaboulton this is super helpful
Actually disregard that last comment about not-batching being slower in this branch. I can't reproduce locally (while I still reproduce the fact batching is faster than not batching). I did a factory reset on https://huggingface.co/spaces/gradio/queue-benchmark-batch-branch and the times are now pretty much in line with https://huggingface.co/spaces/gradio/queue-benchmark 🤷♂️ . False alarm I think.
All righty, this is finally ready to be merged. I've added documentation about how to use batched functions as well. There are further optimizations that can be added, like @FarukOzderim pointed out, but I'll save that for a future PR (probably when I write the Guide on batched functions and concurrency)
@freddyaboulton I didn't understand your question about asyncgenfunctions
, but otherwise everything else has been addressed.
I'll merge in in a couple of hours -- thanks all!