dspy icon indicating copy to clipboard operation
dspy copied to clipboard

fix/Extend generation for all candidate completions

Open mikeedjones opened this issue 4 months ago • 6 comments

Following from https://github.com/stanfordnlp/dspy/pull/918 - This PR attempts a fuller implementation of the fix suggested there such that every candidate completion is extended until it contains valid content for every field in the template.

Fixes https://github.com/stanfordnlp/dspy/issues/914 and I suspect https://github.com/stanfordnlp/dspy/issues/749 and some other raised issues.

The current, unpatched, behaviour also seems to be replicated in the backend-refactor branch.

This PR makes the following changes to the logic of do_generate in dsp/primitives/predict.py:

  • Differentiates between completions and finished_completions where finished_completions have entries for every field in the template.
  • Introduces the functionextend_generation to dsp/primitives/predict.py, in the ns of _generate which ensures (up to two levels of recursion) that every template field has valid content for all n completions requested.
  • Introduces get_last_field to acquire the last valid field of an Example.
  • Maintains overwriting of temperature in extended completions.

Further suggestions:

  • Should this method raise warnings suggesting that max_tokens should be increased? This implementation could slow down forward passes significantly.

mikeedjones avatar Apr 28 '24 13:04 mikeedjones

Thanks @mikeedjones for the PR! This makes sense, but just curious from #734 , whether setting a large number of tokens likely solves this issue? I feel like it's actually better to give the user control over request parameters and adjust accordingly than to excessively increase to more than needed (unless the recursion has some control over this)

Should this method raise warnings suggesting that max_tokens might be increased?

This should definitely be logged as any impact to request parameters are important to flag!

arnavsinghvi11 avatar Apr 28 '24 18:04 arnavsinghvi11

What do you mean by "solves the issue"? Increasing max_tokens would make it more likely that the fallback logic is not entered into, but very long or complicated signatures might still exceed even very high token-generation limits. For example, whilst Claude 3's context window is 200k tokens, the generation limit is 4096.

I've gone into more detail of the problem in the other, more atomised fix I proposed: https://github.com/stanfordnlp/dspy/pull/918#issue-2267428297

mikeedjones avatar Apr 28 '24 18:04 mikeedjones

I'm not sure I understand. If the generation limit is restricted, does setting max_tokens = 4096 not capture what's done here? If the long signatures exceed the very high token-generation limits, it would not work anyways right? maybe I'm misinterpreting so feel free to correct with an example!

arnavsinghvi11 avatar Apr 28 '24 18:04 arnavsinghvi11

The current flow, implemented now on main, checks which template fields are in the n completions made in the first pass. If none of the completions contain all the fields, there is some fallback logic is entered, in which the LM (generate function) is called recursively until the fields are created:

        # If none of the completions is completed (i.e., none has the final field set).
        if last_field_idx < len(field_names):
            # Pick the first completion that has gone farthest.
            completion = completions[0]
            ...

            new_kwargs = {
                **kwargs,
                max_tokens_key: max_tokens,
                "n": 1,
                "temperature": 0.0,
            }

            assert max_depth > 0
            return generate(template, **new_kwargs)(
                completion,
                stage=stage,
                max_depth=max_depth - 1,
                original_example=original_example,
            )

The fallback logic gets the "most complete" completion and uses it to make a further call to the LM to generate an extra k tokens (k is chosen by some more logic in primirtives/predict.py.

If max_tokens is increased then the likelihood the LM generates all the required fields goes up, but it is not certain.

For example, with the below:

import dspy
import os

llm = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=100)
dspy.settings.configure(lm=llm)

class SciGen(dspy.Signature):
    """context and answer on science questions"""
    question = dspy.InputField(desc="Input question")
    foo = dspy.OutputField(desc="foo_context") 
    bar = dspy.OutputField(desc="bar_context") 
    bazz = dspy.OutputField(desc="bazz_context") 
    buzz = dspy.OutputField(desc="buzz_context") 
    fuzz = dspy.OutputField(desc="fuzz_context") 
    fooz = dspy.OutputField(desc="fooz_context") 
    answer = dspy.OutputField(desc="Answer using only the contexts")

