dspy icon indicating copy to clipboard operation
dspy copied to clipboard

Issues Combining BootstrapFewShot and dspy.context

Open thiagodsd opened this issue 6 months ago • 4 comments

I'm encountering a problem when using BootstrapFewShot with custom language models (LM) and datasets. Specifically, I receive an error when setting max_rounds to a value greater than 1:

...

File ~/venv/lib/python3.9/site-packages/dspy/teleprompt/bootstrap.py:182 in BootstrapFewShot._bootstrap_one_example(self, example, round_idx)
   180 with dsp.settings.context(trace=[], **self.teacher_settings):
   181 lm = dsp.settings.lm
-->182 lm.lm.copy(temperature=0.7+0.001*round_idx) if round_idx > 0 else lm

AttributeError: 'NoneType' object has no attribute 'copy'

I’m unable to share detailed logs or code snippets due to company policies, but I can provide a general outline of my setup:

from sagemaker.predictor import Predictor

class CustomLMClient(LM):
    def __init__(self, model, **kwargs):
        self.provider = "default"
        self.model = model
        self.client = Predictor(endpoint_name=self.model)
        ...

class CustomParquetDataset(Dataset):
    ...

class Summarizer(dspy.Signature):
    ...

class Judge(dspy.Signature):
    ...

lm_1 = CustomLMClient(model="mixtral")
lm_2 = CustomLMClient(model="mixtral")

class CustomModule(dspy.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.summarizer = dspy.ChainOfThought(Summarizer)
        self.summarizer._compiled = True
        self.generate_answer = dspy.ChainOfThought(Judge)

    def forward(self, text_field):
        with dspy.context(lm=lm_1, model="summarizer"):
            ...
        with dspy.context(lm=lm_2, model="judge"):
            ...
        return dspy.Prediction(...)

def risk_metric(example, pred, trace=None):
    ...

# dataset definitions here

teleprompter = BootstrapFewShot(metric=risk_metric, max_rounds=5, max_errors=3)
compile_lm = teleprompter.compile(CustomModule(), trainset=trainset)

The content of the experiment seems unrelated to the issue since everything works as expected when max_rounds=1. However, setting max_rounds to a value greater than 1 causes the process to complete one iteration (the progress bar reaches 100%) before triggering the error mentioned above.

thiagodsd avatar Aug 13 '24 12:08 thiagodsd