distilabel icon indicating copy to clipboard operation
distilabel copied to clipboard

Add basic implementation of `VectorSearch` `Step` and `KnowledgeBases`

Open davidberenstein1957 opened this issue 4 months ago • 1 comments

  • went for lancedb because it works in memory.
  • @frascuchon as follow up we can consider adding argilla based on your vector search PR :)

Do vector search using a KnowledgeBase and integrated Embeddings model.

from distilabel.embeddings import SentenceTransformerEmbeddings
from distilabel.knowledge_bases.lancedb import LanceDB
from distilabel.steps.knowledge_bases.vector_search import VectorSearch

embedding = SentenceTransformerEmbeddings(
    model="mixedbread-ai/mxbai-embed-large-v1",
)

knowledge_base = LanceDB(
    uri="data/sample-lancedb",
    table_name="my_table",
)

vector_search = VectorSearch(
    knowledge_base=knowledge_base,
    embeddings=embedding,
    n_retrieved_documents=5
)

vector_search.load()
result = next(vector_search.process([{"text": "Hello, how are you?"}]))
# [{
#   'text': 'Hello, how are you?',
#   'embedding': [0.06209656596183777, -0.015797119587659836, ...],
#   'knowledge_base_col_1': [10.0],
#   'knowledge_base_col_2': ['foo']
# }]

Do vector search using a KnowledgeBase and a pre-computed query column.

from distilabel.embeddings import SentenceTransformerEmbeddings
from distilabel.knowledge_bases.lancedb import LanceDB
from distilabel.steps.knowledge_bases.vector_search import VectorSearch

knowledge_base = LanceDB(
    uri="data/sample-lancedb",
    table_name="my_table",
)

vector_search = VectorSearch(
    knowledge_base=knowledge_base,
    n_retrieved_documents=5
)

vector_search.load()
result = next(embedding_generation.process([{'embedding': [0.06209656596183777, -0.015797119587659836, ...]}]))
# [{'embedding': [0.06209656596183777, -0.015797119587659836, ...], "knowledge_base_col_1": [10.0], "knowledge_base_col_2": ["foo"]}]

Or with Argilla

import os

from distilabel.knowledge_bases.argilla import ArgillaKnowledgeBase
from distilabel.steps.knowledge_bases.vector_search import VectorSearch

knowledge_base = ArgillaKnowledgeBase(
    dataset_name="ag_news_with_suggestions",
    dataset_workspace="argilla",
    vector_field="mini-lm-sentence-transformers",
    api_url=os.environ["ARGILLA_API_URL_DEV"],
    api_key=os.environ["ARGILLA_API_KEY_DEV"],
)

vector_search = VectorSearch(knowledge_base=knowledge_base, n_retrieved_documents=5)

vector_search.load()
result = next(
    vector_search.process([{"text": "Hello, how are you?", "embedding": [1] * 384}])
)
print(result)
# [{'text': ["Italy's Pennetta Wins Idea Prokom Open (AP) AP - Italy's Flavia Pennetta won the Idea Prokom Open for her first WTA Tour title, beating Klara Koukalova of the Czech Republic 7-5, 3-6, 6-3 Saturday after French Open champion Anastasia Myskina withdrew before the semifinals because of a rib injury."], 'embedding': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'id': ['65a57d53-1d1d-4acb-9f12-456d6989e905'], 'status': ['completed'], '_server_id': ['b0a29605-21ed-48d1-b2fe-163243947c2d'], 'split': ['unlabelled'], 'class.responses': [['Sci/Tech']], 'class.responses.users': [['3b1a58ff-6213-4365-880b-17532d13978c']], 'class.responses.status': [['submitted']], 'class.suggestion': ['Sports'], 'class.suggestion.score': [0.3421393299797776], 'class.suggestion.agent': ['setfit']}]

davidberenstein1957 avatar Sep 29 '24 08:09 davidberenstein1957