Model Suggestion
import os
import yaml
from loguru import logger
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.chat_models import ChatLiteLLMRouter
import litellm
litellm.set_verbose=True
# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
class LiteLLMRouterFactory:
@staticmethod
def create_model_list():
config = yaml.load(open('models_support.yaml', 'r', encoding='utf-8'), Loader=yaml.FullLoader)
model_list = []
is_compatible = os.getenv("OPENAI_API_KEY") and os.getenv("OPENAI_API_BASE")
if is_compatible:
model_list.append({
"model_name": "gpt-4o-mini",
"provider": "OpenAI",
"litellm_params": {
"model": "gpt-4o-mini",
"api_key": os.getenv("OPENAI_API_KEY"),
"api_base": os.getenv("OPENAI_API_BASE"),
"streaming": True,
}
})
if os.getenv("GROQ_API_KEY"):
model_list.append({
"model_name": "llama-3.1-70b",
"provider": "Groq",
"litellm_params": {
"model": "groq/llama-3.1-70b",
"api_key": os.getenv("GROQ_API_KEY"),
}
})
elif is_compatible:
model_list.append({
"model_name": "llama-3.1-70b",
"provider": "Groq",
"litellm_params": {
"model": "openai/llama-3.1-70b",
"api_key": os.getenv("OPENAI_API_KEY"),
"api_base": os.getenv("OPENAI_API_BASE"),
}
})
if os.getenv("GOOGLE_API_KEY"):
model_list.append({
"model_name": "gemini-1.5-flash",
"provider": "Google",
"litellm_params": {
"model": "gemini/gemini-1.5-flash",
"api_key": os.getenv("GOOGLE_API_KEY"),
"streaming": True,
}
})
elif is_compatible:
model_list.append({
"model_name": "gemini-1.5-flash",
"provider": "Google",
"litellm_params": {
"model": "openai/gemini-1.5-flash",
"api_key": os.getenv("OPENAI_API_KEY"),
"api_base": os.getenv("OPENAI_API_BASE"),
"streaming": True,
}
})
if os.getenv("OLLAMA_BASE_URL"):
model_list.append({
"model_name": "llama3",
"provider": "Ollama",
"litellm_params": {
"model": "ollama_chat/llama3",
"api_base": os.getenv("OLLAMA_BASE_URL"),
"stream": True,
}
})
elif is_compatible:
model_list.append({
"model_name": "llama3",
"provider": "Ollama",
"litellm_params": {
"model": "openai/llama3",
"api_key": os.getenv("OPENAI_API_KEY"),
"api_base": os.getenv("OPENAI_API_BASE"),
"streaming": True,
}
})
if os.getenv("ANTHROPIC_API_KEY"):
model_list.append({
"model_name": "claude-3-5-sonnet",
"provider": "Anthropic",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20241022",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
"temperature": 0.5,
"streaming": True
}
})
elif is_compatible:
model_list.append({
"model_name": "claude-3-5-sonnet",
"provider": "Anthropic",
"litellm_params": {
"model": "openai/claude-3-5-sonnet-20241022",
"api_key": os.getenv("OPENAI_API_KEY"),
"api_base": os.getenv("OPENAI_API_BASE"),
"temperature": 0.5,
"streaming": True,
}
})
if os.getenv("MISTRAL_API_KEY"):
model_list.append({
"model_name": "mistral-medium",
"provider": "Mistral",
"litellm_params": {
"model": "mistral/mistral-medium",
"api_key": os.getenv("MISTRAL_API_KEY"),
}
})
elif is_compatible:
model_list.append({
"model_name": "mistral-medium",
"provider": "Mistral",
"litellm_params": {
"model": "openai/mistral-medium",
"api_key": os.getenv("OPENAI_API_KEY"),
"api_base": os.getenv("OPENAI_API_BASE"),
"streaming": True,
}
})
if os.getenv("USE_AWS_BEDROCK") == "true":
model_list.append({
"model_name": "bedrock-haiku",
"provider":"AWS",
"litellm_params": {
"model": "anthropic.claude-3-5-haiku-20241022-v1:0",
"model_id": "anthropic.claude-3-5-haiku-20241022-v1:0",
"temperature": 0.5,
}
})
return model_list
class ModelManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
cls._instance.model_list = LiteLLMRouterFactory.create_model_list()
cls._instance.models = cls._instance.initialize_models(cls._instance.model_list)
return cls._instance
def initialize_models(self, model_list) -> dict[str, BaseChatModel]:
models: dict[str, BaseChatModel] = {}
try:
router = litellm.Router(model_list=model_list)
model = ChatLiteLLMRouter(router=router)
except Exception as e:
logger.error(f"Error loading models: {e}")
model_list = []
for model_info in model_list:
models[model_info["litellm_params"]["model"]] = model
if not models:
logger.error("No LLM available. Please set environment variables to enable at least one LLM.")
if os.getenv("MODE") == "dev":
logger.error("FastAPI initialization failed. Please use Ctrl + C to exit uvicorn.")
exit(1)
return models
Hi, I just warped the models.py with LiteLLM support. It simplified a lot of work to deal with various LLM providers. See if anyone needs it.
import os import yaml from loguru import logger from langchain_core.language_models.chat_models import BaseChatModel from langchain_community.chat_models import ChatLiteLLMRouter import litellm litellm.set_verbose=True # NOTE: models with streaming=True will send tokens as they are generated # if the /stream endpoint is called with stream_tokens=True (the default) class LiteLLMRouterFactory: @staticmethod def create_model_list(): config = yaml.load(open('models_support.yaml', 'r', encoding='utf-8'), Loader=yaml.FullLoader) model_list = [] is_compatible = os.getenv("OPENAI_API_KEY") and os.getenv("OPENAI_API_BASE") if is_compatible: model_list.append({ "model_name": "gpt-4o-mini", "provider": "OpenAI", "litellm_params": { "model": "gpt-4o-mini", "api_key": os.getenv("OPENAI_API_KEY"), "api_base": os.getenv("OPENAI_API_BASE"), "streaming": True, } }) if os.getenv("GROQ_API_KEY"): model_list.append({ "model_name": "llama-3.1-70b", "provider": "Groq", "litellm_params": { "model": "groq/llama-3.1-70b", "api_key": os.getenv("GROQ_API_KEY"), } }) elif is_compatible: model_list.append({ "model_name": "llama-3.1-70b", "provider": "Groq", "litellm_params": { "model": "openai/llama-3.1-70b", "api_key": os.getenv("OPENAI_API_KEY"), "api_base": os.getenv("OPENAI_API_BASE"), } }) if os.getenv("GOOGLE_API_KEY"): model_list.append({ "model_name": "gemini-1.5-flash", "provider": "Google", "litellm_params": { "model": "gemini/gemini-1.5-flash", "api_key": os.getenv("GOOGLE_API_KEY"), "streaming": True, } }) elif is_compatible: model_list.append({ "model_name": "gemini-1.5-flash", "provider": "Google", "litellm_params": { "model": "openai/gemini-1.5-flash", "api_key": os.getenv("OPENAI_API_KEY"), "api_base": os.getenv("OPENAI_API_BASE"), "streaming": True, } }) if os.getenv("OLLAMA_BASE_URL"): model_list.append({ "model_name": "llama3", "provider": "Ollama", "litellm_params": { "model": "ollama_chat/llama3", "api_base": os.getenv("OLLAMA_BASE_URL"), "stream": True, } }) elif is_compatible: model_list.append({ "model_name": "llama3", "provider": "Ollama", "litellm_params": { "model": "openai/llama3", "api_key": os.getenv("OPENAI_API_KEY"), "api_base": os.getenv("OPENAI_API_BASE"), "streaming": True, } }) if os.getenv("ANTHROPIC_API_KEY"): model_list.append({ "model_name": "claude-3-5-sonnet", "provider": "Anthropic", "litellm_params": { "model": "anthropic/claude-3-5-sonnet-20241022", "api_key": os.getenv("ANTHROPIC_API_KEY"), "temperature": 0.5, "streaming": True } }) elif is_compatible: model_list.append({ "model_name": "claude-3-5-sonnet", "provider": "Anthropic", "litellm_params": { "model": "openai/claude-3-5-sonnet-20241022", "api_key": os.getenv("OPENAI_API_KEY"), "api_base": os.getenv("OPENAI_API_BASE"), "temperature": 0.5, "streaming": True, } }) if os.getenv("MISTRAL_API_KEY"): model_list.append({ "model_name": "mistral-medium", "provider": "Mistral", "litellm_params": { "model": "mistral/mistral-medium", "api_key": os.getenv("MISTRAL_API_KEY"), } }) elif is_compatible: model_list.append({ "model_name": "mistral-medium", "provider": "Mistral", "litellm_params": { "model": "openai/mistral-medium", "api_key": os.getenv("OPENAI_API_KEY"), "api_base": os.getenv("OPENAI_API_BASE"), "streaming": True, } }) if os.getenv("USE_AWS_BEDROCK") == "true": model_list.append({ "model_name": "bedrock-haiku", "provider":"AWS", "litellm_params": { "model": "anthropic.claude-3-5-haiku-20241022-v1:0", "model_id": "anthropic.claude-3-5-haiku-20241022-v1:0", "temperature": 0.5, } }) return model_list class ModelManager: _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(ModelManager, cls).__new__(cls) cls._instance.model_list = LiteLLMRouterFactory.create_model_list() cls._instance.models = cls._instance.initialize_models(cls._instance.model_list) return cls._instance def initialize_models(self, model_list) -> dict[str, BaseChatModel]: models: dict[str, BaseChatModel] = {} try: router = litellm.Router(model_list=model_list) model = ChatLiteLLMRouter(router=router) except Exception as e: logger.error(f"Error loading models: {e}") model_list = [] for model_info in model_list: models[model_info["litellm_params"]["model"]] = model if not models: logger.error("No LLM available. Please set environment variables to enable at least one LLM.") if os.getenv("MODE") == "dev": logger.error("FastAPI initialization failed. Please use Ctrl + C to exit uvicorn.") exit(1) return modelsHi, I just warped the models.py with LiteLLM support. It simplified a lot of work to deal with various LLM providers. See if anyone needs it.
There is a bug that langchain simply selects the very first model from the model_list as the default model, see the link github-langchain for more detail. Now I patiently wait for the bug fix.
Hey, I'm not sure if I want to adopt LiteLLM in this repo. I expect in most real usage someone has just a few models they are connecting to so it adds more complexity and dependencies. What do you see as the advantages to this approach?
I’m going to close but if you have a more specific proposal or draft pull request open to discuss.