langcorn
langcorn copied to clipboard
Allow to override `retrieval_chain` flag (for custom chains with multi-key output)
Title says all. Current implementation assumes len(chain.output_keys) > 1
is a retrieval chain, but that is not always the case, since it can be any other type of chain (like a custom one which won't have the source
key).
Or alternatively, the detection may be made more strict by checking the actual output keys
Hi @nb-programmer, do you have a code example that reproduces this issue?
Sure:
demo_bug.py
from langchain.chains.base import Chain
from langchain.callbacks.manager import CallbackManagerForChainRun
from typing import Dict, Optional, Any, List
class CustomChain(Chain):
input_key: str = "input"
output_key: str = "output"
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
print("input:", inputs)
return {self.output_key: "Hello", "other": "test"}
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
""":meta private:"""
return [self.output_key, "other"]
chain = CustomChain()
$ langcorn server demo_bug:chain
Then call the endpoint with any input, eg:
{
"input": "test"
}
500 Error: Internal Server Error
File "venv\lib\site-packages\langcorn\server\api.py", line 118, in handler
source_documents=[str(t) for t in output.get("source_documents")],
TypeError: 'NoneType' object is not iterable
Expected: It should return all keys, or just the output key without exception
Thx for the example @nb-programmer . Fixed the bug in https://github.com/msoedov/langcorn/commit/e803e5d24500e0edd6a56724e1177e097c14a165 . Going to release a new version today