dspy
dspy copied to clipboard
Gemini throws 400 error while compiling signature
The code throws google.api_core.exceptions.InvalidArgument: 400 Unable to submit request because candidateCount must be 1 but the entered value was 2. Update the candidateCount value and try again. error in my compile code.
class GenerateExtraction(dspy.Signature):
"""Extracting requested information from a document."""
context = dspy.InputField(desc="The document text.")
task = dspy.InputField(desc="The task to extract necessary information from document. Use the mapping values for the keys.")
answer = dspy.OutputField(
desc=f"""A list of expected key-value pairs, if a key doesn't have a value return N/A.
IMPORTANT!!! The list must be semi-colon separated.
Do not include any other information.""",
)
class SimpleDocumentTextQA(dspy.Module):
def __init__(self, signature: dspy.Signature | None = None):
super().__init__()
if signature is None:
signature = GenerateExtraction
self.predictor = dspy.Predict(signature)
def forward(self, context, question):
pred = self.predictor(context=context, task=question)
return dspy.Prediction(pred)
module = SimpleDocumentTextQA()
teleprompter = COPRO(metric=metric_fn, verbose=True, depth=2, breadth=2)
config = dict(num_threads=thread_count, display_progress=True)
optmized = teleprompter.compile(
module,
trainset=train,
eval_kwargs=config
)
The stack trace
Average Metric: 0.7307692307692308 / 9 (8.1): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:26<00:00, 2.96s/it]
Traceback (most recent call last):
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/google/api_core/grpc_helpers.py", line 76, in error_remapped_callable
return callable_(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/grpc/_channel.py", line 1181, in __call__
return _end_unary_response_blocking(state, call, False, None)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/grpc/_channel.py", line 1006, in _end_unary_response_blocking
raise _InactiveRpcError(state) # pytype: disable=not-instantiable
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
grpc._channel._InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
status = StatusCode.INVALID_ARGUMENT
details = "Unable to submit request because candidateCount must be 1 but the entered value was 2. Update the candidateCount value and try again."
debug_error_string = "UNKNOWN:Error received from peer ipv4:142.251.215.234:443 {grpc_message:"Unable to submit request because candidateCount must be 1 but the entered value was 2. Update the candidateCount value and try again.", grpc_status:3, created_time:"2024-05-31T09:51:39.510654-07:00"}"
>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 308, in <module>
fire.Fire(main)
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 286, in main
optmized = teleprompter.compile(
^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dspy/teleprompt/copro_optimizer.py", line 307, in compile
instr = dspy.Predict(
^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dspy/predict/predict.py", line 61, in __call__
return self.forward(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dspy/predict/predict.py", line 103, in forward
x, C = dsp.generate(template, **config)(x, stage=self.stage)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dsp/primitives/predict.py", line 77, in do_generate
completions: list[dict[str, Any]] = generator(prompt, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dsp/modules/googlevertexai.py", line 177, in __call__
return self.request(prompt, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/backoff/_sync.py", line 105, in retry
ret = target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dsp/modules/googlevertexai.py", line 168, in request
return self.basic_request(prompt, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dsp/modules/googlevertexai.py", line 126, in basic_request
response = self.client.generate_content(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py", line 407, in generate_content
return self._generate_content(
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py", line 496, in _generate_content
gapic_response = self._prediction_client.generate_content(request=request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py", line 2103, in generate_content
response = rpc(
^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/google/api_core/gapic_v1/method.py", line 131, in __call__
return wrapped_func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/google/api_core/grpc_helpers.py", line 78, in error_remapped_callable
raise exceptions.from_grpc_error(exc) from exc
google.api_core.exceptions.InvalidArgument: 400 Unable to submit request because candidateCount must be 1 but the entered value was 2. Update the candidateCount value and try again.
Any ideas on what is the cause of this?
Is thread_count set to 2? Maybe an issue with multithreading/batching
breadth is passed to the model as n - so you need to set breadth to 1 if you're using gemini.
I set the thread_count to 1 and still failed.
Here is the updated code
def signature_optimization(module, train, metric_fn, thread_count:int=1) -> SimpleDocumentTextQA:
from dspy.teleprompt import COPRO
teleprompter = COPRO(metric=metric_fn, verbose=True, depth=2, breadth=1)
config = dict(num_threads=thread_count, display_progress=True)
optmized = teleprompter.compile(
module,
trainset=train,
eval_kwargs=config
)
optmized.save("optimized_signature.json")
return SimpleDocumentTextQA(optmized)
when I set breadth to 1 the code failed with the following error
Evaluation Result (Before Optimization): 0.0
Signature Optimization...
Traceback (most recent call last):
File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 341, in <module>
fire.Fire(main)
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 331, in main
optmized = signature_optimization(module, train, metric_fn, thread_count)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/gemini/doc_qa.py", line 263, in signature_optimization
teleprompter = COPRO(metric=metric_fn, verbose=True, depth=2, breadth=1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ggao/projects/gemini/.venv/lib/python3.11/site-packages/dspy/teleprompt/copro_optimizer.py", line 69, in __init__
raise ValueError("Breadth must be greater than 1")
ValueError: Breadth must be greater than 1
The error makes sense because the COPRO code is specifically looking for breadth to be greater than 1. Can someone help me understand what is this breadth variable. There isn't enough documentation on what the variable is controlling.
class COPRO(Teleprompter):
def __init__(
self,
prompt_model=None,
metric=None,
breadth=10,
depth=3,
init_temperature=1.4,
track_stats=False,
**_kwargs,
):
if breadth <= 1:
raise ValueError("Breadth must be greater than 1")
self.metric = metric
self.breadth = breadth
self.depth = depth
self.init_temperature = init_temperature
self.prompt_model = prompt_model
self.track_stats = track_stats
Just check the code and n is set to
n=self.breadth - 1,
I would set breadth to 2, so that n becomes one. Not confident at all though that this works
Later in COPRO it uses n=breath iirc. Think a problem with gemini + model-garden API.
Hmm are we seeing a collision between parameter names in DSPy and parameter names in Gemini? ...
The model garden API serves multiple models with a common generation config model. For gemini [candidateCount] must be 1 - but it's being set to 2 by n
Thanks for opening this! We released DSPy 2.5 yesterday. I think the new dspy.LM and the underlying dspy.ChatAdapter will probably resolve this problem.
Here's the (very short) migration guide, it should typically take you 2-3 minutes to change the LM definition and you should be good to go: https://github.com/stanfordnlp/dspy/blob/main/examples/migration.ipynb
Please let us know if this resolves your issue. I will close for now but please feel free to re-open if the problem persists.