dspy
dspy copied to clipboard
Errors with Custom Local Model
I am trying to build a custom Local Model Client as per the following documentation . This is how I how I have set it up:
from dsp import LM
from huggingface_hub import InferenceClient
class CustomLMClient(LM):
def __init__(self, model_endpoint):
self.provider = "default"
self.client = InferenceClient(model=model_endpoint)
self.history = []
def basic_request(self, prompt, **kwargs):
output_text = self.client.text_generation(prompt=prompt, **kwargs)
self.history.append({
"prompt": prompt,
"response": output_text,
"kwargs": kwargs
})
return output_text
def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
response = self.basic_request(prompt, **kwargs)
completions = [response]
return completions
I am not using the built-in dspy.HFClientTGI
as it requires a model argument which is the model name on HuggingFace. However, this model is not on the huggingface hub and is already deployed on a TGI endpoint. However, I encounter the following error when trying to use it with dspy.ChainOfThought
.
This is the code I tried:
llm = CustomLMClient(model_endpoint)
dspy.configure(lm=llm)
#Example DSPy CoT QA program
qa = dspy.ChainOfThought('question -> answer')
response = qa(question="What is the capital of Paris?", max_new_tokens=20) #Prompted to llm
print(response.answer)
This is my traceback:
AttributeError Traceback (most recent call last)
Cell In[22], line 6
3 #Example DSPy CoT QA program
4 qa = dspy.ChainOfThought('question -> answer')
----> 6 response = qa(question="What is the capital of Paris?", max_new_tokens=20) #Prompted to llm
7 print(response.answer)
File ~/env/lib/python3.10/site-packages/dspy/predict/predict.py:49, in Predict.__call__(self, **kwargs)
48 def __call__(self, **kwargs):
---> 49 return self.forward(**kwargs)
File ~/env/lib/python3.10/site-packages/dspy/predict/chain_of_thought.py:59, in ChainOfThought.forward(self, **kwargs)
57 signature = new_signature
58 # template = dsp.Template(self.signature.instructions, **new_signature)
---> 59 return super().forward(signature=signature, **kwargs)
File ~/env/lib/python3.10/site-packages/dspy/predict/predict.py:64, in Predict.forward(self, **kwargs)
62 # If temperature is 0.0 but its n > 1, set temperature to 0.7.
63 temperature = config.get("temperature")
---> 64 temperature = lm.kwargs["temperature"] if temperature is None else temperature
66 num_generations = config.get("n")
67 if num_generations is None:
AttributeError: 'CustomLMClient' object has no attribute 'kwargs'
I am a bit confused since it doesn't mention in the docs to pass in kwargs
to the __init__
function. I did try also doing that but that did not work either.
I am using dspy-ai==2.4.0
In addition to this, I have a model in which I have added in custom tags (similar to the [INST][/INST] for mistral instruct). How do I ensure that the optimization prompt takes these special tags into account?
@ss2342 can you try adding the following to init and check?
self.kwargs = {
"temperature": 0.0,
"max_tokens": 150,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
**kwargs,
}
or make it like this:
from dsp import LM
from huggingface_hub import InferenceClient
class CustomLMClient(LM):
def __init__(self, model_endpoint, **kwargs):
self.provider = "default"
self.client = InferenceClient(model=model_endpoint)
self.kwargs = {
"temperature": 0.0,
"max_tokens": 150,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
**kwargs,
}
self.history = []
def basic_request(self, prompt, **kwargs):
kwargs = {**self.kwargs, **kwargs}
output_text = self.client.text_generation(prompt=prompt, **kwargs)
self.history.append({
"prompt": prompt,
"response": output_text,
"kwargs": kwargs
})
return output_text
def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
response = self.basic_request(prompt, **kwargs)
completions = [response]
return completions
and I know you mentioned when you added it in init it gave error, can you provide the error traceback for it?
another solution could be adding super().__init__(model=model)
in your __init__
since LM
base class already has a default kwargs
inside.
https://github.com/stanfordnlp/dspy/blob/4c6a2ffb6e9dd64dd6e5a587b468c695487f9136/dsp/modules/lm.py#L8
I had the same issue would be great to update the documentations here: https://dspy-docs.vercel.app/docs/deep-dive/language_model_clients/custom-lm-client
Ran into this and got around by @krypticmouse 's solution. Please update the documentation!