dspy icon indicating copy to clipboard operation
dspy copied to clipboard

DSpy parallel processing

Open anslin-raj opened this issue 1 year ago • 12 comments

I tried to run the DSpy module inside the thread it's not working...

I have used

dspy.Module

there I used a custom retriever module.

elastic_rm(dspy.Retrieve)

Can anyone give me a proper way for parallel processing?

anslin-raj avatar Jun 10 '24 06:06 anslin-raj

This might not be working since evaluation and some optimizers are using threading themselves. Are you trying to speed up optimization or inference?

tom-doerr avatar Jun 10 '24 08:06 tom-doerr

@tom-doerr I'm trying for inferencing, do you have any parallel inferencing code?

anslin-raj avatar Jun 10 '24 11:06 anslin-raj

Yes I do: GPakSEEacAABQvc https://x.com/tom_doerr/status/1798806436476334123

This is some separate code I use somewhere for evaluation:

            fewshot_optimizer = BootstrapFewShot(metric=great_tweet_metric, max_bootstrapped_demos=4, metric_threshold=metric_threshold)
            compile_start = time.time()

            threads = []
            for dataset_idx in range(TRAIN_SIZE):
                t = threading.Thread(target=fewshot_optimizer.compile, kwargs=dict(student=tweet_generator, teacher=teacher, trainset=[trainset[dataset_idx]]))
                threads.append(t)
                t.start()

            print("All threads have been created.")
            for t in threads:
                t.join()

            print("====== num_nesting_levels_dict:", num_nesting_levels_dict)
            print("All threads have completed.")

            tweet_generator_compiled = fewshot_optimizer.compile(student = tweet_generator, teacher = teacher, trainset=trainset)

tom-doerr avatar Jun 10 '24 11:06 tom-doerr

Thank you @tom-doerr.

I have tried TypedPredictor and TypedChainOfThought as well, but I'm facing an error, I have attached the code snippet and the error message. I'm using AsyncIO for parallel processing and FasAPI for app.

Code:

def oai_ef(text, model="text-embedding-ada-002"):
    return client.embeddings.create(input=text, model=model).data[0].embedding


