langchain
langchain copied to clipboard
RFC: Improve LLMChain output type with strong parsing support
As you all know, one of the big issue with LLM is the output form (type / format / syntax...).
I propose to add a new chain, LLMChainWithValidator, that accept a validator, a function that try parse the output / validate it's syntax, and return the appropriate type.
Here are example of possible validator:
- boolean: "true" -> True
- age: "34" -> 34 (parse int + check 0 < 130)
- data structure
- json
- SQL ...
The chain also accept two others parameters: correct_on_error and retry.
correct_on_errortry to correct the output with the LLMretryis the number of possible retry of the LLM
The chain return None if it didn't succeed to return a validated output.
Some notes:
- I initially wanted to add validator directly in LLMChain, but I figured that it was probably best to start with a simple custom chain.
- So far, this implementation doesn't handle multiple call / responses. It added too much complexities due to the retry / correct logic.
Examples
import os
from langchain.chains.llm_validator import VALIDATORS, LLMChainWithValidator
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
llm = OpenAI()
template = """What is the price of {product}?"""
prompt = PromptTemplate(
template=template,
input_variables=["product"],
)
chain = LLMChainWithValidator(
llm=llm,
prompt=prompt,
output_key="response",
validator=int,
retry=2,
correct_on_error=True,
)
chain(
{
"product": "an iphone",
}
)
# > {'product': 'an iphone', 'response': 699}
prompt = PromptTemplate(
template="generate a json of top 3 french cities",
input_variables=[]
)
chain = LLMChainWithValidator(
llm=llm,
prompt=prompt,
output_key="response",
validator=VALIDATORS["json"],
retry=2,
correct_on_error=True
)
chain({})
# > {'response': {'Top 5 French Cities': [{'City': 'Paris', 'Population': 2244000},{'City': 'Marseille', 'Population': 861635},{'City': 'Lyon', 'Population': 495268}]}}
@hwchase17 I made some change based on your suggestion.
- use OutputParser
- merge (retry and correct) into LLMChain
- ...and I modify the new chain to be about the correction only (LLMCorrectChain).
Here is an example how that look like now
import os
from langchain.chains import APIChain, LLMChain, LLMRequestsChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.prompts.base import BaseOutputParser
llm = OpenAI(openai_api_key=OPENAI_API_KEY)
class IntegerParser(BaseOutputParser):
def parse(self, text: str) -> int:
return int(text)
template = """What is the price of {product}?"""
prompt = PromptTemplate(
template=template,
input_variables=["product"],
output_parser=IntOutputParser()
)
chain = LLMChain(
llm=llm, prompt=prompt,
output_key="response",
retries=2,
correct_on_error=True
)
chain({"product": "an iphone"}, parse=True)
What do you think? (better I go any further)
Following our discussion @hwchase17
- rename parameters (now parser_retries & parser_correct_on_error)
- add possibility to use function as output_parser
- make parse=True the default (when there is an output_parser...)
import os
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
llm = OpenAI(openai_api_key=OPENAI_API_KEY)
def parse_integer(output):
return int(output)
template = """What is the price of {product}?"""
prompt = PromptTemplate(
template=template,
input_variables=["product"],
output_parser=parse_integer # or could be directly "int" here
)
chain = LLMChain(
llm=llm,
prompt=prompt,
output_key="response",
parser_retries=2,
parser_correct_on_error=True
)
chain({"product": "an iphone"}) # => 699
chain({"product": "an iphone"}, parse=False) # "The price of an iPhone depends on the model and the retailer. Generally, the starting price for an iPhone 11 is around $699
chain.predict(product="an iphone") # "The price of an iPhone...
chain.predict_and_parse(product="an iphone") # 699
Note 1: I need to add async support. Note 2: This doesn't support list call (chain.apply...). This is because it would complicate the issue; what should we do if only 2/3 of parsing fail ? should we retry for those two ? Note 3: We also need to change all return type to any... ? I don't know what we can do so it doesn't add complexity in all the codes. Is there a way to link it to the type of the prompt output_parser ?
closing as stale