chroma icon indicating copy to clipboard operation
chroma copied to clipboard

[Bug]: using supplied embeddings for add and query broken

Open schwebke opened this issue 1 year ago • 2 comments

What happened?

Using supplied embeddings adding or querying a collection leads to an exception.

Code snippet:

...
collection.add(embeddings=seg_embeddings, ids=...)
...
search_result = collection.query(query_embeddings=q_embeddings, n_results=5)
...

Raised exception:

  File "/.../chromadb/api/models/Collection.py", line 99, in add
    embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
...
  File "/.../chromadb/api/models/Collection.py", line 208, in query
    query_embeddings = maybe_cast_one_to_many(query_embeddings) if query_embeddings else None
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Versions

Chroma current dev (https://github.com/chroma-core/chroma/commit/79c891f8f597dad8bd3eb5a42645cb99ec553440) Python 3.8.16

Relevant log output

No response

schwebke avatar Apr 22 '23 05:04 schwebke

@schwebke thanks for reporting this and opening a PR - what specifically is the failure case?

jeffchuber avatar Apr 22 '23 05:04 jeffchuber

@jeffchuber here the complete example to reproduce: https://gist.github.com/schwebke/ddb72b4ce7b8a612b83b4e6dbebfbfb5

Without these fixes the mentioned exceptions are thrown. With fix applied it runs without unhandled exceptions, yielding this output:

...
Embedding time: 2.0751430988311768 seconds
...
Index time: 0.021298646926879883 seconds
lady
woman

schwebke avatar Apr 22 '23 05:04 schwebke

@schwebke getting merged today, thank you for fixing this! 👏

jeffchuber avatar May 08 '23 17:05 jeffchuber

@schwebke here is the solution:

the type out of InstructorEmbedding is not compatible with what we expect. Simply cast it to a list

search_result = collection.query(query_embeddings=list(q_embeddings), n_results=2)
ids = [int(s) for s in search_result["ids"][0]]
for i in ids:
    print(docs[i])

jeffchuber avatar May 08 '23 18:05 jeffchuber