langchaingo
langchaingo copied to clipboard
SequentialChain removes the input of previuos call
I think it would be good to append to the inputs
map with the outputs
of each call instead of replacing it (creating a new inputs map only with the outputs of the last call).
The Call method of SequentialChain looks like this today:
func (c *SequentialChain) Call(ctx context.Context, inputs map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint:lll
var outputs map[string]any
var err error
for _, chain := range c.chains {
outputs, err = Call(ctx, chain, inputs, options...)
if err != nil {
return nil, err
}
// Set the input for the next chain to the output of the current chain
inputs = outputs
}
return outputs, nil
}
In this way, the second Chain does not have access to the inputs of the first Chain. This is possible on the Python Langchain like this:
code_prompt = PromptTemplate(
input_variables=["task", "language"],
template="Write a short {language} function that will {task}.",
)
test_prompt = PromptTemplate(
input_variables=["language", "code"],
template="Write a test for {language} code:\n{code}",
)
code_chain = LLMChain(llm=llml, prompt=code_prompt, output_key="code")
test_chain = LLMChain(llm=llmg, prompt=test_prompt, output_key="test")
chain = SequentialChain(
chains=[code_chain, test_chain],
input_variables=["task", "language"],
output_variables=["test", "code"],
)
Using Langchaingo, this will fail because the test chain (second chain) only has access to the outputs of the code chain (first chain). I updated the method to look like this, and it works.
func (c *SequentialChain) Call(ctx context.Context, inputs map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint:lll
var outputs map[string]any
var err error
for _, chain := range c.chains {
outputs, err = Call(ctx, chain, inputs, options...)
if err != nil {
return nil, err
}
// Append the output of this chain to the existing input
for key, value := range outputs {
inputs[key] = value
}
}
return outputs, nil
}
Maybe a check to avoid collision of key/values is necessary.
I had the same problem.
Refer to the example in rewriting the unit test as follows:
` testLLM1 := &testLanguageModel{expResult: "In the year 3000, chickens have taken over the world"} testLLM2 := &testLanguageModel{expResult: "An egg-citing adventure"} testLLM3 := &testLanguageModel{expResult: "test output"}
chain1 := chains.NewLLMChain(
testLLM1,
prompts.NewPromptTemplate("Write a story titled {{.title}} set in the year {{.year}}", []string{"title", "year"}),
)
chain1.OutputKey = "story"
chain2 := chains.NewLLMChain(testLLM2, prompts.NewPromptTemplate("Review this story: {{.story}}", []string{"story"}))
chain2.OutputKey = "review"
chain3 := chains.NewLLMChain(
testLLM3,
prompts.NewPromptTemplate("Please expand this story according to the story and review: {{.story}} {{.review}}", []string{"story", "review"}),
)
chs := []chains.Chain{chain1, chain2, chain3}
seqChain, err := chains.NewSequentialChain(chs, []string{"title", "year"}, []string{"text"})
require.NoError(t, err)
_, err = chains.Call(context.Background(), seqChain, map[string]any{"title": "Chicken Takeover", "year": 3000})
require.NoError(t, err)
` This test case cannot pass now, and similar examples can pass in the python version.