replicate-python icon indicating copy to clipboard operation
replicate-python copied to clipboard

Ability to process in Batches

Open charliemday opened this issue 1 year ago • 5 comments

Is there any ability to batch process similar to how OpenAI do it: https://help.openai.com/en/articles/9197833-batch-api-faq?

charliemday avatar Apr 29 '24 07:04 charliemday

Hi @charliemday. No, Replicate doesn't currently implement a batch processing API like OpenAI. It's something we're considering, though. Can you share more about your intended use case?

mattt avatar Apr 29 '24 11:04 mattt

Hi @mattt, sure.

My intended use case is that I have a CSV file of ~1k rows and I want to to send a request for each row. Given that that the response takes ~1-2 seconds this is going to take ~20 minutes give or take. I would like to batch the rows into groups of 10 and send them all at once taking the time to completion down considerably.

charliemday avatar Apr 29 '24 11:04 charliemday

@charliemday Replicate does support creating up to 6000 concurrent predictions per minute. Depending on how much the model is scaled out, you could process all of them more quickly using our async API. Here's an example from the README that you can adapt (I'd recommend processing rows in async batches of 100 or so, and keeping track of successful and failing rows:

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

results = await asyncio.gather(*tasks)
print(results)

mattt avatar Apr 29 '24 11:04 mattt

@charliemday Replicate does support creating up to 6000 concurrent predictions per minute. Depending on how much the model is scaled out, you could process all of them more quickly using our async API. Here's an example from the README that you can adapt (I'd recommend processing rows in async batches of 100 or so, and keeping track of successful and failing rows:

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

results = await asyncio.gather(*tasks)
print(results)

Thanks @mattt , this should work for my use case.

Maybe I'm missing something but I don't see 6000 RPM on the link (only 600)?

charliemday avatar Apr 29 '24 14:04 charliemday

@charliemday Apologies, yes — the rate limit is 600 / minute. That was a typo on my part.

mattt avatar Apr 29 '24 14:04 mattt