distilabel
distilabel copied to clipboard
Add basic implementation of `VectorSearch` `Step` and `KnowledgeBases`
- went for lancedb because it works in memory.
- @frascuchon as follow up we can consider adding argilla based on your vector search PR :)
Do vector search using a KnowledgeBase
and integrated Embeddings
model.
from distilabel.embeddings import SentenceTransformerEmbeddings
from distilabel.knowledge_bases.lancedb import LanceDB
from distilabel.steps.knowledge_bases.vector_search import VectorSearch
embedding = SentenceTransformerEmbeddings(
model="mixedbread-ai/mxbai-embed-large-v1",
)
knowledge_base = LanceDB(
uri="data/sample-lancedb",
table_name="my_table",
)
vector_search = VectorSearch(
knowledge_base=knowledge_base,
embeddings=embedding,
n_retrieved_documents=5
)
vector_search.load()
result = next(vector_search.process([{"text": "Hello, how are you?"}]))
# [{
# 'text': 'Hello, how are you?',
# 'embedding': [0.06209656596183777, -0.015797119587659836, ...],
# 'knowledge_base_col_1': [10.0],
# 'knowledge_base_col_2': ['foo']
# }]
Do vector search using a KnowledgeBase
and a pre-computed query
column.
from distilabel.embeddings import SentenceTransformerEmbeddings
from distilabel.knowledge_bases.lancedb import LanceDB
from distilabel.steps.knowledge_bases.vector_search import VectorSearch
knowledge_base = LanceDB(
uri="data/sample-lancedb",
table_name="my_table",
)
vector_search = VectorSearch(
knowledge_base=knowledge_base,
n_retrieved_documents=5
)
vector_search.load()
result = next(embedding_generation.process([{'embedding': [0.06209656596183777, -0.015797119587659836, ...]}]))
# [{'embedding': [0.06209656596183777, -0.015797119587659836, ...], "knowledge_base_col_1": [10.0], "knowledge_base_col_2": ["foo"]}]
Or with Argilla
import os
from distilabel.knowledge_bases.argilla import ArgillaKnowledgeBase
from distilabel.steps.knowledge_bases.vector_search import VectorSearch
knowledge_base = ArgillaKnowledgeBase(
dataset_name="ag_news_with_suggestions",
dataset_workspace="argilla",
vector_field="mini-lm-sentence-transformers",
api_url=os.environ["ARGILLA_API_URL_DEV"],
api_key=os.environ["ARGILLA_API_KEY_DEV"],
)
vector_search = VectorSearch(knowledge_base=knowledge_base, n_retrieved_documents=5)
vector_search.load()
result = next(
vector_search.process([{"text": "Hello, how are you?", "embedding": [1] * 384}])
)
print(result)
# [{'text': ["Italy's Pennetta Wins Idea Prokom Open (AP) AP - Italy's Flavia Pennetta won the Idea Prokom Open for her first WTA Tour title, beating Klara Koukalova of the Czech Republic 7-5, 3-6, 6-3 Saturday after French Open champion Anastasia Myskina withdrew before the semifinals because of a rib injury."], 'embedding': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'id': ['65a57d53-1d1d-4acb-9f12-456d6989e905'], 'status': ['completed'], '_server_id': ['b0a29605-21ed-48d1-b2fe-163243947c2d'], 'split': ['unlabelled'], 'class.responses': [['Sci/Tech']], 'class.responses.users': [['3b1a58ff-6213-4365-880b-17532d13978c']], 'class.responses.status': [['submitted']], 'class.suggestion': ['Sports'], 'class.suggestion.score': [0.3421393299797776], 'class.suggestion.agent': ['setfit']}]