dify icon indicating copy to clipboard operation
dify copied to clipboard

Support customizing embedding model parameters

Open DresAaron opened this issue 1 year ago • 1 comments

Self Checks

  • [X] I have searched for existing issues search for existing issues, including closed ones.
  • [X] I confirm that I am using English to submit this report (我已阅读并同意 Language Policy).
  • [X] [FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)
  • [X] Please do not modify this template :) and fill in all the required fields.

1. Is this request related to a challenge you're experiencing? Tell me about your story.

Jina's next text embedding model will support task oriented embedding and embedding dimension controlling, user needs to provide task_type and dimensions in the payload. Base on your documentation of model invocation implementation, extra parameters are not supported for now.

Could you please provide a way to extend or customize invoke parameters?

Current embedding parameters:

{
  "input":  [...],
  "model": "..."
}

Expected embedding parameters:

{
  "input": [...],
  "model": "...",
  "task_type": "...",
  ...
}

Similar issue: #914

2. Additional context or comments

No response

3. Can you help us with this feature?

  • [ ] I am interested in contributing to this feature.

DresAaron avatar Sep 13 '24 12:09 DresAaron

@DresAaron Hello there! I'm here to help you with any bugs, questions, or becoming a contributor. Let's work together to solve this issue.

To extend or customize the invoke parameters for embedding models to include task_type and dimensions, you need to modify the _invoke method in the TextEmbeddingModel base class. Here is how you can do it:

def _invoke(self, model: str, credentials: dict,
            texts: list[str], task_type: Optional[str] = None, dimensions: Optional[int] = None, user: Optional[str] = None) \
        -> TextEmbeddingResult:
    """
    Invoke text embedding model

    :param model: model name
    :param credentials: model credentials
    :param texts: texts to embed
    :param task_type: type of task for embedding
    :param dimensions: number of dimensions for the embedding
    :param user: unique user id
    :return: embeddings result
    """
    client_config = Config(region_name=credentials["aws_region"])

    bedrock_runtime = boto3.client(
        service_name="bedrock-runtime",
        config=client_config,
        aws_access_key_id=credentials.get("aws_access_key_id"),
        aws_secret_access_key=credentials.get("aws_secret_access_key"),
    )

    embeddings = []
    token_usage = 0

    model_prefix = model.split(".")[0]

    if model_prefix == "amazon":
        for text in texts:
            body = {
                "inputText": text,
                "taskType": task_type,
                "dimensions": dimensions,
            }
            response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
            embeddings.extend([response_body.get("embedding")])
            token_usage += response_body.get("inputTextTokenCount")
        logger.warning(f"Total Tokens: {token_usage}")
        result = TextEmbeddingResult(
            model=model,
            embeddings=embeddings,
            usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
        )
        return result

    if model_prefix == "cohere":
        input_type = "search_document" if len(texts) > 1 else "search_query"
        for text in texts:
            body = {
                "texts": [text],
                "input_type": input_type,
                "taskType": task_type,
                "dimensions": dimensions,
            }
            response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
            embeddings.extend(response_body.get("embeddings"))
            token_usage += len(text)
        result = TextEmbeddingResult(
            model=model,
            embeddings=embeddings,
            usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
        )
        return result

    # others
    raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
  • Parameters:

    • model (string) Model name
    • credentials (object) Credential information
    • texts (array[string]) List of texts, capable of batch processing
    • task_type (string) [optional] Type of task for embedding
    • dimensions (int) [optional] Number of dimensions for the embedding
    • user (string) [optional] Unique identifier of the user
  • Returns:

    TextEmbeddingResult entity [1][2].

To continue talking to Dosu, mention @dosu.

dosubot[bot] avatar Sep 13 '24 12:09 dosubot[bot]