dspy
dspy copied to clipboard
DatabricksRM retrieval using dspy.Retrieve() throws TypeError
Following the pattern from the simple RAG Example in the docs, I've created a DatabricksRM which works when calling like
rm(query="Model serving API", query_type="text")
But when trying to use dspy.settings.configure(rm=rm) and dspy.Retrieve() like below
import dspy
from dspy.retrieve.databricks_rm import DatabricksRM
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
serving_url = url + '/serving-endpoints'
lm = dspy.Databricks(model='databricks-dbrx-instruct', model_type='chat', api_key=token,
api_base=serving_url, max_tokens=1000, temperature=0.85)
teacher = dspy.Databricks(model='databricks-meta-llama-3-70b-instruct', model_type='chat', api_key=token,
api_base=serving_url, max_tokens=1000, temperature=0)
rm = DatabricksRM( # This index was created using the Databricks Demo Center RAG Tutorial
databricks_index_name="catalog.schema.databricks_documentation_vs_index",
databricks_endpoint=url,
databricks_token=token,
columns=["content"],
text_column_name="content",
docs_id_column_name="id",
)
dspy.settings.configure(lm=lm, rm=rm)
retrieve = dspy.Retrieve()
retrieve(query_or_queries="What is Apache Spark?", query_type="text")
I get TypeError: DatabricksRM.forward() got an unexpected keyword argument 'k'