haystack
haystack copied to clipboard
PromptBuilder yields huggingface_hub.errors.ValidationError: Input validation error: `inputs` must have less than 4095 tokens. Given: 4701
Describe the bug
When using a standard RAG pipeline I get the above error.
Error message
File "/home/felix/PycharmProjects/anychat/src/anychat/analysis/rag.py", line 124, in query_rag_in_document_store
result = self.llm_pipeline.run(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/core/pipeline/pipeline.py", line 197, in run
res = comp.run(**last_inputs[name])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/components/generators/hugging_face_api.py", line 187, in run
return self._run_non_streaming(prompt, generation_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/components/generators/hugging_face_api.py", line 211, in _run_non_streaming
tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/huggingface_hub/inference/_client.py", line 2061, in text_generation
raise_text_generation_error(e)
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/huggingface_hub/inference/_common.py", line 457, in raise_text_generation_error
raise exception from http_error
huggingface_hub.errors.ValidationError: Input validation error: `inputs` must have less than 4095 tokens. Given: 4701
Expected behavior
My expectation would be that there is a truncation built in that truncates the input so that not too many tokens are passed to the model. Ideally the input should be truncated not at the end of the prompt (in which case the question would be truncated) but at a specific part (e.g., instead of using all tokens of my top_k=10 documents but truncating those).
To Reproduce
query_template = """Beantworte die Frage basierend auf dem nachfolgenden Kontext und Chatverlauf. Antworte so detailliert wie möglich, aber nur mit Informationen aus dem Kontext. Wenn du eine Antwort nicht weißt, sag, dass du sie nicht kennst. Wenn sich eine Frage nicht auf den Kontext bezieht, sag, dass du die Frage nicht beantworten kannst und höre auf. Stelle nie selbst eine Frage.
Kontext:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Vorheriger Chatverlauf:
{{ history }}
Frage: {{ question }}
Antwort: """
def _create_generator(self):
if AnyChatConfig.hf_use_local_generator:
return HuggingFaceLocalGenerator(
model=self.model_id,
task="text2text-generation",
device=ComponentDevice.from_str("cuda:0"),
huggingface_pipeline_kwargs={
"device_map": "auto",
"model_kwargs": {
"load_in_4bit": True,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
},
generation_kwargs={"max_new_tokens": 350},
)
else:
return HuggingFaceAPIGenerator(
api_type="text_generation_inference",
api_params={"url": AnyChatConfig.hf_api_generator_url},
)
def create_llm_pipeline(self, document_store):
"""
Creates an LLM pipeline that employs RAG on a document store that must have been set up before.
:return:
"""
# create the pipeline with the individual components
self.llm_pipeline = Pipeline()
self.llm_pipeline.add_component(
"embedder",
SentenceTransformersTextEmbedder(
model=DocumentManager.embedding_model_id,
device=ComponentDevice.from_str(
AnyChatConfig.hf_device_rag_text_embedder
),
),
)
self.llm_pipeline.add_component(
"retriever",
InMemoryEmbeddingRetriever(document_store=document_store, top_k=8),
)
self.llm_pipeline.add_component(
"prompt_builder", PromptBuilder(template=query_template)
)
self.llm_pipeline.add_component("llm", self._create_generator())
# connect the individual nodes to create the final pipeline
self.llm_pipeline.connect("embedder.embedding", "retriever.query_embedding")
self.llm_pipeline.connect("retriever", "prompt_builder.documents")
self.llm_pipeline.connect("prompt_builder", "llm")
def _get_formatted_history(self):
history = ""
for message in self.conversation_history:
history += f"{message[0]}: {message[1]}\n"
history = history.strip()
return history
def query_rag_in_document_store(self, query):
"""
Uses the LLM and RAG to provide an answer to the given query based on the documents in the document store.
:param query:
:return:
"""
logger.debug("querying using rag with: {}", query)
# run the query through the pipeline
result = self.llm_pipeline.run(
{
"embedder": {"text": query},
"prompt_builder": {
"question": query,
"history": self._get_formatted_history(),
},
"llm": {"generation_kwargs": {"max_new_tokens": 350}},
}
)
response = result["llm"]["replies"][0]
post_processed_response = self._post_process_response(response, query)
logger.debug(post_processed_response)
self.conversation_history.append(("Frage", query))
self.conversation_history.append(("Antwort", post_processed_response))
return post_processed_response
FAQ Check
- [x] Have you had a look at our new FAQ page?
System:
- OS: debian
- GPU/CPU: NVIDIA RTX A6000 (CUDA 12.4)
- Haystack version (commit or version number): 2.2.1
- DocumentStore: InMemoryDocumentStore
- Reader: PyPDFToDocument
- Retriever: InMemoryEmbeddingRetriever
Related to #6593
Thank you @fhamborg for the suggestion to truncate specific parts of the prompt. We are tracking this with #6593
Regarding the error caused by the input length, you could set a maximum length for truncation as part of the generation_kwargs of the HuggingFaceAPIGenerator. Does that work for you as a workaround? https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
Thanks @julian-risch for the quick reply! As for setting the truncation parameter to some value, I guess while it would help to avoid the error above it would cut of the actual question in such cases where the input is too long (as the question is the last item in my prompt), which would be worse.
Is there a way to retrieve the actual input to the LLM (or rather the text that is converted to that input), i.e., the potentially truncated input? This way I could compare my full prompt and the actual one (after potential truncation) and if it in fact was truncated I could rerun the pipeline but with top_k for the retriever component set one lower, for example. Or would you think it'd be better to just catch the exception above and then rerun with decreased top_k?
EDIT: I just figured that the top_k parameter has to be set during creation of the pipeline, not during running it. So the above idea wouldn't work unfortunately (only if I recreated the pipeline each time the situation above occurs). Do you have any idea of how to both avoid the error above and also cutting of the question, other than setting top_k to a very low value (in which case still at some point, i.e., if the chat history get long, it would come up again)?
@julian-risch
Thank you @fhamborg for the suggestion to truncate specific parts of the prompt. We are tracking this with #6593 Regarding the error caused by the input length, you could set a maximum length for truncation as part of the
generation_kwargsof theHuggingFaceAPIGenerator. Does that work for you as a workaround? https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
what's the exact naming of the parameter? the link does not contain a max_length and also not a truncation parameter.
My usecase is slightly different as I'm trying to achieve getting around this bug with HuggingFaceLocalGenerator. Setting max_length (generation_kwargs) here only applies to the output length but won't truncate the input and thus crashes my application.