ragas icon indicating copy to clipboard operation
ragas copied to clipboard

Adding vLLM models support for evaluation

Open MaximeSongIdris opened this issue 5 months ago • 0 comments

Describe the Feature LlamaIndex and Langchain wrapped models are available for evaluation, Langchain contains a wrapper for vLLM but it is not the most up-to-date and has some issues. So it would be interesting to add a native support for vLLM.

Why is the feature important for you? I couldn't use vLLM with Langchain

Additional context I already wrote a code that works.

import os
os.environ["OUTLINES_CACHE_DIR"] = "vllm_cache"
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'  # fork method doesn't work with when cuda is already initialized

import typing as t

from ragas.metrics import LLMContextPrecisionWithReference, LLMContextRecall
from ragas import evaluate, EvaluationDataset
from vllm import AsyncLLMEngine, LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs


import uuid

from langchain_core.callbacks import Callbacks
from langchain_core.outputs import LLMResult, Generation
from langchain_core.prompt_values import PromptValue
from ragas.cache import CacheInterface
from ragas.llms import BaseRagasLLM
from ragas.run_config import RunConfig


class vLLMWrapper(BaseRagasLLM):
    """
    A wrapper class that adapts vLLM's inference engine to the Ragas-compatible BaseRagasLLM interface.

    This class enables using vLLM for scoring and evaluation tasks within the Ragas framework by implementing
    the `generate_text` and `agenerate_text` method that produces LangChain-compatible `LLMResult` objects.
    Source: https://github.com/explodinggradients/ragas/blob/main/ragas/src/ragas/llms/base.py#L123

    Attributes:
        llm: The vLLM model instance, typically created via `vllm.LLM(...)`.
        sampling_params: A `SamplingParams` object defining temperature, top_p, etc.
        run_config: Optional configuration for controlling how evaluations are executed.
        cache: Optional cache for storing/reusing model outputs.

    """

    def __init__(
        self,
        vllm_model,
        sampling_params,
        run_config: t.Optional[RunConfig] = None,
        cache: t.Optional[CacheInterface] = None,
    ):
        super().__init__(cache=cache)
        self.llm = vllm_model
        self.sampling_params = sampling_params
        
        if run_config is None:  # legacy code 
            run_config = RunConfig()
        self.set_run_config(run_config)

    def is_finished(self, response: LLMResult) -> bool:
        """
        Verify that generation finished correctly by looking at finish_reason.
        `response` contains the n outputs of a single input, thus:
            len(response.generations) == 1
            len(response.generations[0]) == n
        """
        is_finished_list = []
        for single_generation in response.generations[0]:
            # generation_info is provided with `finish_reason`
            finish_reason = single_generation.generation_info.get("finish_reason")
            is_finished_list.append(finish_reason == 'stop')

        # if all the n outputs finished correctly, return True
        return all(is_finished_list)

    def generate_text(
        self,
        prompt: PromptValue,
        n: int = 1,
        temperature: t.Optional[float] = None,
        stop: t.Optional[t.List[str]] = None,
        callbacks: Callbacks = None,
    ) -> LLMResult:
        """
        Generates a LangChain-compatible LLMResult from a PromptValue using vLLM.

        This method is designed to be compatible with the BaseRagasLLM interface. It uses the
        preconfigured vLLM engine and sampling parameters to produce completions for a given prompt.

        Args:
            prompt (PromptValue): The input prompt wrapped in a LangChain PromptValue.
            n (int): The number of outputs for the prompt.

        Returns:
            LLMResult: A LangChain LLMResult containing one Generation per prompt.
        """
        # expected arguments from BaseRagasLLM that is kept to have a compatible API
        temperature = None
        stop = None
        callbacks = None

        prompt = prompt.to_string()  # vLLM requires a text as an input
        sampling_params.n = n        # generate n outputs per input
        sampling_params.best_of = n

        # vLLM engine will always produce a list[vllm.outputs.RequestOutput]
        # since we only have 1 prompt, the list has 1 entry
        vllm_result = self.llm.generate(prompt, self.sampling_params)[0]
        
        # LangChain's LLMResult expects a list of lists:
        # - The outer list corresponds to each input prompt
        # - The inner list contains one or more Generations per prompt (e.g. multiple outputs for a single input)
        # - We register the reason why the generation ended: stop, length, abort
        # source: https://docs.vllm.ai/en/stable/api/vllm/v1/engine/index.html?h=finish_reason#vllm.v1.engine.FINISH_REASON_STRINGS
        generations = [
            [Generation(text=output.text.strip(), generation_info={'finish_reason': output.finish_reason}) for output in vllm_result.outputs]
        ]
        ragas_expected_result = LLMResult(generations=generations)

        return ragas_expected_result

    async def agenerate_text(
        self,
        prompt: PromptValue,
        n: int = 1,
        temperature: t.Optional[float] = None,
        stop: t.Optional[t.List[str]] = None,
        callbacks: Callbacks = None,
    ) -> LLMResult:
        
        # expected arguments from BaseRagasLLM that is kept to have a compatible API
        temperature = None
        stop = None
        callbacks = None

        prompt = prompt.to_string()     # vLLM requires a text as an input
        sampling_params.n = n           # generate n outputs per input
        sampling_params.best_of = n
        request_id = str(uuid.uuid4())  # id used for tracking purpose
        # non-blocking calls, create a request to vLLM engine
        results_generator = self.llm.generate(prompt, self.sampling_params, request_id=request_id)
        
        # waiting for results from AsyncLLMEngine, while waiting, other coroutines will be working
        vllm_result= None
        async for request_output in results_generator:
            vllm_result = request_output
        
        # LangChain's LLMResult expects a list of lists:
        # - The outer list corresponds to each input prompt
        # - The inner list contains one or more Generations per prompt (e.g. multiple outputs for a single input)
        # - We register the reason why the generation ended: stop, length, abort
        # source: https://docs.vllm.ai/en/stable/api/vllm/v1/engine/index.html?h=finish_reason#vllm.v1.engine.FINISH_REASON_STRINGS
        generations = [
            [Generation(text=output.text.strip(), generation_info={'finish_reason': output.finish_reason}) for output in vllm_result.outputs]
        ]
        ragas_expected_result = LLMResult(generations=generations)

        return ragas_expected_result

    def set_run_config(self, run_config: RunConfig):
        self.run_config = run_config

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(llm={self.llm.__class__.__name__}(...))"

