chroma icon indicating copy to clipboard operation
chroma copied to clipboard

[Bug]: Too strict type requirement on EmbeddingFunction

Open LumenYoung opened this issue 1 year ago • 5 comments

What happened?

Dear team on Chroma,

Strictly speaking, this is not a bug rather the constraint chroma imposed. However, I find that chroma has a too strict EmbeddingFunction typing which is not suitable for my usecase and it prevents me from doing the proper thing.

In short, chroma only allow a single D type variable (either image or Docuement) to be pasted to the EmbeddingFunction. However, I'm using both image and text to create a joint embedding with llava model. So this restriction makes it very hard to use the chroma vector store.

The following is my definition for the embedding function, which makes total sense if you were trying to generate an embedding from multiple modalities.

class LlavaEmbedding(EmbeddingFunction):
    def __call__(
        self, texts: Documents, images: Optional[List[Dict[str, str]]] = None
    ) -> Embeddings:
        # embed the documents somehow

        if images is None:
            embeddings = [get_embedding_from_llava(text)[0] for text in texts]

        assert len(texts) == len(images), "text and image length must be equal"
        embeddings = []
        for text, imgs in zip(texts, images):
            embeddings.append(get_embedding_from_llava(text, file=imgs))

        return embeddings

But the typing is too strict for this kind of proper usage:

class EmbeddingFunction(Protocol[D]):
    def __call__(self, input: D) -> Embeddings:
        ...


def validate_embedding_function(
    embedding_function: EmbeddingFunction[Embeddable],
) -> None:
    function_signature = signature(
        embedding_function.__class__.__call__
    ).parameters.keys()
    protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys()

    if not function_signature == protocol_signature:
        raise ValueError(
            f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n"
            "Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n"
            "Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
        )

I would suggest allowing additional keyword arguments to be pasted to EmbeddingFunction. The signature would be

class EmbeddingFunction(Protocol[D]):
    def __call__(self, input: D, **kwargs) -> Embeddings:

Versions

Version: 0.4.17

Relevant log output

No response

LumenYoung avatar Nov 20 '23 19:11 LumenYoung