agent-service-toolkit icon indicating copy to clipboard operation
agent-service-toolkit copied to clipboard

Model Suggestion

Open doncat99 opened this issue 1 year ago • 2 comments

import os
import yaml
from loguru import logger

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.chat_models import ChatLiteLLMRouter
import litellm

litellm.set_verbose=True


# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
class LiteLLMRouterFactory:
    @staticmethod
    def create_model_list():
        config = yaml.load(open('models_support.yaml', 'r', encoding='utf-8'), Loader=yaml.FullLoader)
        model_list = []
        
        is_compatible = os.getenv("OPENAI_API_KEY") and os.getenv("OPENAI_API_BASE")

        if is_compatible:
            model_list.append({
                "model_name": "gpt-4o-mini",
                "provider": "OpenAI",
                "litellm_params": {
                    "model": "gpt-4o-mini",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("GROQ_API_KEY"):
            model_list.append({
                "model_name": "llama-3.1-70b",
                "provider": "Groq",
                "litellm_params": {
                    "model": "groq/llama-3.1-70b",
                    "api_key": os.getenv("GROQ_API_KEY"),
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "llama-3.1-70b",
                "provider": "Groq",
                "litellm_params": {
                    "model": "openai/llama-3.1-70b",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                }
            })

        if os.getenv("GOOGLE_API_KEY"):
            model_list.append({
                "model_name": "gemini-1.5-flash",
                "provider": "Google",
                "litellm_params": {
                    "model": "gemini/gemini-1.5-flash",
                    "api_key": os.getenv("GOOGLE_API_KEY"),
                    "streaming": True,
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "gemini-1.5-flash",
                "provider": "Google",
                "litellm_params": {
                    "model": "openai/gemini-1.5-flash",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("OLLAMA_BASE_URL"):
            model_list.append({
                "model_name": "llama3",
                "provider": "Ollama",
                "litellm_params": {
                    "model": "ollama_chat/llama3",
                    "api_base": os.getenv("OLLAMA_BASE_URL"),
                    "stream": True,
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "llama3",
                "provider": "Ollama",
                "litellm_params": {
                    "model": "openai/llama3",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("ANTHROPIC_API_KEY"):
            model_list.append({
                "model_name": "claude-3-5-sonnet",
                "provider": "Anthropic",
                "litellm_params": {
                    "model": "anthropic/claude-3-5-sonnet-20241022",
                    "api_key": os.getenv("ANTHROPIC_API_KEY"),
                    "temperature": 0.5,
                    "streaming": True
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "claude-3-5-sonnet",
                "provider": "Anthropic",
                "litellm_params": {
                    "model": "openai/claude-3-5-sonnet-20241022",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "temperature": 0.5,
                    "streaming": True,
                }
            })

        if os.getenv("MISTRAL_API_KEY"):
            model_list.append({
                "model_name": "mistral-medium",
                "provider": "Mistral",
                "litellm_params": {
                    "model": "mistral/mistral-medium",
                    "api_key": os.getenv("MISTRAL_API_KEY"),
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "mistral-medium",
                "provider": "Mistral",
                "litellm_params": {
                    "model": "openai/mistral-medium",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })
            
        if os.getenv("USE_AWS_BEDROCK") == "true":
            model_list.append({
                "model_name": "bedrock-haiku",
                "provider":"AWS",
                "litellm_params": {
                    "model": "anthropic.claude-3-5-haiku-20241022-v1:0",
                    "model_id": "anthropic.claude-3-5-haiku-20241022-v1:0",
                    "temperature": 0.5,
                }
            })

        return model_list


class ModelManager:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
            cls._instance.model_list = LiteLLMRouterFactory.create_model_list()
            cls._instance.models = cls._instance.initialize_models(cls._instance.model_list)
        return cls._instance
    
    def initialize_models(self, model_list) -> dict[str, BaseChatModel]:
        models: dict[str, BaseChatModel] = {}

        try:    
            router = litellm.Router(model_list=model_list)
            model = ChatLiteLLMRouter(router=router)

        except Exception as e:
            logger.error(f"Error loading models: {e}")
            model_list = []

        for model_info in model_list:
            models[model_info["litellm_params"]["model"]] = model

        if not models:
            logger.error("No LLM available. Please set environment variables to enable at least one LLM.")
            if os.getenv("MODE") == "dev":
                logger.error("FastAPI initialization failed. Please use Ctrl + C to exit uvicorn.")
            exit(1)

        return models

Hi, I just warped the models.py with LiteLLM support. It simplified a lot of work to deal with various LLM providers. See if anyone needs it.

doncat99 avatar Nov 14 '24 04:11 doncat99

import os
import yaml
from loguru import logger

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.chat_models import ChatLiteLLMRouter
import litellm

litellm.set_verbose=True


# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
class LiteLLMRouterFactory:
    @staticmethod
    def create_model_list():
        config = yaml.load(open('models_support.yaml', 'r', encoding='utf-8'), Loader=yaml.FullLoader)
        model_list = []
        
        is_compatible = os.getenv("OPENAI_API_KEY") and os.getenv("OPENAI_API_BASE")

        if is_compatible:
            model_list.append({
                "model_name": "gpt-4o-mini",
                "provider": "OpenAI",
                "litellm_params": {
                    "model": "gpt-4o-mini",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("GROQ_API_KEY"):
            model_list.append({
                "model_name": "llama-3.1-70b",
                "provider": "Groq",
                "litellm_params": {
                    "model": "groq/llama-3.1-70b",
                    "api_key": os.getenv("GROQ_API_KEY"),
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "llama-3.1-70b",
                "provider": "Groq",
                "litellm_params": {
                    "model": "openai/llama-3.1-70b",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                }
            })

        if os.getenv("GOOGLE_API_KEY"):
            model_list.append({
                "model_name": "gemini-1.5-flash",
                "provider": "Google",
                "litellm_params": {
                    "model": "gemini/gemini-1.5-flash",
                    "api_key": os.getenv("GOOGLE_API_KEY"),
                    "streaming": True,
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "gemini-1.5-flash",
                "provider": "Google",
                "litellm_params": {
                    "model": "openai/gemini-1.5-flash",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("OLLAMA_BASE_URL"):
            model_list.append({
                "model_name": "llama3",
                "provider": "Ollama",
                "litellm_params": {
                    "model": "ollama_chat/llama3",
                    "api_base": os.getenv("OLLAMA_BASE_URL"),
                    "stream": True,
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "llama3",
                "provider": "Ollama",
                "litellm_params": {
                    "model": "openai/llama3",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })

        if os.getenv("ANTHROPIC_API_KEY"):
            model_list.append({
                "model_name": "claude-3-5-sonnet",
                "provider": "Anthropic",
                "litellm_params": {
                    "model": "anthropic/claude-3-5-sonnet-20241022",
                    "api_key": os.getenv("ANTHROPIC_API_KEY"),
                    "temperature": 0.5,
                    "streaming": True
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "claude-3-5-sonnet",
                "provider": "Anthropic",
                "litellm_params": {
                    "model": "openai/claude-3-5-sonnet-20241022",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "temperature": 0.5,
                    "streaming": True,
                }
            })

        if os.getenv("MISTRAL_API_KEY"):
            model_list.append({
                "model_name": "mistral-medium",
                "provider": "Mistral",
                "litellm_params": {
                    "model": "mistral/mistral-medium",
                    "api_key": os.getenv("MISTRAL_API_KEY"),
                }
            })
        elif is_compatible:
            model_list.append({
                "model_name": "mistral-medium",
                "provider": "Mistral",
                "litellm_params": {
                    "model": "openai/mistral-medium",
                    "api_key": os.getenv("OPENAI_API_KEY"),
                    "api_base": os.getenv("OPENAI_API_BASE"),
                    "streaming": True,
                }
            })
            
        if os.getenv("USE_AWS_BEDROCK") == "true":
            model_list.append({
                "model_name": "bedrock-haiku",
                "provider":"AWS",
                "litellm_params": {
                    "model": "anthropic.claude-3-5-haiku-20241022-v1:0",
                    "model_id": "anthropic.claude-3-5-haiku-20241022-v1:0",
                    "temperature": 0.5,
                }
            })

        return model_list


class ModelManager:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
            cls._instance.model_list = LiteLLMRouterFactory.create_model_list()
            cls._instance.models = cls._instance.initialize_models(cls._instance.model_list)
        return cls._instance
    
    def initialize_models(self, model_list) -> dict[str, BaseChatModel]:
        models: dict[str, BaseChatModel] = {}

        try:    
            router = litellm.Router(model_list=model_list)
            model = ChatLiteLLMRouter(router=router)

        except Exception as e:
            logger.error(f"Error loading models: {e}")
            model_list = []

        for model_info in model_list:
            models[model_info["litellm_params"]["model"]] = model

        if not models:
            logger.error("No LLM available. Please set environment variables to enable at least one LLM.")
            if os.getenv("MODE") == "dev":
                logger.error("FastAPI initialization failed. Please use Ctrl + C to exit uvicorn.")
            exit(1)

        return models

Hi, I just warped the models.py with LiteLLM support. It simplified a lot of work to deal with various LLM providers. See if anyone needs it.

There is a bug that langchain simply selects the very first model from the model_list as the default model, see the link github-langchain for more detail. Now I patiently wait for the bug fix.

doncat99 avatar Nov 14 '24 07:11 doncat99

Hey, I'm not sure if I want to adopt LiteLLM in this repo. I expect in most real usage someone has just a few models they are connecting to so it adds more complexity and dependencies. What do you see as the advantages to this approach?

JoshuaC215 avatar Nov 27 '24 05:11 JoshuaC215

I’m going to close but if you have a more specific proposal or draft pull request open to discuss.

JoshuaC215 avatar Jan 14 '25 16:01 JoshuaC215