DSPY Copro Tutorial for Hotpot QA doesn't work with Bedrock Claude Sonnet Model
I word for word copied the tutorial for using DSPY Copro to optimize the hotpot example on the dspy website, but it didn't work.
from dsp import AWSAnthropic
import dspy
from dspy.datasets import HotPotQA
from dspy.evaluate import Evaluate
from dspy.teleprompt import COPRO
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0)
trainset, devset = dataset.train, dataset.dev
print(trainset)
print(devset)
sonnet = AWSAnthropic(
aws_provider = dspy.Bedrock(region_name="us-west-2"),
model="anthropic.claude-3-sonnet-20240229-v1:0",
)
dspy.configure(lm=sonnet)
class CoTSignature(dspy.Signature):
"""Answer the question and give the reasoning for the same."""
question = dspy.InputField(desc="question about something")
answer = dspy.OutputField(desc="often between 1 and 5 words")
class CoTPipeline(dspy.Module):
def __init__(self):
super().__init__()
self.signature = CoTSignature
self.predictor = dspy.ChainOfThought(self.signature)
def forward(self, question):
result = self.predictor(question=question)
return dspy.Prediction(
answer=result.answer,
reasoning=result.rationale,
)
def validate_context_and_answer(example, pred, trace=None):
answer_EM = dspy.evaluate.answer_exact_match(example, pred)
return answer_EM
NUM_THREADS = 5
evaluate = Evaluate(devset=devset, metric=validate_context_and_answer, num_threads=NUM_THREADS, display_progress=True, display_table=False)
cot_baseline = CoTPipeline()
devset_with_input = [dspy.Example({"question": r["question"], "answer": r["answer"]}).with_inputs("question") for r in devset]
evaluate(cot_baseline, devset=devset_with_input)
kwargs = dict(num_threads=64, display_progress=True, display_table=0) # Used in Evaluate class in the optimization process
teleprompter = COPRO(
metric=validate_context_and_answer,
verbose=True,
)
compiled_prompt_opt = teleprompter.compile(cot_baseline, trainset=devset, eval_kwargs=kwargs)
Here is the error message I received.
Traceback (most recent call last):
File "/Users/esthersu/workplace/p-chatbot-ws/src/P-Chatbot-Bedrock/src/p_chatbot_bedrock/sample_prompt_engineering.py", line 55, in <module>
compiled_prompt_opt = teleprompter.compile(cot_baseline, trainset=devset, eval_kwargs=kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/esthersu/workplace/p-chatbot-ws/src/P-Chatbot-Bedrock/src/p_chatbot_bedrock/path/to/venv/lib/python3.12/site-packages/dspy/teleprompt/copro_optimizer.py", line 171, in compile
instruct = dspy.Predict(
^^^^^^^^^^^^^
File "/Users/esthersu/workplace/p-chatbot-ws/src/P-Chatbot-Bedrock/src/p_chatbot_bedrock/path/to/venv/lib/python3.12/site-packages/dspy/predict/predict.py", line 61, in __call__
return self.forward(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/esthersu/workplace/p-chatbot-ws/src/P-Chatbot-Bedrock/src/p_chatbot_bedrock/path/to/venv/lib/python3.12/site-packages/dspy/predict/predict.py", line 103, in forward
x, C = dsp.generate(template, **config)(x, stage=self.stage)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/esthersu/workplace/p-chatbot-ws/src/P-Chatbot-Bedrock/src/p_chatbot_bedrock/path/to/venv/lib/python3.12/site-packages/dsp/primitives/predict.py", line 113, in do_generate
completions: list[Example] = [template.extract(example, p) for p in completions]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/esthersu/workplace/p-chatbot-ws/src/P-Chatbot-Bedrock/src/p_chatbot_bedrock/path/to/venv/lib/python3.12/site-packages/dsp/templates/template_v2.py", line 152, in extract
raw_pred = raw_pred.strip()
^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'strip'
Hi @suesther , does this work now with the latest commits in the repo. Just merged #843 that impacts AWS providers/models, which was in the backlogs
No, it unfortunately has the same error.
No, it unfortunately has the same error.
@suesther can you specify which version of DSPy you used when you tried it again?
I ran this command:
pip install git+https://github.com/stanfordnlp/dspy.git
@lebsral @arnavsinghvi11 Do you guys have any ideas for what could be the issue?
I have no idea. But if I was going to go about trying to figure it out I would probably start with "how is the wrong data type, a list instead of a string ending up at this point?"
https://github.com/stanfordnlp/dspy/blob/55510eec1b83fa77f368e191a363c150df8c5b02/dsp/templates/template_v2.py#L152
Is there any update regarding this issue? I am still facing this issue today.
This issue still persists. Any plans on fixing this?
Can you give us a complete understanding of your environment? What version of Python ? Virtual env? Conda? Poetry? Spell out what you seeing now please
Python 3.11
Pip list:
aiobotocore 2.12.4
aiohappyeyeballs 2.4.0
aiohttp 3.10.5
aioitertools 0.11.0
aiosignal 1.3.1
alembic 1.13.2
annotated-types 0.7.0
anyio 4.4.0
attrs 24.2.0
backoff 2.2.1
beautifulsoup4 4.12.3
blinker 1.8.2
boto3 1.34.106
botocore 1.34.106
botocore-stubs 1.35.5
Brotli 1.1.0
certifi 2024.7.4
cffi 1.17.0
charset-normalizer 3.3.2
click 8.1.7
colorlog 6.8.2
ConfigArgParse 1.7
cryptography 42.0.8
dataclasses-json 0.6.7
datasets 2.14.7
defusedxml 0.7.1
dill 0.3.7
distro 1.9.0
dspy-ai 2.4.13
Events 0.5
execnet 2.1.1
fastapi 0.110.3
filelock 3.15.4
Flask 3.0.3
Flask-Cors 4.0.1
Flask-Login 0.6.3
frozenlist 1.4.1
fsspec 2023.10.0
gevent 24.2.1
geventhttpclient 2.3.1
greenlet 3.0.3
gunicorn 21.2.0
h11 0.14.0
httpcore 1.0.5
httpx 0.27.0
huggingface-hub 0.24.6
idna 3.8
iniconfig 2.0.0
itsdangerous 2.2.0
Jinja2 3.1.4
jmespath 1.0.1
joblib 1.3.2
jsonpatch 1.33
jsonpointer 3.0.0
langchain 0.1.20
langchain-community 0.0.38
langchain-core 0.1.52
langchain-text-splitters 0.0.2
langsmith 0.1.104
locust 2.29.1
loguru 0.7.2
lxml 5.1.1
Mako 1.3.5
MarkupSafe 2.1.5
marshmallow 3.22.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.15
mypy-extensions 1.0.0
numpy 1.26.4
openai 1.31.2
opensearch-py 2.6.0
optuna 3.6.1
orjson 3.10.7
packaging 23.2
pandas 2.2.2
pip 24.2
pluggy 1.5.0
psutil 6.0.0
pyarrow 17.0.0
pyarrow-hotfix 0.6
PyAthena 3.8.3
pycparser 2.22
pydantic 2.5.0
pydantic_core 2.14.1
PyJWT 2.8.0
pytest 8.3.2
pytest-html 4.1.1
pytest-metadata 3.1.1
pytest-repeat 0.9.3
pytest-xdist 3.6.1
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
pytz 2024.1
PyYAML 6.0.2
pyzmq 26.2.0
regex 2024.7.24
requests 2.32.3
requests-aws4auth 1.2.3
s3transfer 0.10.2
setuptools 72.1.0
six 1.16.0
sniffio 1.3.1
soupsieve 2.6
SQLAlchemy 2.0.32
starlette 0.37.2
structlog 24.4.0
tenacity 8.5.0
tqdm 4.66.5
types-aiobotocore 2.13.3
types-aiobotocore-bedrock-runtime 2.13.3
types-awscrt 0.21.2
typing_extensions 4.12.2
typing-inspect 0.9.0
tzdata 2024.1
ujson 5.10.0
urllib3 2.2.2
uvicorn 0.29.0
Werkzeug 3.0.4
wheel 0.43.0
wrapt 1.16.0
xxhash 3.5.0
yarl 1.9.4
zope.event 5.0
zope.interface 7.0.1
I am initializing the bedrock provider as follows:
bedrock = dspy.Bedrock(region_name="us-east-1")
bedrock_claude_sonnet = dspy.AWSAnthropic(bedrock, model="anthropic.claude-3-haiku-20240307-v1:0")
dspy.settings.configure(lm=bedrock_claude_sonnet)
I'm getting the same error for different claude models as well just fyi
Line 73 in predict.py says:
completions: list[dict[str, Any]] = generator(prompt, **kwargs)
But the completions object ends up being a length one list which itself contains an object of type list[dict[str, Any]]. i.e. it is actually completions: list[list[dict[str, Any]]]
In the __call__ method of AWSModel, the final return is wrapped in a list: return [generated] which results in returning a list of a list of strings, rather than just a list of strings as the type hint would suggest.
The following change to AWSModel.call appears to fix it, but still begs the question as to why it is necessary:
generated = self.basic_request(prompt, **kwargs)
if type(generated) is list:
return generated
return [generated]
if you reinstall dspy-ai with pip install git+https://github.com/stanfordnlp/dspy.git
Does that change anything for you?
Nope, that ended up with exactly the same issue.
+1 running into the exact same issue and error
I'm also seeing the exact same error with Anthropic Haiku, Mistral Small, and Meta-llama2 models as well utilizing the dspy.AWSMistral, dspy.AWSAnthropic, dspy.AWSMeta objects