griptape icon indicating copy to clipboard operation
griptape copied to clipboard

BedrockLlamaPromptModelDriver does not apply input truncation, resulting in ValidationError

Open njedema opened this issue 11 months ago • 4 comments

Describe the bug BedrockLlamaPromptModelDriver() does not enforce truncation for any PromptTasks. If the task prompt is longer than 2048 tokens when rendered, the following ValidationError is raised until the code exits.

WARNING:root:<RetryCallState ************: attempt #2; slept for 2.0; last result: failed (ValidationException An error occurred (ValidationException) when calling the InvokeModel operation: Validation Error)>

To Reproduce Steps to reproduce the behavior:

The follow snippet is sufficient to reproduce the error:

First, verify that you are able to invoke Llama2 on Bedrock; ensure that you change the PROFILE_NAME

import boto3
from griptape.drivers import (
    AmazonBedrockPromptDriver,
    BedrockLlamaPromptModelDriver
)
from griptape.structures import Agent

REGION='us-west-2'
PROFILE_NAME=<CHANGE ME>

session = boto3.Session(region_name=REGION, profile_name=PROFILE_NAME)

prompt_driver = AmazonBedrockPromptDriver(
    model="meta.llama2-70b-chat-v1",
    prompt_model_driver=BedrockLlamaPromptModelDriver(),
    session=session,
)

agent = Agent(prompt_driver=prompt_driver)

agent.run(
    "Write a haiku about academic research"
)

Now run the code with a really long prompt. GETing content of a really long Wiki page will do. Again, be sure to change PROFILE_NAME.

import boto3
import requests
from griptape.drivers import (
    AmazonBedrockPromptDriver,
    BedrockLlamaPromptModelDriver
)
from griptape.structures import Agent

REGION='us-west-2'
PROFILE_NAME=<CHANGE ME>

session = boto3.Session(region_name=REGION, profile_name=PROFILE_NAME)

prompt_driver = AmazonBedrockPromptDriver(
    model="meta.llama2-70b-chat-v1",
    prompt_model_driver=BedrockLlamaPromptModelDriver(),
    session=session,
)

agent = Agent(prompt_driver=prompt_driver)

import requests
long_content = requests.get("https://en.wikipedia.org/wiki/Barack_Obama?action=raw").content

agent.run(
    f"{long_content}"
)

Expected behavior I expect GripTape to truncate the response to the maximum number of tokens allowed by Bedrock Llama2. Griptape does this for other Bedrock prompt drivers, such as BedrockClaudePromptModelDriver()

Screenshots If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • OS: MacOS
  • Version: Ventura 13.6.4
  • griptape: 0.22.3
  • boto3: 1.34.11
  • botocore: 1.34.11

Additional context Add any other context about the problem here.

njedema avatar Feb 28 '24 00:02 njedema

Thanks for the report @njedema. I'm able to reproduce on my end; will keep you updated on a solution.

collindutter avatar Feb 28 '24 00:02 collindutter

Thanks!

njedema avatar Feb 29 '24 01:02 njedema

After thinking about this a bit more, this will actually be an issue with all Drivers. It just so happens that Llama has a relatively small context window.

Do we want to truncate prompt input to models? It might hide an underlying issue with the input data being too large.

CC @andrewfrench @vasinov

collindutter avatar Mar 21 '24 19:03 collindutter

If we do then we need to log a warning. I am on the fence about this leaning towards not truncating and explicitly failing. IMO, should be on the user to truncate. Another option is adding an optional truncate_input parameter to models.

vasinov avatar Mar 21 '24 19:03 vasinov