R2R icon indicating copy to clipboard operation
R2R copied to clipboard

free provision for embedding providers

Open zhengwj533 opened this issue 8 months ago • 0 comments
trafficstars

Question:

Unable to provide a new embedding provider. To adapt to the new embedding provider, multiple source code files need to be modified.

Hope:

Users can directly provide an embedding provider and specify in the configuration file to use the newly provided provider

resolvent:

  1. Modify the check logic of the embedding config and check if the configuration file is working properly based on the user-defined provider’s supported_providers() function.
  2. Modify the definition of R2RPProviders, relax the type restrictions of embedding and LLM, that is, use parent class types to avoid synchronously modifying R2RPProviders every time a new provider is added.
  3. Modifying the factory also relaxes the type restrictions.
  4. A sample provider will be like this.
import logging
import os
from typing import Any

from openai import AsyncOpenAI, AuthenticationError, OpenAI

from core.base import (
    ChunkSearchResult,
    EmbeddingConfig,
    EmbeddingProvider,
)

logger = logging.getLogger()


class DashscopeEmbeddingProvider(EmbeddingProvider):

    def __init__(self, config: EmbeddingConfig):
        ...

    @classmethod
    def supported_providers(self) -> list[str]:
        return ["dashscope"]

    def _get_embedding_kwargs(self, **kwargs):
        embedding_kwargs = {
            "model": self.base_model,
            "dimensions": self.base_dimension,
        }
        embedding_kwargs.update(kwargs)
        return embedding_kwargs

    async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
        # fix: request one 'text' in base provider
        texts = task.get("text", task.get("texts"))
        one_text = isinstance(texts, str)
        if one_text:
            texts = [texts]

        kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))

        try:
            response = await self.aclient.embeddings.create(
                input=texts,
                **kwargs,
            )
            embeddings = [data.embedding for data in response.data]
            return embeddings[0] if one_text else embeddings
...

[!IMPORTANT] Enable user-defined embedding providers by dynamically discovering subclasses and relaxing type restrictions in embedding.py, abstractions.py, and factory.py.

  • Behavior:
    • supported_providers() in embedding.py now dynamically discovers subclasses of EmbeddingProvider to include user-defined providers.
    • R2RProviders in abstractions.py and create_providers() in factory.py now use base class types for embedding and llm to allow custom providers.
  • Code Structure:
    • Added get_all_subclasses() in embedding.py to recursively find all subclasses of a given class.
    • Removed specific provider types in R2RProviders and create_providers() in factory.py, using EmbeddingProvider and CompletionProvider instead.
  • Misc:
    • Added a placeholder supported_providers() method in EmbeddingProvider class.

This description was created by Ellipsis for befec6d78392d2ebaa5ac930301f0b659d3ba63a. It will automatically update as commits are pushed.

zhengwj533 avatar Mar 06 '25 02:03 zhengwj533