paper-qa
paper-qa copied to clipboard
LiteLLM configuration for caching and streaming
Hi,
I try to configure LiteLLM for enabling caching and streaming with LiteLLM.
# -*- coding: utf-8 -*-
from paperqa.settings import Settings, AgentSettings, AnswerSettings
from paperqa.agents import ask, build_index
from paperqa import Docs, HybridEmbeddingModel,SparseEmbeddingModel,LiteLLMEmbeddingModel, SentenceTransformerEmbeddingModel, NumpyVectorStore, embedding_model_factory
from prompts_french import summary_prompt,summary_json_prompt,qa_prompt, select_paper_prompt, citation_prompt, structured_citation_prompt, default_system_prompt, summary_json_system_prompt
import os
import csv
import litellm
import glob
import compress_pickle
from paperqa.agents.main import agent_query
from paperqa.agents.search import get_directory_index
import asyncio
from datetime import datetime
from litellm.caching.caching import Cache
litellm.set_verbose=True
litellm.stream=True
litellm.cache=True
litellm.request_timeout = 600
os.environ["OPENAI_API_KEY"] = "sk-no-key-required"
os.environ['MISTRAL_API_KEY'] = "sk-no-key-required"
local_llm_config = dict(
model_list=[
dict(
model_name="ollama/dolphin-mixtral:8x7b",
litellm_params=dict(
model="ollama/dolphin-mixtral:8x7b",
api_type= "ollama",
api_base="http://localhost:11434",
stream=True,
caching=True,
cache= True,
# format='json',
mirostat=0,
mirostat_eta=0.1,
mirostat_tau=5.0,
num_ctx=8192, #Default 2048
repeat_last_n=64, # Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
repeat_penalty=1.1,
temperature=0.5,
seed=0,
tfs_z=1,
num_predict=8192,# Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)
max_tokens=8192,
# low_vram=True,
top_k=20,
top_p=0.9,
min_p=0.0,
timeout = 600,
),
),
]
)
'stream' parameter seems unrecognized as kwarg by litellm.
litellm.cache=True leads to an error that indicates this arg cannot be a bool... Also tried to pass a Cache() object but leads to an error...
What is the correct way to enable streaming and caching correctly ?
To correctly enable streaming and caching in the litellm library, you need to adjust the parameters as follows:
-
Streaming: The
streamparameter should be set within thelitellm_paramsdictionary for the specific model configuration. Ensure that it is correctly recognized by the library. If it is not recognized, it might be due to a version mismatch or incorrect usage. Double-check the library's documentation or source code to confirm the correct parameter name and usage. -
Caching: Instead of setting
litellm.cache=True, you should instantiate aCache()object and assign it tolitellm.cache. Here's an example of how to do this:from litellm.caching.caching import Cache # Create a Cache object litellm.cache = Cache()Ensure that the
Cache()object is correctly imported and used. This should resolve the error related to thecacheparameter not accepting a boolean value.
By following these steps, you should be able to configure litellm with caching and streaming enabled correctly. If issues persist, consider checking for any updates or changes in the library's API that might affect these configurations.
To continue talking to Dosu, mention @dosu.
Help Dosu learn! Give it feedback: Great Response | Irrelevant Answer | Incorrect Sources | Too Verbose | Hallucination | Bug Report | Other
- 'stream' parameter is ever set on 'True' in litellm_params but I can see it's still False in LiteLLM verbose mode.
- passed
Cache()objectlitellm.cache = Cache()and got this error :
[12:47:07] Trajectory failed.
╭──────────────────────────────────── Traceback (most recent call last) ────────────────────────────────────╮
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\paperqa\agents\main.py:158 in │
│ _run_with_timeout_failure │
│ │
│ 155 ) -> tuple[PQASession, AgentStatus]: │
│ 156 │ try: │
│ 157 │ │ async with asyncio.timeout(settings.agent.timeout): │
│ ❱ 158 │ │ │ status = await rollout() │
│ 159 │ except TimeoutError: │
│ 160 │ │ logger.warning( │
│ 161 │ │ │ f"Agent timeout after {settings.agent.timeout}-sec, just answering." │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\paperqa\agents\main.py:309 in rollout │
│ │
│ 306 │ │ │ │ ) │
│ 307 │ │ │ │ return AgentStatus.TRUNCATED │
│ 308 │ │ │ agent_state.messages += obs │
│ ❱ 309 │ │ │ for attempt in Retrying( │
│ 310 │ │ │ │ stop=stop_after_attempt(5), │
│ 311 │ │ │ │ retry=retry_if_exception_type(MalformedMessageError), │
│ 312 │ │ │ │ before_sleep=before_sleep_log(logger, logging.WARNING), │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\tenacity\__init__.py:443 in __iter__ │
│ │
│ 440 │ │ │
│ 441 │ │ retry_state = RetryCallState(self, fn=None, args=(), kwargs={}) │
│ 442 │ │ while True: │
│ ❱ 443 │ │ │ do = self.iter(retry_state=retry_state) │
│ 444 │ │ │ if isinstance(do, DoAttempt): │
│ 445 │ │ │ │ yield AttemptManager(retry_state=retry_state) │
│ 446 │ │ │ elif isinstance(do, DoSleep): │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\tenacity\__init__.py:376 in iter │
│ │
│ 373 │ │ self._begin_iter(retry_state) │
│ 374 │ │ result = None │
│ 375 │ │ for action in self.iter_state.actions: │
│ ❱ 376 │ │ │ result = action(retry_state) │
│ 377 │ │ return result │
│ 378 │ │
│ 379 │ def _begin_iter(self, retry_state: "RetryCallState") -> None: # noqa │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\tenacity\__init__.py:398 in <lambda> │
│ │
│ 395 │ │
│ 396 │ def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None: │
│ 397 │ │ if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result): │
│ ❱ 398 │ │ │ self._add_action_func(lambda rs: rs.outcome.result()) │
│ 399 │ │ │ return │
│ 400 │ │ │
│ 401 │ │ if self.after is not None: │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\concurrent\futures\_base.py:449 in result │
│ │
│ 446 │ │ │ │ if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]: │
│ 447 │ │ │ │ │ raise CancelledError() │
│ 448 │ │ │ │ elif self._state == FINISHED: │
│ ❱ 449 │ │ │ │ │ return self.__get_result() │
│ 450 │ │ │ │ │
│ 451 │ │ │ │ self._condition.wait(timeout) │
│ 452 │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\concurrent\futures\_base.py:401 in __get_result │
│ │
│ 398 │ def __get_result(self): │
│ 399 │ │ if self._exception: │
│ 400 │ │ │ try: │
│ ❱ 401 │ │ │ │ raise self._exception │
│ 402 │ │ │ finally: │
│ 403 │ │ │ │ # Break a reference cycle with the exception in self._exception │
│ 404 │ │ │ │ self = None │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\paperqa\agents\main.py:316 in rollout │
│ │
│ 313 │ │ │ │ reraise=True, │
│ 314 │ │ │ ): │
│ 315 │ │ │ │ with attempt: # Retrying if ToolSelector fails to select a tool │
│ ❱ 316 │ │ │ │ │ action = await agent(agent_state.messages, tools) │
│ 317 │ │ │ agent_state.messages = [*agent_state.messages, action] │
│ 318 │ │ │ if on_agent_action_callback: │
│ 319 │ │ │ │ await on_agent_action_callback(action, agent_state) │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\aviary\tools\utils.py:84 in __call__ │
│ │
│ 81 │ │ │ self._ledger.messages.extend(messages) │
│ 82 │ │ │ messages = self._ledger.messages │
│ 83 │ │ │
│ ❱ 84 │ │ model_response = await self._bound_acompletion( │
│ 85 │ │ │ messages=MessagesAdapter.dump_python( │
│ 86 │ │ │ │ messages, exclude_none=True, by_alias=True │
│ 87 │ │ │ ), │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:907 in acompletion │
│ │
│ 904 │ │ │ │ │ original_exception=e, │
│ 905 │ │ │ │ ) │
│ 906 │ │ │ ) │
│ ❱ 907 │ │ │ raise e │
│ 908 │ │
│ 909 │ async def _acompletion( │
│ 910 │ │ self, model: str, messages: List[Dict[str, str]], **kwargs │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:883 in acompletion │
│ │
│ 880 │ │ │ if request_priority is not None and isinstance(request_priority, int): │
│ 881 │ │ │ │ response = await self.schedule_acompletion(**kwargs) │
│ 882 │ │ │ else: │
│ ❱ 883 │ │ │ │ response = await self.async_function_with_fallbacks(**kwargs) │
│ 884 │ │ │ end_time = time.time() │
│ 885 │ │ │ _duration = end_time - start_time │
│ 886 │ │ │ asyncio.create_task( │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:3079 in │
│ async_function_with_fallbacks │
│ │
│ 3076 │ │ │ │ │ │ ) │
│ 3077 │ │ │ │ │ ) │
│ 3078 │ │ │ │
│ ❱ 3079 │ │ │ raise original_exception │
│ 3080 │ │
│ 3081 │ def _handle_mock_testing_fallbacks( │
│ 3082 │ │ self, │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:2893 in │
│ async_function_with_fallbacks │
│ │
│ 2890 │ │ │ │ │ *args, **kwargs, mock_timeout=mock_timeout │
│ 2891 │ │ │ │ ) │
│ 2892 │ │ │ else: │
│ ❱ 2893 │ │ │ │ response = await self.async_function_with_retries(*args, **kwargs) │
│ 2894 │ │ │ verbose_router_logger.debug(f"Async Response: {response}") │
│ 2895 │ │ │ response = add_fallback_headers_to_response( │
│ 2896 │ │ │ │ response=response, │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:3269 in │
│ async_function_with_retries │
│ │
│ 3266 │ │ │ │ setattr(original_exception, "max_retries", num_retries) │
│ 3267 │ │ │ │ setattr(original_exception, "num_retries", current_attempt) │
│ 3268 │ │ │ │
│ ❱ 3269 │ │ │ raise original_exception │
│ 3270 │ │
│ 3271 │ async def make_call(self, original_function: Any, *args, **kwargs): │
│ 3272 │ │ """ │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:3162 in │
│ async_function_with_retries │
│ │
│ 3159 │ │ │ │ model_group=model_group, kwargs=kwargs │
│ 3160 │ │ │ ) │
│ 3161 │ │ │ # if the function call is successful, no exception will be raised and we'll │
│ break out of the loop │
│ ❱ 3162 │ │ │ response = await self.make_call(original_function, *args, **kwargs) │
│ 3163 │ │ │ response = add_retry_headers_to_response( │
│ 3164 │ │ │ │ response=response, attempted_retries=0, max_retries=None │
│ 3165 │ │ │ ) │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:3278 in make_call │
│ │
│ 3275 │ │ model_group = kwargs.get("model") │
│ 3276 │ │ response = original_function(*args, **kwargs) │
│ 3277 │ │ if inspect.iscoroutinefunction(response) or inspect.isawaitable(response): │
│ ❱ 3278 │ │ │ response = await response │
│ 3279 │ │ ## PROCESS RESPONSE HEADERS │
│ 3280 │ │ response = await self.set_response_headers( │
│ 3281 │ │ │ response=response, model_group=model_group │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:1045 in _acompletion │
│ │
│ 1042 │ │ │ ) │
│ 1043 │ │ │ if model_name is not None: │
│ 1044 │ │ │ │ self.fail_calls[model_name] += 1 │
│ ❱ 1045 │ │ │ raise e │
│ 1046 │ │
│ 1047 │ def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None: │
│ 1048 │ │ """ │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:1004 in _acompletion │
│ │
│ 1001 │ │ │ │ │ parent_otel_span=parent_otel_span, │
│ 1002 │ │ │ │ ) │
│ 1003 │ │ │ │ │
│ ❱ 1004 │ │ │ │ response = await _response │
│ 1005 │ │ │ │
│ 1006 │ │ │ ## CHECK CONTENT FILTER ERROR ## │
│ 1007 │ │ │ if isinstance(response, ModelResponse): │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\utils.py:1397 in wrapper_async │
│ │
│ 1394 │ │ │ │
│ 1395 │ │ │ timeout = _get_wrapper_timeout(kwargs=kwargs, exception=e) │
│ 1396 │ │ │ setattr(e, "timeout", timeout) │
│ ❱ 1397 │ │ │ raise e │
│ 1398 │ │
│ 1399 │ is_coroutine = inspect.iscoroutinefunction(original_function) │
│ 1400 │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\utils.py:1287 in wrapper_async │
│ │
│ 1284 │ │ │ ) │
│ 1285 │ │ │ │
│ 1286 │ │ │ ## Add response to cache │
│ ❱ 1287 │ │ │ await _llm_caching_handler.async_set_cache( │
│ 1288 │ │ │ │ result=result, │
│ 1289 │ │ │ │ original_function=original_function, │
│ 1290 │ │ │ │ kwargs=kwargs, │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\caching\caching_handler.py:684 in │
│ async_set_cache │
│ │
│ 681 │ │ │ ) │
│ 682 │ │ ) │
│ 683 │ │ # [OPTIONAL] ADD TO CACHE │
│ ❱ 684 │ │ if self._should_store_result_in_cache( │
│ 685 │ │ │ original_function=original_function, kwargs=new_kwargs │
│ 686 │ │ ): │
│ 687 │ │ │ if ( │
│ │
│ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\caching\caching_handler.py:758 in │
│ _should_store_result_in_cache │
│ │
│ 755 │ │ │ (litellm.cache is not None) │
│ 756 │ │ │ and litellm.cache.supported_call_types is not None │
│ 757 │ │ │ and (str(original_function.__name__) in litellm.cache.supported_call_types) │
│ ❱ 758 │ │ │ and (kwargs.get("cache", {}).get("no-store", False) is not True) │
│ 759 │ │ ) │
│ 760 │ │
│ 761 │ def _is_call_type_supported_by_cache( │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'bool' object has no attribute 'get'
Any solution for these problems ?
To address the issues you're encountering with the LiteLLM library:
-
Streaming Parameter: Ensure that the
streamparameter is correctly set within thelitellm_paramsdictionary. If it appears asFalsein verbose mode despite being set toTrue, it might be due to an incorrect configuration or a bug in the library. Double-check that the parameter is correctly spelled and placed within the dictionary. If the issue persists, consider checking the library's documentation or source code for any updates or known issues related to this parameter. -
Caching Configuration: The error you're encountering (
AttributeError: 'bool' object has no attribute 'get') suggests that thecacheparameter is being treated as a boolean instead of a dictionary. Ensure that you are correctly instantiating theCacheobject and assigning it tolitellm.cache. Here's how you should do it:from litellm.caching.caching import Cache # Create and assign a Cache object litellm.cache = Cache()Make sure that the
Cacheobject is properly imported and used. If the error persists, it might be due to an incorrect implementation or a bug in the library's caching mechanism. You may need to review the library's documentation or source code for further guidance on the correct usage of the caching feature.
By ensuring these configurations are correctly set, you should be able to resolve the issues with streaming and caching in the LiteLLM library. If problems continue, consider checking for any updates or changes in the library's API that might affect these configurations.
To continue talking to Dosu, mention @dosu.
Help Dosu learn! Give it feedback: Great Response | Irrelevant Answer | Incorrect Sources | Too Verbose | Hallucination | Bug Report | Other
I guess Cache() object is not correctly instantiated. Is there any documentation for that ?
Hello @Snikch63200
Please notice that the LLMs we use in paperqa comes from lmi, which is a wrapper over litellm. Of course the usage might be very similar, but there are a few differences.
That said, I am assuming you want to locally cache, is that right? This page shows how to do it with litellm.
Please check this code snippet to get it done in lmi:
from lmi import LiteLLMModel
from aviary.core import Message
import litellm
from litellm.caching.caching import Cache
# Set up caching -- I did not need this setup. But since it's in the litellm page, I left it here
# litellm.cache = Cache()
llm = LiteLLMModel(name="gpt-4o-mini") # Change it to your llm
r = await llm.call_single(
messages=[Message(role="user", content="Tell me a very very very random joke")],
caching=True
)
print(r.text)
I could rerun it multiple times and get the same output. Let me know if it works for you.
Hello @maykcaldas
Thanks for your answer. Your snippet perfectly works with llama3.3.
So, I think there could be two explainations :
- PQA overrides these parameters in code
- parameter are not passed by LiteLLM API ('stream' appears as 'False' in lite llm verbose mode but not 'caching')
Best regards.
Hey @Snikch63200
I don't think PQA uses caching anywhere. But you're right about streaming.
PQA uses LLMs from lmi. On lmi, you can see an example here on how we override the stream option depending on the use case.
Could you tell me why do you need the answer to be streamed back to you? In our ecosystem, we use lmi to abstract all the LLM operations and PQA need to deal with the curated LLMResults object only.
Best,