dspy
dspy copied to clipboard
How to best use TypedPredictors?
I want to extract information using pydantic types, and generate many examples like in the Readme.
from pydantic import BaseModel, Field
from typing import Literal
from dspy.functional import TypedPredictor
from dspy import Signature, InputField, OutputField
class AssertReason(BaseModel):
assertion: str = Field()
reason: str = Field()
answer: Literal["A", "B", "C", "D"] = Field()
class AssertReasonDSPy(Signature):
""" Generate a list of assertions and reasons for the given context."""
context: str = InputField()
items: list[AssertReason] = OutputField()
predictor = TypedPredictor(AssertReasonDSPy)
text = "Coffee One day around 850 CE, a goat herd named Kaldi observed that, after nibbling on some berries, his goats started acting abnormally. Kaldi tried them himself, and soon enough, he was just as hyper. This was humanity's first run-in with coffee— or so the story goes."
predictor(context=text)
I get an error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[27], [line 16](vscode-notebook-cell:?execution_count=27&line=16)
[13](vscode-notebook-cell:?execution_count=27&line=13) items: list[AssertReason] = OutputField()
[15](vscode-notebook-cell:?execution_count=27&line=15) predictor = TypedPredictor(AssertReasonDSPy)
---> [16](vscode-notebook-cell:?execution_count=27&line=16) predictor(context=text)
File [~/miniforge3/envs/qabot/lib/python3.12/site-packages/dspy/primitives/program.py:26](..../lib/python3.12/site-packages/dspy/primitives/program.py:26), in Module.__call__(self, *args, **kwargs)
[25](..../lib/python3.12/site-packages/dspy/primitives/program.py:25) def __call__(self, *args, **kwargs):
---> [26](..../lib/python3.12/site-packages/dspy/primitives/program.py:26) return self.forward(*args, **kwargs)
File [~/miniforge3/envs/qabot/lib/python3.12/site-packages/dspy/functional/functional.py:190](..../lib/python3.12/site-packages/dspy/functional/functional.py:190), in TypedPredictor.forward(self, **kwargs)
[188](..../lib/python3.12/site-packages/dspy/functional/functional.py:188) value = completion[name]
[189](..../lib/python3.12/site-packages/dspy/functional/functional.py:189) parser = field.json_schema_extra.get("parser", lambda x: x)
--> [190](..../lib/python3.12/site-packages/dspy/functional/functional.py:190) parsed[name] = parser(value)
[191](..../lib/python3.12/site-packages/dspy/functional/functional.py:191) except (pydantic.ValidationError, ValueError) as e:
[192](..../lib/python3.12/site-packages/dspy/functional/functional.py:192) errors[name] = _format_error(e)
File [~/miniforge3/envs/qabot/lib/python3.12/site-packages/dspy/functional/functional.py:152](..../lib/python3.12/site-packages/dspy/functional/functional.py:152), in TypedPredictor._prepare_signature.<locals>.<lambda>(x, from_json)
[145](..../lib/python3.12/site-packages/dspy/functional/functional.py:145) from_json = lambda x, type_=type_: type_.model_validate_json(x)
[146](..../lib/python3.12/site-packages/dspy/functional/functional.py:146) schema = json.dumps(type_.model_json_schema())
[147](..../lib/python3.12/site-packages/dspy/functional/functional.py:147) signature = signature.with_updated_fields(
[148](..../lib/python3.12/site-packages/dspy/functional/functional.py:148) name,
[149](..../lib/python3.12/site-packages/dspy/functional/functional.py:149) desc=field.json_schema_extra.get("desc", "")
[150](..../lib/python3.12/site-packages/dspy/functional/functional.py:150) + (". Respond with a single JSON object. JSON Schema: " + schema),
...
--> [282](..../lib/python3.12/site-packages/dspy/functional/functional.py:282) output = output.strip()
[283](..../lib/python3.12/site-packages/dspy/functional/functional.py:283) if output.startswith("```"):
[284](..../lib/python3.12/site-packages/dspy/functional/functional.py:284) if not output.startswith("```json"):
AttributeError: 'builtin_function_or_method' object has no attribute 'strip'
What is the best way to use Typed Predictors?
I was trying something like this to but failed.
Hi @arunpatro, I believe
class AssertReason(BaseModel):
assertion: str = Field()
reason: str = Field()
answer: Literal["A", "B", "C", "D"] = Field()
should not have the type declarations here since it inherits from BaseModel, not Signature which has dspy.InputField and dspy.OutputField.
class AssertReason(BaseModel):
assertion: str
reason: str
answer: Literal["A", "B", "C", "D"]
Let me know if that fixes it or any others follow up.
Feel free to reference documentation on Functional types in DSPy!
Hi @arunpatro. I had the same problem and figured out how to implement this. Try this:
from pydantic import BaseModel, Field
from typing import Literal, List
from dspy.functional import TypedPredictor
from dspy import Signature, InputField, OutputField
class AssertReason(BaseModel):
assertion: str = Field()
reason: str = Field()
answer: List[Literal["A", "B", "C", "D"]] = Field()
class AssertReasonDSPy(Signature):
""" Generate a list of assertions and reasons for the given context."""
context: str = InputField()
items: AssertReason = OutputField()
predictor = TypedPredictor(AssertReasonDSPy)
text = "Coffee One day around 850 CE, a goat herd named Kaldi observed that, after nibbling on some berries, his goats started acting abnormally. Kaldi tried them himself, and soon enough, he was just as hyper. This was humanity's first run-in with coffee— or so the story goes."
predictor(context=text)
All this is great, but I want to add a couple additional items:
- Watch
max_tokens, I know that OllamaLocal defaults it to 150ollama_model = dspy.OllamaLocal( model="llama3.1", model_type="instruct", format="json", max_tokens=20000, ) - The schema can get so complex that the LLM needs optimization just to get off the ground. To fix this, skip validation on parts of the model.
class AssertReason(BaseModel): assertion: str = Field() reason: str = Field() answer: List[pydantic.SkipValidation[Literal["A", "B", "C", "D"]]] = Field()
I wish there were some configurability in how JSON is parsed. I'd love to handle invalid JSON specially. If the JSON gets cut off due to the max_tokens, then either you want to raise that or maybe chunk your text. Either way, there's almost zero operational visibility into what's going on. No bueno.