easyllm icon indicating copy to clipboard operation
easyllm copied to clipboard

Issue setting huggingface.prompt_builder = 'llama2' when using sagemaker as client

Open bcarsley opened this issue 2 years ago • 2 comments

So I'm building a class that can alternate between both the huggingface and sagemaker clients and I declare all my os.environs at the top of the class like so:

os.environ["AWS_ACCESS_KEY_ID"] = "<key_id>"
os.environ["AWS_SECRET_ACCESS_KEY"] = '<access_key>'
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
os.environ["HF_TOKEN"] = "<hf_token>"
os.environ["HUGGINGFACE_PROMPT"] = "llama2"

and even later on in the class, just to be sure, I declare huggingface.prompt_builder = 'llama2' tried importing build_llama2_prompt directly and passing it as a callable, that also didn't work tried setting sagemaker.prompt_builder = 'llama2' just for fun to see if that would do anything...nope

Still get the warning telling me I haven't set a prompt builder, which is kinda weird, plus it's clear that occasionally the prompt is being formatted a bit weirdly (because the same prompt formatted as in the example below when passed directly to the sagemaker endpoint yields a somewhat better response from the same endpoint)

it's nbd that this doesn't work super well for me, I might just be being stupid about it, below is how I've just worked around it by manually implementing w/ sagemaker's HuggingFacePredictor cls:

llm = sagemaker.huggingface.model.HuggingFacePredictor('llama-party', sess)
def build_llama2_prompt(messages):
    startPrompt = "<s>[INST] "
    endPrompt = " [/INST]"
    conversation = []
    for index, message in enumerate(messages):
        if message["role"] == "system" and index == 0:
            conversation.append(f"<<SYS>>\n{message['content']}\n<</SYS>>\n\n")
        elif message["role"] == "user":
            conversation.append(message["content"].strip())
        else:
            conversation.append(f" [/INST] {message.content}</s><s>[INST] ")

    return startPrompt + "".join(conversation) + endPrompt

prompt = build_llama2_prompt(messages)

payload = {
  "inputs":  prompt,
  "parameters": {
    "do_sample": True,
    "top_p": 0.6,
    "temperature": 0.9,
    "top_k": 50,
    "max_new_tokens": 512,
    "repetition_penalty": 1.03,
    "stop": ["</s>"]
  }
}

chat = llm.predict(payload)

print(chat[0]["generated_text"][len(prompt):])

this code was pretty much fully taken from the sagemaker llama deployment blog post here: https://www.philschmid.de/sagemaker-llama-llm

works fine, just don't know why the same code doesn't work right inside of the lib (easyllm)

bcarsley avatar Aug 17 '23 02:08 bcarsley