chroma
chroma copied to clipboard
collection.query() should return all elements if n_results is greater than the total number of elements in the collection
import chromadb
from sentence_transformers import SentenceTransformer
inputs = [
"전문가들, 대유행 이후 원격 근무 수요 증가 예상",
"SNS 사용과 정신 건강 문제 간의 연관성 발견한 새로운 연구",
"지속 가능한 패션, 젊은 세대들 사이에서 인기 증가",
"인공지능 기술 발전으로 의료 진단과 치료에 대한 희망 제시",
"환경 활동가들, 정부의 플라스틱 폐기물 감소에 대한 조치를 촉구"
]
client = chromadb.Client()
collection = client.get_or_create_collection("chat_logs")
collection.add(
documents=inputs,
metadatas=[{'user': 'david'} for _ in inputs],
ids=[f'uid_{i}' for i in range(len(inputs))]
)
collection.query(query_texts=".", where={"user": 'david'}, n_results=10)
the code above raises exception complaining n_results is greater than total number of elements in given collection
237 def _query(
...
495 results = self.get(
496 collection_uuid=collection_uuid, where=where, where_document=where_document
497 )
NotEnoughElementsException: Number of requested results 10 cannot be greater than number of elements in index 5
I think this should fail in more graceful way. if number of elements < n_results, then all elements can be returned instead of raise exception. otherwise, at least, there should be sensible way to check the number of elements prior to call query()
Why didn't the authorities fix the problem
Great point @fritzprix - this is being worked on upstream inside hnswlib
- https://github.com/nmslib/hnswlib/issues/444#issuecomment-1493224811 - we also may handle it at the logic layer.
You can check collection.count()
to get the total number of elements in a collection. Or you can do collection.get(**filter logic**)
and do a len()
on that. Not awesome solutions and we are working on it! :)
Hi, I am new here and want to contribute to this issue, I have found that the above error is thrown by the 'get_nearest_neighbours' function. So can I edit the function setting the value of k to size of collection if it exceeds the size of collection? and One more question can my questions be seen by the authors of the Repo? and if not then how can I reach them?
Thank you for any help you can provide
import chromadb
client = chromadb.Client()
collection = client.create_collection("test")
collection.add(
embeddings=[
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
],
metadatas=[
{"uri": "img1.png", "style": "style1"},
{"uri": "img2.png", "style": "style2"},
{"uri": "img3.png", "style": "style1"},
{"uri": "img4.png", "style": "style1"},
{"uri": "img5.png", "style": "style1"},
{"uri": "img6.png", "style": "style1"},
{"uri": "img7.png", "style": "style1"},
{"uri": "img8.png", "style": "style1"},
],
documents=["doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8"],
ids=["id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"],
)
query_result = collection.query(
query_embeddings=[[1.1, 2.3, 3.2], [5.1, 4.3, 2.2]],
n_results=-1,
)
print(query_result)
the code above raises exception complaining TypeError: knn_query(): incompatible function arguments.
query_result = collection.query(
...
File "\chroma\chromadb\db\index\hnswlib.py", line 248, in get_nearest_neighbors
database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function)
TypeError: knn_query(): incompatible function arguments. The following argument types are supported:
1. (self: hnswlib.Index, data: object, k: int = 1, num_threads: int = -1, filter: Callable[[int], bool] = None) -> object
Invoked with: <hnswlib.Index(space='l2', dim=3)>, [[1.1, 2.3, 3.2], [5.1, 4.3, 2.2]]; kwargs: k=-1, filter=None
I think this should fail in more graceful way and can be resolved with this issue#301.
if n_results < 0
, then there should be a way to check the n_results prior to call query()
and AssertionError should be raised instead of TypeError: knn_query(): incompatible function arguments.
while this is an issue, the solution is pretty much a oneliner:
# Query the collection to get the 5 most relevant results
count = collection.count()
results = collection.query(query_texts=[query],
n_results=min(5, count) # here is the fix
)