dspy icon indicating copy to clipboard operation
dspy copied to clipboard

feat(dspy): try from json before applying formatting logic in typed predictor

Open mikeedjones opened this issue 9 months ago • 2 comments

Adds an attempt to use the function supplied to from_json before running string parsing. This allows users to modify the pydantic's model_validate_json to customise the deserialization of the generation.

Other changes were to satisfy ruff.

This PR gives users a route to inject code which can solve issues like https://github.com/stanfordnlp/dspy/issues/1024 and https://github.com/stanfordnlp/dspy/issues/1001

After this PR, the user could implement something like the following:

class TaxField(BaseModel):
    key: str = Field()
    value: str = Field()
    truthful: bool = Field(
        description="is a tuple contains the extracted value and a boolean indicating if the inFieldsation is truthful and correct."
    )
    confidence: float = Field(ge=0, le=1, description="The confidence score for the answer")

    @classmethod
    def model_validate_json(
        cls, json_data: str, *, strict: bool | None = None, context: dict[str, Any] | None = None
    ) -> "list[ModelField]":
        try:
            __tracebackhide__ = True
            return cls.__pydantic_validator__.validate_json(
                json_data, strict=strict, context=context
            )
        except ValidationError:
            # custom parsing logic here
        raise ValueError("Could not find valid json") from last_exc


class QAExtractionSignature(Signature):
    """Your task is to extract the key-value pairs from the document that will follow the instructions on the field to extract.\nplease keep the content you extract truthful and correct based on the document text provided"""

    document: str = InputField()
    instruction: str = InputField(
        description="The instruction on which field to extact and how to extract it."
    )
    tax_from_fields: TaxField = OutputField(
        description="The list of fields extracted from the document text. IMPORTANT!! This must follow a semicolon separated list of values!"
    )


class TypedQAExtraction(dspy.Module):
    def __init__(self):
        self.qa_extraction = dspy.functional.TypedPredictor(QAExtractionSignature)

    def forward(self, document: str, instruction: str) -> TaxFromFields:
        return self.qa_extraction(document=document, instruction=instruction).tax_from_fields

mikeedjones avatar May 22 '24 07:05 mikeedjones