llm-guard
llm-guard copied to clipboard
Callback error in Llama Index RAG pipeline integration for Node post processors
Describe the bug Hi, I have followed the tutorial for Llama Index RAG implementation for llm-guard (ref: https://llm-guard.com/tutorials/notebooks/llama_index_rag/). I am getting callback error (NoneType object has no attribute "callback manager"). I am using latest Llama Index package where I have passed callback manager at vector index creation rather than service context (which is getting depreciated in future). To Reproduce Steps to reproduce the behavior:
- All steps are same except I am not initializing service context and passing callback manager instead I am passing it at vector index creation stage.
- Here self._llamaindex_settings contains initialized callback_manager.
- Rest part of code I have kept the same, created the LLMGuardNodePostProcessor as per tutorial and fed it to query engine. For ref:
import logging
from typing import List, Optional
from llama_index.core.bridge.pydantic import Field
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
class LLMGuardNodePostProcessor(BaseNodePostprocessor):
scanners: List = Field(description="Scanner objects")
fail_fast: bool = Field(
description="If True, the postprocessor will stop after the first scanner failure.",
)
skip_scanners: List[str] = Field(
description="List of scanner names to skip when failed e.g. Anonymize.",
)
def __init__(
self,
scanners: List,
fail_fast: bool = True,
skip_scanners: List[str] = None,
) -> None:
if skip_scanners is None:
skip_scanners = []
try:
import llm_guard # noqa: F401
except ImportError:
raise ImportError(
"Cannot import llm_guard package, please install it: ",
"pip install llm-guard",
)
super().__init__(
scanners=scanners,
fail_fast=fail_fast,
skip_scanners=skip_scanners,
)
@classmethod
def class_name(cls) -> str:
return "LLMGuardNodePostProcessor"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
from llm_guard import scan_prompt
safe_nodes = []
for node_with_score in nodes:
node = node_with_score.node
sanitized_text, results_valid, results_score = scan_prompt(
self.scanners,
node.get_content(metadata_mode=MetadataMode.LLM),
self.fail_fast,
)
for scanner_name in self.skip_scanners:
results_valid[scanner_name] = True
if any(not result for result in results_valid.values()):
logger.warning(f"Node `{node.node_id}` is not valid, scores: {results_score}")
continue
node.set_content(sanitized_text)
safe_nodes.append(NodeWithScore(node=node, score=node_with_score.score))
return safe_nodes
For query engine:
query_engine = self.documents.index.as_query_engine(
similarity_top_k= similarity_top_k,
response_synthesizer= response_synthesizer,
llm=llamaindex_settings['llm'],
node_postprocessors=[llm_guard_postprocessor]
)
Screen Shot of Error Message
Can you pls help with issue pointed above.
Hey @nashugame , I updated the notebook and removed the usage of ServiceContext.