context_generator = dspy.Predict(SciGen, n=10)
response = context_generator(question="Why do we have solar eclipse?")

print(len(response.completions))
# 1

The first call to gpt-35 (round 1) produces 10 completions, none of which contain the required fieldsfooz or answer.

The fallback logic is therefore entered, where the "most complete" completion is used (maybe the one which contains fuzz) as the Example for another call of generate, with updated kwargs {"n":1, "max_tokens": k, "temperature":0} (round 2).

The updated completion only produces 1 completion, as n has been overwritten, based upon the most complete completion from round 1.

It is this fallback logic which is causing only one completion to be returned in https://github.com/stanfordnlp/dspy/issues/914.

For an arbitrarily long and complex signature, there is no guarantee that the model will generate the required fields - I suspect that's why the fallback logic was included in the first place! The fallback logic (and my update to it) extends the generation (using the completion from round 1 as input to the calls in round 2) to allow for arbitrarily long signatures - up to the context limit of the LM. But the current implementation replaces the user n with 1.

The ultimate limit on "rounds" is set by max_depth - so an ultimate limit to the output of 4096*max_depth as opposed to 4096.

mikeedjones avatar Apr 28 '24 19:04 mikeedjones

Thanks @mikeedjones , this is really helpful! I do see the issue now lies more in the response parsing which triggers the fallback completion logic.

With your code above and the proposed changes in the PR, there are indeed 10 outputted completions, but these are actually 10 "incomplete" completions due to output parsing errors (e.g.

Prediction(
    foo='....',
    bar=''....',
    bazz=''...',
    buzz=''....',
    fuzz=''....\n\nFooz',
    fooz='', #is empty because of the previous parsing error in fuzz likely not producing the Fooz prefix as "Fooz:"
    answer="'.... \n\nAnswer: '...."
)

whereas with the existing logic, there are only 2 completions outputted, but they are "complete" with all fields parsed correctly (from the fallback logic).

This occurs even when I remove `"n": 1" from the #918 . Based on this, I believe we need to tackle a deeper parsing issue rather than extending generation for all candidates, especially since it's better to have 2 "complete" completions instead of 10 - but ideally we want 10 "complete" completions!

Let me know if that makes sense as this PR doesn't fully correct the existing issue (but potentially is on the right track!).

arnavsinghvi11 avatar Apr 28 '24 19:04 arnavsinghvi11

Good catch @arnavsinghvi11! Thank you :)

Yes, looks like I was using the last filled field to restart the completion as opposed to the first missing field.

Updated the test as the LM didn't reliably fill the nonsense fields - leading to inconsistent results.

import dspy


llm = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=75)
dspy.settings.configure(lm=llm)
class SciGen(dspy.Signature):
    """context and answer on science questions"""
    question = dspy.InputField(desc="Input question")
    sun_context = dspy.OutputField(desc="sun_context") 
    moon_context = dspy.OutputField(desc="moon_context") 
    earth_context = dspy.OutputField(desc="earth_context") 
    relative_distances_context = dspy.OutputField(desc="relative_distances_context") 
    answer = dspy.OutputField(desc="Answer only when you have all the context fields.")

context_generator = dspy.Predict(SciGen, n=10)
response = context_generator(question="Why do we have solar eclipse?")

assert len(response.completions) == 10

for answer in response.completions:
    for key in [
        "sun_context",
        "moon_context",
        "earth_context",
        "answer",
    ]:
        assert key in answer
        assert answer[key] is not None
        assert answer[key] != ""

I think the parsing errors you're seeing are also due to the LM producing junk when given the odd prompt generated by the odd signature. I think this could be a larger problem with dspy and attempts to make reliably parsable LM output.

mikeedjones avatar Apr 28 '24 21:04 mikeedjones

Thanks @mikeedjones . Could you run ruff check . --fix-only and push again? Ready to merge after that.

To confirm, this change is more comprehensive than #918 and that PR can be closed after this is merged?

arnavsinghvi11 avatar May 06 '24 00:05 arnavsinghvi11

@arnavsinghvi11 yes that's correct. It should pick up a few other issues as well relating to n!=1

Linting appiled - cheers!

Cheers

mikeedjones avatar May 06 '24 07:05 mikeedjones