dspy icon indicating copy to clipboard operation
dspy copied to clipboard

DatabricksRM retrieval using dspy.Retrieve() throws TypeError

Open josh-melton-db opened this issue 8 months ago • 12 comments

Following the pattern from the simple RAG Example in the docs, I've created a DatabricksRM which works when calling like rm(query="Model serving API", query_type="text")

But when trying to use dspy.settings.configure(rm=rm) and dspy.Retrieve() like below

import dspy
from dspy.retrieve.databricks_rm import DatabricksRM

token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
serving_url = url + '/serving-endpoints'

lm = dspy.Databricks(model='databricks-dbrx-instruct', model_type='chat', api_key=token, 
                     api_base=serving_url, max_tokens=1000, temperature=0.85)
teacher = dspy.Databricks(model='databricks-meta-llama-3-70b-instruct', model_type='chat', api_key=token, 
                          api_base=serving_url, max_tokens=1000, temperature=0)
rm = DatabricksRM( # This index was created using the Databricks Demo Center RAG Tutorial
    databricks_index_name="catalog.schema.databricks_documentation_vs_index",
    databricks_endpoint=url,
    databricks_token=token,
    columns=["content"],
    text_column_name="content",
    docs_id_column_name="id",
)
dspy.settings.configure(lm=lm, rm=rm)

retrieve = dspy.Retrieve()
retrieve(query_or_queries="What is Apache Spark?", query_type="text")

I get TypeError: DatabricksRM.forward() got an unexpected keyword argument 'k'

josh-melton-db avatar Jun 22 '24 13:06 josh-melton-db