dspy icon indicating copy to clipboard operation
dspy copied to clipboard

How to best use TypedPredictors?

Open arunpatro opened this issue 1 year ago • 3 comments

I want to extract information using pydantic types, and generate many examples like in the Readme.

from pydantic import BaseModel, Field 
from typing import Literal
from dspy.functional import TypedPredictor
from dspy import Signature, InputField, OutputField

class AssertReason(BaseModel):
    assertion: str = Field()
    reason: str = Field()
    answer: Literal["A", "B", "C", "D"] = Field()

class AssertReasonDSPy(Signature):
    """ Generate a list of assertions and reasons for the given context."""
    context: str = InputField()
    items: list[AssertReason] = OutputField()

predictor = TypedPredictor(AssertReasonDSPy)
text = "Coffee One day around 850 CE, a goat herd named Kaldi observed that, after nibbling on some berries, his goats started acting abnormally. Kaldi tried them himself, and soon enough, he was just as hyper. This was humanity's first run-in with coffee— or so the story goes."
predictor(context=text)

I get an error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[27], [line 16](vscode-notebook-cell:?execution_count=27&line=16)
     [13](vscode-notebook-cell:?execution_count=27&line=13)     items: list[AssertReason] = OutputField()
     [15](vscode-notebook-cell:?execution_count=27&line=15) predictor = TypedPredictor(AssertReasonDSPy)
---> [16](vscode-notebook-cell:?execution_count=27&line=16) predictor(context=text)

File [~/miniforge3/envs/qabot/lib/python3.12/site-packages/dspy/primitives/program.py:26](..../lib/python3.12/site-packages/dspy/primitives/program.py:26), in Module.__call__(self, *args, **kwargs)
     [25](..../lib/python3.12/site-packages/dspy/primitives/program.py:25) def __call__(self, *args, **kwargs):
---> [26](..../lib/python3.12/site-packages/dspy/primitives/program.py:26)     return self.forward(*args, **kwargs)

File [~/miniforge3/envs/qabot/lib/python3.12/site-packages/dspy/functional/functional.py:190](..../lib/python3.12/site-packages/dspy/functional/functional.py:190), in TypedPredictor.forward(self, **kwargs)
    [188](..../lib/python3.12/site-packages/dspy/functional/functional.py:188)     value = completion[name]
    [189](..../lib/python3.12/site-packages/dspy/functional/functional.py:189)     parser = field.json_schema_extra.get("parser", lambda x: x)
--> [190](..../lib/python3.12/site-packages/dspy/functional/functional.py:190)     parsed[name] = parser(value)
    [191](..../lib/python3.12/site-packages/dspy/functional/functional.py:191) except (pydantic.ValidationError, ValueError) as e:
    [192](..../lib/python3.12/site-packages/dspy/functional/functional.py:192)     errors[name] = _format_error(e)

File [~/miniforge3/envs/qabot/lib/python3.12/site-packages/dspy/functional/functional.py:152](..../lib/python3.12/site-packages/dspy/functional/functional.py:152), in TypedPredictor._prepare_signature.<locals>.<lambda>(x, from_json)
    [145](..../lib/python3.12/site-packages/dspy/functional/functional.py:145)             from_json = lambda x, type_=type_: type_.model_validate_json(x)
    [146](..../lib/python3.12/site-packages/dspy/functional/functional.py:146)             schema = json.dumps(type_.model_json_schema())
    [147](..../lib/python3.12/site-packages/dspy/functional/functional.py:147)         signature = signature.with_updated_fields(
    [148](..../lib/python3.12/site-packages/dspy/functional/functional.py:148)             name,
    [149](..../lib/python3.12/site-packages/dspy/functional/functional.py:149)             desc=field.json_schema_extra.get("desc", "")
    [150](..../lib/python3.12/site-packages/dspy/functional/functional.py:150)             + (". Respond with a single JSON object. JSON Schema: " + schema),
...
--> [282](..../lib/python3.12/site-packages/dspy/functional/functional.py:282)     output = output.strip()
    [283](..../lib/python3.12/site-packages/dspy/functional/functional.py:283)     if output.startswith("```"):
    [284](..../lib/python3.12/site-packages/dspy/functional/functional.py:284)         if not output.startswith("```json"):

AttributeError: 'builtin_function_or_method' object has no attribute 'strip'

What is the best way to use Typed Predictors?

I was trying something like this to but failed.

arunpatro avatar Mar 27 '24 04:03 arunpatro

Hi @arunpatro, I believe

class AssertReason(BaseModel):
    assertion: str = Field()
    reason: str = Field()
    answer: Literal["A", "B", "C", "D"] = Field()

should not have the type declarations here since it inherits from BaseModel, not Signature which has dspy.InputField and dspy.OutputField.

class AssertReason(BaseModel):
    assertion: str
    reason: str
    answer: Literal["A", "B", "C", "D"]

Let me know if that fixes it or any others follow up.

Feel free to reference documentation on Functional types in DSPy!

arnavsinghvi11 avatar Apr 01 '24 23:04 arnavsinghvi11

Hi @arunpatro. I had the same problem and figured out how to implement this. Try this:

from pydantic import BaseModel, Field 
from typing import Literal, List
from dspy.functional import TypedPredictor
from dspy import Signature, InputField, OutputField

class AssertReason(BaseModel):
      assertion: str = Field()
      reason: str = Field()
      answer: List[Literal["A", "B", "C", "D"]] = Field()
    
class AssertReasonDSPy(Signature):
    """ Generate a list of assertions and reasons for the given context."""
    context: str = InputField()
    items: AssertReason = OutputField()

predictor = TypedPredictor(AssertReasonDSPy)
text = "Coffee One day around 850 CE, a goat herd named Kaldi observed that, after nibbling on some berries, his goats started acting abnormally. Kaldi tried them himself, and soon enough, he was just as hyper. This was humanity's first run-in with coffee— or so the story goes."
predictor(context=text)

MarkusOdenthal avatar May 03 '24 09:05 MarkusOdenthal

All this is great, but I want to add a couple additional items:

  1. Watch max_tokens, I know that OllamaLocal defaults it to 150
    ollama_model = dspy.OllamaLocal(
        model="llama3.1",
        model_type="instruct",
        format="json",
        max_tokens=20000,
    )
    
  2. The schema can get so complex that the LLM needs optimization just to get off the ground. To fix this, skip validation on parts of the model.
    class AssertReason(BaseModel):
      assertion: str = Field()
      reason: str = Field()
      answer: List[pydantic.SkipValidation[Literal["A", "B", "C", "D"]]] = Field()
    

I wish there were some configurability in how JSON is parsed. I'd love to handle invalid JSON specially. If the JSON gets cut off due to the max_tokens, then either you want to raise that or maybe chunk your text. Either way, there's almost zero operational visibility into what's going on. No bueno.

tkellogg avatar Sep 02 '24 19:09 tkellogg