dspy icon indicating copy to clipboard operation
dspy copied to clipboard

Saving/Loading Method for signatures in dspy.Predict

Open GFarnon opened this issue 1 year ago • 7 comments

Currently the dump_state / load_state method in dspy.Predict only saves signature_instructions and signature_prefix - the prefix for the final key in the signature:

def dump_state(self):
       
         ............
        # Cache the signature instructions and the last field's name.
        state["signature_instructions"] = self.signature.instructions

        *_, last_key = self.signature.fields.keys()
        state["signature_prefix"] = self.signature.fields[last_key].json_schema_extra["prefix"]

        return state
        
def load_state(self, state):
        for name, value in state.items():
            setattr(self, name, value)

        # Reconstruct the signature.
        if "signature_instructions" in state:
            instructions = state["signature_instructions"]
            self.signature = self.signature.with_instructions(instructions)

        if "signature_prefix" in state:
            prefix = state["signature_prefix"]
            *_, last_key = self.signature.fields.keys()
            self.signature = self.signature.with_updated_fields(last_key, prefix=prefix)

But the optimizers also optimize the descriptions and prefixes of all the other fields in the signature, so we lose those when we save the state. Can we change to something like this?

def dump_state(self):
      state_keys = ["lm", "traces", "train"]
      state = {k: getattr(self, k) for k in state_keys}

      state["demos"] = []
      for demo in self.demos:
          demo = demo.copy()

          for field in demo:
              if isinstance(demo[field], BaseModel):
                  demo[field] = demo[field].model_dump_json()

          state["demos"].append(demo)

      # Cache the signature instructions and the last field's name.
      state["signature_instructions"] = self.signature.instructions

      field_details = {}
      for key, field in self.signature.fields.items():
          field_details[key] = {
              "prefix": field.json_schema_extra.get("prefix", ""),
              "desc": field.json_schema_extra.get("desc", "")
          }
      state["signature_field_details"] = field_details

      return state

  def load_state(self, state):
      for name, value in state.items():
          setattr(self, name, value)

      # Reconstruct the signature.
      if "signature_instructions" in state:
          instructions = state["signature_instructions"]
          self.signature = self.signature.with_instructions(instructions)

      if "signature_field_details" in state:
          field_details = state["signature_field_details"]
          for key, details in field_details.items():
              prefix = details.get("prefix", "")
              desc = details.get("desc", "")
              self.signature = self.signature.with_updated_fields(
                  key, prefix=prefix, desc=desc
              )

GFarnon avatar May 20 '24 09:05 GFarnon

Yes, this is reasonable — but no current optimizers optimize the little parts that are not saved.

okhat avatar May 20 '24 09:05 okhat

Thanks, I'm pretty sure optimize_signature for TypedPredictor does.

e.g. After optimizing this signature :

generate_related_questions.generate_related_questions.predictor.predictor = Predict(RelatedQuestionsSignature(query, topics -> related_questions
    instructions='\n    Generate related questions for a given query and topics\n    '
    query = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Query:', 'desc': '${query}'})
    topics = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Topics:', 'desc': '${topics}'})
    related_questions = Field(annotation=List[str] required=True json_schema_extra={'desc': 'A list of 4 related questions to the input question', '__dspy_field_type': 'output', 'prefix': 'Related Questions:'})
))

Becomes:

generate_related_questions.generate_related_questions.predictor.predictor = Predict(StringSignature(query, topics -> related_questions
    instructions='You are a professor of mathematics. Generate related questions for a given query and topics. Use lots of creativity and be as clear as possible. I really need your help!'
    query = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Main Query:', 'desc': '${query}'})
    topics = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Associated Topics:', 'desc': '${topics}'})
    related_questions = Field(annotation=List[str] required=True json_schema_extra={'desc': 'A list of 4 creative and clear questions related to the input query.', '__dspy_field_type': 'output', 'prefix': 'Related Questions:'})
))

GFarnon avatar May 20 '24 09:05 GFarnon

Thank you @GFarnon ! I guess I reached the point where I don't know every little detail in all optimizers anymore. Ok yes this is a change we should merge. Would be so kind as to open a PR?

okhat avatar May 20 '24 10:05 okhat

🫡

GFarnon avatar May 20 '24 10:05 GFarnon

I'm very inexperienced so it's likely something I'm doing wrong, but this has tripped me up a lot.

The docs tell me that a saved program "contains all the parameters and steps in the source program", but it seems that the reality is more like: "running an optimiser and saving its output will save the specific output of that optimiser, not the whole program".

JarradJarrad avatar Jun 14 '24 00:06 JarradJarrad

@chenmoneygithub Wanna take a look at this if you get a chance? Seems relevant to your current efforts.

okhat avatar Sep 27 '24 15:09 okhat

Sure! I will revisit what states should be saved at dump_state call.

chenmoneygithub avatar Sep 27 '24 20:09 chenmoneygithub