chroma icon indicating copy to clipboard operation
chroma copied to clipboard

collection.query() should return all elements if n_results is greater than the total number of elements in the collection

Open fritzprix opened this issue 1 year ago • 5 comments

import chromadb
from sentence_transformers import SentenceTransformer

inputs = [
    "전문가들, 대유행 이후 원격 근무 수요 증가 예상",
    "SNS 사용과 정신 건강 문제 간의 연관성 발견한 새로운 연구",
    "지속 가능한 패션, 젊은 세대들 사이에서 인기 증가",
    "인공지능 기술 발전으로 의료 진단과 치료에 대한 희망 제시",
    "환경 활동가들, 정부의 플라스틱 폐기물 감소에 대한 조치를 촉구"
]


client = chromadb.Client()
collection = client.get_or_create_collection("chat_logs")
collection.add(
    documents=inputs, 
    metadatas=[{'user': 'david'} for _ in inputs],
    ids=[f'uid_{i}' for i in range(len(inputs))]
)
collection.query(query_texts=".", where={"user": 'david'}, n_results=10)

the code above raises exception complaining n_results is greater than total number of elements in given collection

    237 def _query(
...
    495     results = self.get(
    496         collection_uuid=collection_uuid, where=where, where_document=where_document
    497     )

NotEnoughElementsException: Number of requested results 10 cannot be greater than number of elements in index 5

I think this should fail in more graceful way. if number of elements < n_results, then all elements can be returned instead of raise exception. otherwise, at least, there should be sensible way to check the number of elements prior to call query()

fritzprix avatar Apr 07 '23 13:04 fritzprix

Why didn't the authorities fix the problem

sixdjango avatar Apr 07 '23 15:04 sixdjango

Great point @fritzprix - this is being worked on upstream inside hnswlib - https://github.com/nmslib/hnswlib/issues/444#issuecomment-1493224811 - we also may handle it at the logic layer.

You can check collection.count() to get the total number of elements in a collection. Or you can do collection.get(**filter logic**) and do a len() on that. Not awesome solutions and we are working on it! :)

jeffchuber avatar Apr 07 '23 16:04 jeffchuber

Hi, I am new here and want to contribute to this issue, I have found that the above error is thrown by the 'get_nearest_neighbours' function. So can I edit the function setting the value of k to size of collection if it exceeds the size of collection? and One more question can my questions be seen by the authors of the Repo? and if not then how can I reach them?

Thank you for any help you can provide

DK-77 avatar Apr 11 '23 08:04 DK-77

import chromadb

client = chromadb.Client()

collection = client.create_collection("test")

collection.add(
    embeddings=[
        [1.1, 2.3, 3.2],
        [4.5, 6.9, 4.4],
        [1.1, 2.3, 3.2],
        [4.5, 6.9, 4.4],
        [1.1, 2.3, 3.2],
        [4.5, 6.9, 4.4],
        [1.1, 2.3, 3.2],
        [4.5, 6.9, 4.4],
    ],
    metadatas=[
        {"uri": "img1.png", "style": "style1"},
        {"uri": "img2.png", "style": "style2"},
        {"uri": "img3.png", "style": "style1"},
        {"uri": "img4.png", "style": "style1"},
        {"uri": "img5.png", "style": "style1"},
        {"uri": "img6.png", "style": "style1"},
        {"uri": "img7.png", "style": "style1"},
        {"uri": "img8.png", "style": "style1"},
    ],
    documents=["doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8"],
    ids=["id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"],
)

query_result = collection.query(
        query_embeddings=[[1.1, 2.3, 3.2], [5.1, 4.3, 2.2]],
        n_results=-1,
    )

print(query_result)

the code above raises exception complaining TypeError: knn_query(): incompatible function arguments.


    query_result = collection.query(
 ...
  File "\chroma\chromadb\db\index\hnswlib.py", line 248, in get_nearest_neighbors
    database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function)
TypeError: knn_query(): incompatible function arguments. The following argument types are supported:
    1. (self: hnswlib.Index, data: object, k: int = 1, num_threads: int = -1, filter: Callable[[int], bool] = None) -> object

Invoked with: <hnswlib.Index(space='l2', dim=3)>, [[1.1, 2.3, 3.2], [5.1, 4.3, 2.2]]; kwargs: k=-1, filter=None

I think this should fail in more graceful way and can be resolved with this issue#301.

if n_results < 0, then there should be a way to check the n_results prior to call query() and AssertionError should be raised instead of TypeError: knn_query(): incompatible function arguments.

Satyam-79 avatar Apr 16 '23 03:04 Satyam-79

while this is an issue, the solution is pretty much a oneliner:

    # Query the collection to get the 5 most relevant results
    count = collection.count()
    results = collection.query(query_texts=[query],
                               n_results=min(5, count) # here is the fix
					)

swyxio avatar Apr 24 '23 16:04 swyxio