chroma
chroma copied to clipboard
[Bug]: Too strict type requirement on EmbeddingFunction
What happened?
Dear team on Chroma,
Strictly speaking, this is not a bug rather the constraint chroma imposed. However, I find that chroma has a too strict EmbeddingFunction typing which is not suitable for my usecase and it prevents me from doing the proper thing.
In short, chroma only allow a single D type variable (either image or Docuement) to be pasted to the EmbeddingFunction. However, I'm using both image and text to create a joint embedding with llava model. So this restriction makes it very hard to use the chroma vector store.
The following is my definition for the embedding function, which makes total sense if you were trying to generate an embedding from multiple modalities.
class LlavaEmbedding(EmbeddingFunction):
def __call__(
self, texts: Documents, images: Optional[List[Dict[str, str]]] = None
) -> Embeddings:
# embed the documents somehow
if images is None:
embeddings = [get_embedding_from_llava(text)[0] for text in texts]
assert len(texts) == len(images), "text and image length must be equal"
embeddings = []
for text, imgs in zip(texts, images):
embeddings.append(get_embedding_from_llava(text, file=imgs))
return embeddings
But the typing is too strict for this kind of proper usage:
class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D) -> Embeddings:
...
def validate_embedding_function(
embedding_function: EmbeddingFunction[Embeddable],
) -> None:
function_signature = signature(
embedding_function.__class__.__call__
).parameters.keys()
protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys()
if not function_signature == protocol_signature:
raise ValueError(
f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n"
"Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n"
"Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
)
I would suggest allowing additional keyword arguments to be pasted to EmbeddingFunction. The signature would be
class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D, **kwargs) -> Embeddings:
Versions
Version: 0.4.17
Relevant log output
No response