langchain icon indicating copy to clipboard operation
langchain copied to clipboard

`langchain.llms.base.BaseLLM` has a bug

Open creatorrr opened this issue 1 year ago • 3 comments

Here, instead of:

        generations = [existing_prompts[i] for i in range(len(prompts))]

it should be

        generations = [
            existing_prompts[i]
            for i in range(len(prompts))
            if i not in missing_prompt_idxs
        ]

@hwchase17 lemme know if I am missing something.

creatorrr avatar Jan 04 '23 05:01 creatorrr

nope - we want to return results for ALL inputs, either through generation or caching

for example:

inputs = ["foo", "bar"]

lets presume missing_prompts_idxs = [0] (this means tht result at index 0 was NOT in the cache and we need to generate it)

we want to return results for both "foo" and "bar", even though one is in the cache and the other isnt, so we iterate over lengths of all the prompts

hwchase17 avatar Jan 04 '23 07:01 hwchase17

it throws an error from time to time IndexError. Noticed it while using the QAGenerateChain

creatorrr avatar Jan 04 '23 07:01 creatorrr

Example code:

texts = essay_texts + wiki_texts + webpages

examples = []

num_batches = 10
batches = np.array_split(texts, num_batches)

for batch in tqdm_notebook(batches):
    docs = [
        {"doc": t} for t in batch.tolist()
    ]
    
    examples += example_gen_chain.apply_and_parse(docs)
    
    time.sleep(10)

Stacktrace:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[11], line 17
     12 for batch in tqdm_notebook(batches):
     13     docs = [
     14         {"doc": t} for t in batch.tolist()
     15     ]
---> 17     examples += example_gen_chain.apply_and_parse(docs)
     19     time.sleep(10)

File ~/.cache/pypoetry/virtualenvs/lm-bot-w2c4PdJl-py3.8/lib/python3.8/site-packages/langchain/chains/llm.py:116, in LLMChain.apply_and_parse(self, input_list)
    112 def apply_and_parse(
    113     self, input_list: List[Dict[str, Any]]
    114 ) -> Sequence[Union[str, List[str], Dict[str, str]]]:
    115     """Call apply and then parse the results."""
--> 116     result = self.apply(input_list)
    117     if self.prompt.output_parser is not None:
    118         new_result = []

File ~/.cache/pypoetry/virtualenvs/lm-bot-w2c4PdJl-py3.8/lib/python3.8/site-packages/langchain/chains/llm.py:75, in LLMChain.apply(self, input_list)
     73 def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
     74     """Utilize the LLM generate method for speed gains."""
---> 75     response = self.generate(input_list)
     76     outputs = []
     77     for generation in response.generations:
     78         # Get the text of the top generated string.

File ~/.cache/pypoetry/virtualenvs/lm-bot-w2c4PdJl-py3.8/lib/python3.8/site-packages/langchain/chains/llm.py:70, in LLMChain.generate(self, input_list)
     66         raise ValueError(
     67             "If `stop` is present in any inputs, should be present in all."
     68         )
     69     prompts.append(prompt)
---> 70 response = self.llm.generate(prompts, stop=stop)
     71 return response

File ~/.cache/pypoetry/virtualenvs/lm-bot-w2c4PdJl-py3.8/lib/python3.8/site-packages/langchain/llms/base.py:70, in BaseLLM.generate(self, prompts, stop)
     68     prompt = prompts[i]
     69     langchain.llm_cache.update(prompt, llm_string, result)
---> 70 generations = [existing_prompts[i] for i in range(len(prompts))]
     71 return LLMResult(generations=generations, llm_output=new_results.llm_output)

File ~/.cache/pypoetry/virtualenvs/lm-bot-w2c4PdJl-py3.8/lib/python3.8/site-packages/langchain/llms/base.py:70, in <listcomp>(.0)
     68     prompt = prompts[i]
     69     langchain.llm_cache.update(prompt, llm_string, result)
---> 70 generations = [existing_prompts[i] for i in range(len(prompts))]
     71 return LLMResult(generations=generations, llm_output=new_results.llm_output)

KeyError: 25

creatorrr avatar Jan 04 '23 07:01 creatorrr

@creatorrr you were completely right - i think i fixed with this pr: https://github.com/hwchase17/langchain/pull/538

added unit test so hopefully doesnt happen again

thank you for flagging!!!

hwchase17 avatar Jan 04 '23 22:01 hwchase17