dspy
dspy copied to clipboard
feat(dspy): try from json before applying formatting logic in typed predictor
Adds an attempt to use the function supplied to from_json
before running string parsing. This allows users to modify the pydantic's model_validate_json
to customise the deserialization of the generation.
Other changes were to satisfy ruff.
This PR gives users a route to inject code which can solve issues like https://github.com/stanfordnlp/dspy/issues/1024 and https://github.com/stanfordnlp/dspy/issues/1001
After this PR, the user could implement something like the following:
class TaxField(BaseModel):
key: str = Field()
value: str = Field()
truthful: bool = Field(
description="is a tuple contains the extracted value and a boolean indicating if the inFieldsation is truthful and correct."
)
confidence: float = Field(ge=0, le=1, description="The confidence score for the answer")
@classmethod
def model_validate_json(
cls, json_data: str, *, strict: bool | None = None, context: dict[str, Any] | None = None
) -> "list[ModelField]":
try:
__tracebackhide__ = True
return cls.__pydantic_validator__.validate_json(
json_data, strict=strict, context=context
)
except ValidationError:
# custom parsing logic here
raise ValueError("Could not find valid json") from last_exc
class QAExtractionSignature(Signature):
"""Your task is to extract the key-value pairs from the document that will follow the instructions on the field to extract.\nplease keep the content you extract truthful and correct based on the document text provided"""
document: str = InputField()
instruction: str = InputField(
description="The instruction on which field to extact and how to extract it."
)
tax_from_fields: TaxField = OutputField(
description="The list of fields extracted from the document text. IMPORTANT!! This must follow a semicolon separated list of values!"
)
class TypedQAExtraction(dspy.Module):
def __init__(self):
self.qa_extraction = dspy.functional.TypedPredictor(QAExtractionSignature)
def forward(self, document: str, instruction: str) -> TaxFromFields:
return self.qa_extraction(document=document, instruction=instruction).tax_from_fields