if __name__ == "__main__":
    
    judge_llm = AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(
            model='meta-llama/Llama-3.1-8B-Instruct',
            task='generate',              # generation task
            tokenizer_mode="auto",        # use fast implementation in Rust if available
            skip_tokenizer_init=False,    # use tokenizer and detokenizer
            trust_remote_code=False,
            dtype='bfloat16',
            quantization=None,            # "awq", "gptq", "fp8"
            compilation_config=3,
            enforce_eager=False,          # eager mode disable CUDA graph
            max_seq_len_to_capture=8192,  # fall back to eager mode if seq_len > 8192
            tensor_parallel_size=4,
            gpu_memory_utilization=0.9,
            swap_space=4,                 # offloading of KV cache if necessary
            cpu_offload_gb=0,             # no model offloading on CPU
            max_model_len=32768,          # set max number of tokens per sequence (prefill + decode) to 2^15
            enable_chunked_prefill=True,  # split batch of sequences to avoid potential vRAM overflow
            max_num_batched_tokens=32768  # threshold on total input tokens (num_seq * seq_len) before chunking
        )
    )  # Load LLM with tensor_parallel_size=4 & bf16 quantization (expected 4 GPUs)
    
    
    metrics = [  # these metrics relies either on a LLM or on an embedding model (EMB)
        LLMContextPrecisionWithReference(),   # Do we have all relevant contexts ? (LLM)
        LLMContextRecall(),                   # Do we have only relevant contexts or is there a lot of noises ? (LLM)
    ]
    
    
    dataset_ragas = [
        {
            'user_input': "What is the effect of adding stearic acid to PAO4 lubricant on the viscous damping friction contribution in steel/steel contacts?",
            'retrieved_contexts': [
                "The viscous damping friction value ߤଵ corresponding to PAO4 and 150NS lubricated and loaded steel/steel contacts is considered as negligible at any temperature since its corresponding value is less than that of the apparatus.",
                "When additives are used in PAO4 lubricant, the damping behavior of the system is reduced.",
                "Three different additives are tested with PAO4: oleic acid, linoleic acid, and stearic acid. Results show that linoleic acid has the lowest µ0; however, stearic acid has the lowest ߤଵ."
            ],
            'response': "Adding stearic acid to PAO4 lubricant reduces the viscous damping friction contribution in steel/steel contacts.",
            'reference': "Stearic acid reduces viscous damping friction more than oleic acid and linoleic acid, as it results in the lowest ߤଵ value among tested additives."
        }
    ]
    
    dataset_ragas = EvaluationDataset.from_list(dataset_ragas)
    
    # Load Judge LLM with its sampling parameters
    sampling_params = SamplingParams(
        temperature=0.6,
        top_p=0.9,
        max_tokens=2024,
    )  # Base sampling params from: https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct/blob/main/generation_config.json
    
    results = evaluate(dataset=dataset_ragas,
                       metrics=metrics,
                       llm=vLLMWrapper(judge_llm, sampling_params),
                       raise_exceptions=False,
                       run_config=RunConfig(timeout=180, max_retries=5, max_wait=60, max_workers=30),
                       show_progress=True)

MaximeSongIdris avatar Jul 03 '25 07:07 MaximeSongIdris