TGI crashes with complex json schemas provided as grammar without any information (on debug/trace level)
System Info
Tech stack: tgi 2.0.1, A100 GPU 80GB running on Kubernetes. Model: Mixtral 8x7B-instruct-v0.1
Information
- [x] Docker
- [ ] The CLI directly
Tasks
- [x] An officially supported command
- [ ] My own modifications
Reproduction
Running the following query crashed TGI on the tech stack above. Would be great if someone could reproduce the issue.
Running the example without using the grammar parameter works as expected.
# Define the response format: Start
from enum import Enum
from pydantic import BaseModel, Field
from typing import Optional
from text_generation.types import GrammarType, Grammar
class Gender(str, Enum):
male = 'male'
female = 'female'
diverse = 'diverse'
class SmokerStatus(str, Enum):
smoker = "yes"
non_smoker = "no"
class AllFields(BaseModel):
CompanyName: Optional[str] = Field(None, description="Name of the insurance company. "
"Correct answer can never be a bank,"
"comparison portal")
ApplicationDate: Optional[str] = Field(None, description="Date when the application form was signed.")
NameInsuredPerson: Optional[str] = Field(None, description="First name of Insured Person.")
SurnameInsuredPerson: Optional[str] = Field(None, description="For German applications, include"
"Dr. title as part of the surname if applicable. "
"Dr. does not need to be denoted further"
"such as in Dr.med. Other titles must not appear.")
DateOfBirthInsuredPerson: Optional[str] = None
CompanyReference: Optional[str] = Field(None, description="ID assigned by the company to the"
"insurance application filed. The answer must"
"not be an IBAN number.")
Occupation: Optional[str] = None
MaritalStatus: Optional[str] = None
Sex: Optional[Gender] = None
Height: Optional[float] = Field(None, gt=1.0)
Weight: Optional[float] = Field(None, gt=1.0)
Smoker: Optional[SmokerStatus] = None
NamePolicyHolder: Optional[str] = Field(None, description="First name of Policyholder.")
SurnamePolicyHolder: Optional[str] = Field(None, description="For German applications, include"
"Dr. title as part of the surname if applicable. "
"Dr. does not need to be denoted further"
"such as in Dr.med.")
Name: Optional[str] = Field(None, description="Name of insurance product.")
StartDate: Optional[str] = Field(None, description="Start date of insurance contract.")
Term: Optional[str] = Field(None, description="Duration of insurance contract.")
MonthlyPension: Optional[float] = Field(None, gt=1.0)
YearlyPension: Optional[float] = Field(None, gt=1.0)
SumAssured: Optional[float] = Field(None, gt=1.0)
BenefitAppliedFor: Optional[float] = Field(None, gt=1.0)
ExistingCover: Optional[float] = Field(None, gt=1.0, description="Sum assured (cover) of existing insurance contracted earlier.")
response_format = AllFields.model_json_schema()
response_type = GrammarType.Json
response_grammar = Grammar(type=response_type, value=response_format)
# Define the response format: End
# Define the query
specify_query = "Deine Aufgabe ist, es die wichtigsten Daten und Angaben aus dem Antrag zusammenzufassen. Hier ist der Text aus dem Antrag, den du zusammenfassen sollst: "
txt = """Hier beginnt Seite 1 im Dokument:
Betreff: Ergänzungen zum Antrag Müller / RiLV / 123456788 An: [email protected] Datum: 2024-03-01 08:49:10
Der Versicherungsnehmer Max Müller, geboren am 1. Februar 1982, ist von Beruf Rechtsanwalt.
Hier beginnt Seite 2 im Dokument:
Abschluss einer Risikolebensversicherung bei der Allgemeine Versicherungs AG in Höhe von 800.000€. Versicherte Person ist Max Müller. Versicherungsnehmerin ist die Ehefrau Marina Müller.
Versicherungsstart ist 01.04.2024. Die Versicherung läuft bis zum 65. Lebensjahr."""
PROMPT = f"""[INST] {specify_query + txt} [/INST]"""
# Call the client using the grammar parameter
# llm_client is an instance of the text_generation Python client
llm_client.generate(PROMPT,max_new_tokens=450, grammar=response_grammar)
We are calling tgi as follows:
"text-generation-server"
- "download-weights"
- "--revision"
- "125c431e2ff41a156b9f9076f744d2f35dd6e67a"
- "mistralai/Mixtral-8x7B-Instruct-v0.1"
text-generation-launcher \
--model-id mistralai/Mixtral-8x7B-Instruct-v0.1 \
--num-shard 1 \
--quantize bitsandbytes-fp4 \
--max-total-tokens 32000 \
--max-batch-size 1 \
--max-client-batch-size 1 \
--max-input-tokens 16000 \
--max-concurrent-requests 1 \
--json-output \
--trust-remote-code \
--revision 125c431e2ff41a156b9f9076f744d2f35dd6e67a \
--env
We are restricting the batch size to 1 was because we anticipated an issue with mixed requests using grammar or not, but it made no difference.
In the debug logs is no error message whatsoever. The only thing that helps is to restart the pod.
Expected behavior
Code should reduce a valid json. It does so if a certain share of fields in the pydantic class are removed. However, it does not depend which fields. So having 5-6 of them is fine.
Hey @o1iv3r thanks for sharing! I'll try to reproduce soon and share an update here.
FYI, I think it might be a problem in the outlines library which also doesn't work for me with a large number of fields.
Hi! The LLM has to worry about generating the JSON as well as the fields in the schema, I think that's the issue. Grammar works 99% of the time really well with smaller schemas. I have to admit I've never seen a schema so long, but the use-case is absolutely something that should work effectively. I've been doing some reading around schema based generation and I came across this article from Lamini here... it looks like they present to the LLM the pre-filled JSON, this saves on compute plus all the LLM has to do is generate the field contents. The schema parsing would never fail this way. @drbh I'm not totally sure on the implementation currently in TGI but I'm assuming the LLM is also generating the JSON right now. Is there scope to implement something like this going forwards? I can see great benefit in this if so :) Thanks.
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.