dspy
dspy copied to clipboard
fix/Extend generation for all candidate completions
Following from https://github.com/stanfordnlp/dspy/pull/918 - This PR attempts a fuller implementation of the fix suggested there such that every candidate completion is extended until it contains valid content for every field in the template.
Fixes https://github.com/stanfordnlp/dspy/issues/914 and I suspect https://github.com/stanfordnlp/dspy/issues/749 and some other raised issues.
The current, unpatched, behaviour also seems to be replicated in the backend-refactor branch.
This PR makes the following changes to the logic of do_generate in dsp/primitives/predict.py:
- Differentiates between
completionsandfinished_completionswherefinished_completionshave entries for every field in the template. - Introduces the function
extend_generationtodsp/primitives/predict.py, in the ns of_generatewhich ensures (up to two levels of recursion) that every template field has valid content for allncompletions requested. - Introduces
get_last_fieldto acquire the last valid field of an Example. - Maintains overwriting of temperature in extended completions.
Further suggestions:
- Should this method raise warnings suggesting that
max_tokensshould be increased? This implementation could slow down forward passes significantly.
Thanks @mikeedjones for the PR! This makes sense, but just curious from #734 , whether setting a large number of tokens likely solves this issue? I feel like it's actually better to give the user control over request parameters and adjust accordingly than to excessively increase to more than needed (unless the recursion has some control over this)
Should this method raise warnings suggesting that max_tokens might be increased?
This should definitely be logged as any impact to request parameters are important to flag!
What do you mean by "solves the issue"? Increasing max_tokens would make it more likely that the fallback logic is not entered into, but very long or complicated signatures might still exceed even very high token-generation limits. For example, whilst Claude 3's context window is 200k tokens, the generation limit is 4096.
I've gone into more detail of the problem in the other, more atomised fix I proposed: https://github.com/stanfordnlp/dspy/pull/918#issue-2267428297
I'm not sure I understand. If the generation limit is restricted, does setting max_tokens = 4096 not capture what's done here? If the long signatures exceed the very high token-generation limits, it would not work anyways right? maybe I'm misinterpreting so feel free to correct with an example!
The current flow, implemented now on main, checks which template fields are in the n completions made in the first pass. If none of the completions contain all the fields, there is some fallback logic is entered, in which the LM (generate function) is called recursively until the fields are created:
# If none of the completions is completed (i.e., none has the final field set).
if last_field_idx < len(field_names):
# Pick the first completion that has gone farthest.
completion = completions[0]
...
new_kwargs = {
**kwargs,
max_tokens_key: max_tokens,
"n": 1,
"temperature": 0.0,
}
assert max_depth > 0
return generate(template, **new_kwargs)(
completion,
stage=stage,
max_depth=max_depth - 1,
original_example=original_example,
)
The fallback logic gets the "most complete" completion and uses it to make a further call to the LM to generate an extra k tokens (k is chosen by some more logic in primirtives/predict.py.
If max_tokens is increased then the likelihood the LM generates all the required fields goes up, but it is not certain.
For example, with the below:
import dspy
import os
llm = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=100)
dspy.settings.configure(lm=llm)
class SciGen(dspy.Signature):
"""context and answer on science questions"""
question = dspy.InputField(desc="Input question")
foo = dspy.OutputField(desc="foo_context")
bar = dspy.OutputField(desc="bar_context")
bazz = dspy.OutputField(desc="bazz_context")
buzz = dspy.OutputField(desc="buzz_context")
fuzz = dspy.OutputField(desc="fuzz_context")
fooz = dspy.OutputField(desc="fooz_context")
answer = dspy.OutputField(desc="Answer using only the contexts")
context_generator = dspy.Predict(SciGen, n=10)
response = context_generator(question="Why do we have solar eclipse?")
print(len(response.completions))
# 1
The first call to gpt-35 (round 1) produces 10 completions, none of which contain the required fieldsfooz or answer.
The fallback logic is therefore entered, where the "most complete" completion is used (maybe the one which contains fuzz) as the Example for another call of generate, with updated kwargs {"n":1, "max_tokens": k, "temperature":0} (round 2).
The updated completion only produces 1 completion, as n has been overwritten, based upon the most complete completion from round 1.
It is this fallback logic which is causing only one completion to be returned in https://github.com/stanfordnlp/dspy/issues/914.
For an arbitrarily long and complex signature, there is no guarantee that the model will generate the required fields - I suspect that's why the fallback logic was included in the first place! The fallback logic (and my update to it) extends the generation (using the completion from round 1 as input to the calls in round 2) to allow for arbitrarily long signatures - up to the context limit of the LM. But the current implementation replaces the user n with 1.
The ultimate limit on "rounds" is set by max_depth - so an ultimate limit to the output of 4096*max_depth as opposed to 4096.
Thanks @mikeedjones , this is really helpful! I do see the issue now lies more in the response parsing which triggers the fallback completion logic.
With your code above and the proposed changes in the PR, there are indeed 10 outputted completions, but these are actually 10 "incomplete" completions due to output parsing errors (e.g.
Prediction(
foo='....',
bar=''....',
bazz=''...',
buzz=''....',
fuzz=''....\n\nFooz',
fooz='', #is empty because of the previous parsing error in fuzz likely not producing the Fooz prefix as "Fooz:"
answer="'.... \n\nAnswer: '...."
)
whereas with the existing logic, there are only 2 completions outputted, but they are "complete" with all fields parsed correctly (from the fallback logic).
This occurs even when I remove `"n": 1" from the #918 . Based on this, I believe we need to tackle a deeper parsing issue rather than extending generation for all candidates, especially since it's better to have 2 "complete" completions instead of 10 - but ideally we want 10 "complete" completions!
Let me know if that makes sense as this PR doesn't fully correct the existing issue (but potentially is on the right track!).
Good catch @arnavsinghvi11! Thank you :)
Yes, looks like I was using the last filled field to restart the completion as opposed to the first missing field.
Updated the test as the LM didn't reliably fill the nonsense fields - leading to inconsistent results.
import dspy
llm = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=75)
dspy.settings.configure(lm=llm)
class SciGen(dspy.Signature):
"""context and answer on science questions"""
question = dspy.InputField(desc="Input question")
sun_context = dspy.OutputField(desc="sun_context")
moon_context = dspy.OutputField(desc="moon_context")
earth_context = dspy.OutputField(desc="earth_context")
relative_distances_context = dspy.OutputField(desc="relative_distances_context")
answer = dspy.OutputField(desc="Answer only when you have all the context fields.")
context_generator = dspy.Predict(SciGen, n=10)
response = context_generator(question="Why do we have solar eclipse?")
assert len(response.completions) == 10
for answer in response.completions:
for key in [
"sun_context",
"moon_context",
"earth_context",
"answer",
]:
assert key in answer
assert answer[key] is not None
assert answer[key] != ""
I think the parsing errors you're seeing are also due to the LM producing junk when given the odd prompt generated by the odd signature. I think this could be a larger problem with dspy and attempts to make reliably parsable LM output.
Thanks @mikeedjones . Could you run ruff check . --fix-only and push again? Ready to merge after that.
To confirm, this change is more comprehensive than #918 and that PR can be closed after this is merged?
@arnavsinghvi11 yes that's correct. It should pick up a few other issues as well relating to n!=1
Linting appiled - cheers!
Cheers
@arnavsinghvi11 - is there anything outstanding for this PR? cheers! :)
Thanks @mikeedjones ! Very useful PR that caught an elaborate issue!
This should not have been merged.
There's an open issue on empty input fields (#1108 ) which is being caused by this PR - but I think the older logic would have the same issue - or is the problem more serious?
EDIT: I'm not sure if the logic as originally implemented works as expected.