class elastic_rm(dspy.Retrieve):
    def __init__(self, es_client, es_index, es_field, embedding_function, k=3):
        super().__init__()
        self.k=k
        self.es_index=es_index
        self.es_client=es_client
        self.field=es_field
        self.ef = embedding_function


    def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
        return self.ef(queries)
    
    
    def forward(self, query: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction:
        # Add retriver logic here
        return dspy.Prediction(passages=passages)


class RAG(dspy.Module):

    def __init__(self):
        super().__init__()
        # self.generate_answer = dspy.ChainOfThought("context, question -> answer")
        # self.generate_answer = dspy.TypedPredictor("context, question -> answer")
        self.generate_answer = dspy.TypedChainOfThought("context, question -> answer")

    def forward(self, question, rm, lm, num_passages=3):

        self.retrieve = rm(query_or_queries=question, k=num_passages)
        context = self.retrieve.passages

        prediction = self.generate_answer(question=question, context=context)
        return dspy.Prediction(answer=prediction.answer, context=context)

if __name__ =='__main__':
    rm = elastic_rm(es_client_01, "embeddings_index", "embedding_field", embedding_function=oai_ef)
    lm = dspy.OpenAI(model="gpt-4o", max_tokens=4000, api_key=OPENAI_API_KEY)
    dspy.settings.configure(lm=lm, rm=rm)
    qa = RAG()
    prompt = "Qustion"
    response = qa(prompt, rm, lm, num_passages=50)

Error:

Traceback (most recent call last):
  File "D:\app\worker\llm_worker.py", line 174, in process_request
    response = qa(prompt, rm, lm, num_passages=50)
  File "D:\app\venv\lib\site-packages\dspy\primitives\program.py", line 26, in __call__
    return self.forward(*args, **kwargs)
  File "D:\app\worker\llm_worker.py", line 152, in forward
    prediction = self.generate_answer(question=question, context=context)
  File "D:\app\venv\lib\site-packages\dspy\primitives\program.py", line 26, in __call__
    return self.forward(*args, **kwargs)
  File "D:\app\venv\lib\site-packages\dspy\functional\functional.py", line 295, in forward
    result = self.predictor(**modified_kwargs, new_signature=signature)
  File "D:\app\venv\lib\site-packages\dspy\predict\predict.py", line 61, in __call__
    return self.forward(**kwargs)
  File "D:\app\venv\lib\site-packages\dspy\predict\predict.py", line 111, in forward
    x, C = dsp.generate(template, **config)(x, stage=self.stage)
  File "D:\app\venv\lib\site-packages\dsp\primitives\predict.py", line 78, in do_generate
    completions: list[dict[str, Any]] = generator(prompt, **kwargs)
  File "D:\app\venv\lib\site-packages\dsp\modules\gpt3.py", line 178, in __call__
    response = self.request(prompt, **kwargs)
  File "D:\app\venv\lib\site-packages\backoff\_sync.py", line 105, in retry
    ret = target(*args, **kwargs)
  File "D:\app\venv\lib\site-packages\dsp\modules\gpt3.py", line 144, in request
    return self.basic_request(prompt, **kwargs)
  File "D:\app\venv\lib\site-packages\dsp\modules\gpt3.py", line 116, in basic_request
    kwargs = {"stringify_request": json.dumps(kwargs)}
  File "C:\Users\user\.pyenv\pyenv-win\versions\3.10.6\lib\json\__init__.py", line 231, in dumps
    return _default_encoder.encode(obj)
  File "C:\Users\user\.pyenv\pyenv-win\versions\3.10.6\lib\json\encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "C:\Users\user\.pyenv\pyenv-win\versions\3.10.6\lib\json\encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
  File "C:\Users\user\.pyenv\pyenv-win\versions\3.10.6\lib\json\encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type GPT3 is not JSON serializable

anslin-raj avatar Jun 10 '24 19:06 anslin-raj

Is this related to parallelism? Can't see any code related to that

tom-doerr avatar Jun 10 '24 19:06 tom-doerr

@tom-doerr, Sorry for the inconvenience, there is the updated code with parallel processing, as of now I've not developed the FastAPI part.

Code:

async def process_request(es: Elasticsearch, prompt: str):
    rm = elastic_rm(es_client_01, "embeddings_index", "embedding_field", embedding_function=oai_ef)
    lm = dspy.OpenAI(model="gpt-4o", max_tokens=4000, api_key=OPENAI_API_KEY)
    dspy.settings.configure(lm=lm, rm=rm)
    qa = RAG()
    response = qa(prompt, rm, lm, num_passages=50)

async def process_requests():
    while True:
        new_requests = await retrieve_new_requests()
        tasks = [asyncio.create_task(process_request(es, request)) for request in new_requests]

        for task in asyncio.as_completed(tasks):
            try:
                result = await task
                print(f"Request processed successfully")
            except Exception as e:
                print(f"Error processing request: {e}")
                exit(0)

        await asyncio.sleep(10)


async def main():
    es = Elasticsearch(ELASTICSEARCH_HOST)
    await process_requests(es)


if __name__ == "__main__":
    asyncio.run(main())

anslin-raj avatar Jun 10 '24 19:06 anslin-raj

Could you just switch to a process-based worker model? That should still give you parallelism without needing to serialize the GPT3 instance.

tom-doerr avatar Jun 10 '24 19:06 tom-doerr

@tom-doerr I'm not aware of the process-based worker model. Do you have any sample code for the process-based worker model, could you please share it?

anslin-raj avatar Jun 10 '24 19:06 anslin-raj

As far as I know, it works using uvicorn or guvicorn

tom-doerr avatar Jun 10 '24 19:06 tom-doerr

Thank you @tom-doerr, I understand, but I had the plan to manage OpenAI calls in a centralized place when we tried more requests and more data we may face the token limit and rate limit exceptions. So I selected this one, in this structure we have full control over the OpenAI calls and customization on data retrieval.

Does DSpy automatically handle these exceptions?

Do you know any other way to solve this?

anslin-raj avatar Jun 10 '24 19:06 anslin-raj

Not really sure why having multiple instances would trigger token or rate limits faster or how having it centralized helps in data retrieval. You could try to make it serializable, not sure however how feasible this is. Other ideas:

  • create a openai.py that is loaded instead of the package and process the requests however you like and then send them to OpenAI
  • create a server where the requests are sent to before sending them to OpenAI, you could do that changing the base url
  • implement your own OpenAI client

tom-doerr avatar Jun 10 '24 20:06 tom-doerr

No @tom-doerr, multiple instances, and single instances trigger the exception based on the count and size of the request, if we handle all requests in a single instance it could be easier to avoid the occurrence of those.

anslin-raj avatar Jun 11 '24 05:06 anslin-raj