RAG Evaluation with Prometheus 2


      

Evaluating the responses of Language Models and LLM-based applications often involves using model-based metrics that do not require ground truth labels. Large proprietary models like GPT-4 and Claude 3 Opus are frequently employed as evaluators and demonstrate a good correlation with human evaluations.

However, relying on closed models poses several challenges:

  • fairness: the training data of these models is unknown.
  • controllability: the behavior of these models can change unpredictably.
  • data privacy: sending data to external providers may raise privacy concerns.
  • affordability: using these powerful models can be expensive.

Using open models for evaluation is an active research area, but their practical use is often limited. They typically do not correlate well with human judgments and lack flexibility.

🔥 Prometheus 2 is a new family of open-source models designed to address these gaps:

  • two variants, respectively fine-tuned from Mistral-7B and Mixtral8x7B
  • trained on open-source data
  • demonstrate high correlation with human evaluations and proprietary models
  • highly flexible: capable of performing direct assessments and pairwise rankings, and allowing the definition of custom evaluation criteria.

In this experimental notebook, we will use Prometheus 2 to evaluate the responses of a RAG pipeline.

First, we will build the RAG pipeline and collect some results. Then, we will code a custom Prometheus Evaluator component for Haystack. Finally, we will initialize three different evaluators and run them in an evaluation pipeline.

Create the RAG pipeline to evaluate

We want to use Prometheus 2 to evaluate the answers generated by a RAG, so we first need to build our RAG Pipeline.

This part is quite similar to the “Evaluating RAG Pipelines” tutorial. Take a look at it for more details.

If you want, you can simply read this section. We will provide the generated data for later evaluation steps.

!pip install haystack-ai datasets sentence-transformers accelerate huggingface_hub bitsandbytes

We will be using a labeled PubMed dataset with questions, contexts and answers. This allows us to use the contexts as Documents and provides the necessary labeled data for some of the evaluation metrics we will define.

In this example, we will use the first 100 rows.

First, let’s fetch the dataset and extract all_documents, all_questions and all_ground_truth_answers.

from datasets import load_dataset
from haystack import Document

dataset = load_dataset("vblagoje/PubMedQA_instruction", split="train")
dataset = dataset.select(range(100))
all_documents = [Document(content=doc["context"]) for doc in dataset]
all_questions = [doc["instruction"] for doc in dataset]
all_ground_truth_answers = [doc["response"] for doc in dataset]

Indexing pipeline

Next, let’s build a simple indexing pipeline and write the documents into a Document Store.

from haystack import Pipeline
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy

document_store = InMemoryDocumentStore()

document_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
document_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)

indexing = Pipeline()
indexing.add_component(instance=document_embedder, name="document_embedder")
indexing.add_component(instance=document_writer, name="document_writer")

indexing.connect("document_embedder.documents", "document_writer.documents")

indexing.run({"document_embedder": {"documents": all_documents}})

RAG pipeline

Now that we have our data ready, we can create a simple RAG pipeline.

In this example, we’ll be using:

  • InMemoryEmbeddingRetriever to retrieve the relevant documents for the query.
  • HuggingFaceLocalGenerator with google/gemma-1.1-2b-it to generate answers to queries. It is a small model, and later we will evaluate the quality of the generated responses based on custom criteria.
import os
from getpass import getpass
from haystack import Pipeline
from haystack.components.builders import AnswerBuilder, PromptBuilder
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.generators import HuggingFaceLocalGenerator
from haystack.utils import ComponentDevice


# to access Gemma
# 1. you need a Hugging Face account
# 2. you have to accept Google conditions: https://huggingface.co/google/gemma-1.1-2b-it
# 3. copy your HF token (https://huggingface.co/settings/tokens) and paste it below
os.environ["HF_API_TOKEN"] = getpass("Your Hugging Face token")

generator = HuggingFaceLocalGenerator(
    "google/gemma-1.1-2b-it",
    huggingface_pipeline_kwargs={"device_map": "auto"},
    device=ComponentDevice.from_str("cuda:0"),
)

template = """
<bos><start_of_turn>user
You have to answer the following question based on the given context information only.

Context:
{% for document in documents %}
    {{ document.content }}
{% endfor %}

Question: {{question}}
Answer:<end_of_turn>
<start_of_turn>model"""

rag_pipeline = Pipeline()
rag_pipeline.add_component(
    "query_embedder",
    SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
)
rag_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store, top_k=3))
rag_pipeline.add_component("prompt_builder", PromptBuilder(template=template))
rag_pipeline.add_component("generator", generator)
rag_pipeline.add_component("answer_builder", AnswerBuilder())

rag_pipeline.connect("query_embedder", "retriever.query_embedding")
rag_pipeline.connect("retriever", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder", "generator")
rag_pipeline.connect("generator.replies", "answer_builder.replies")
rag_pipeline.connect("retriever", "answer_builder.documents")

You can try the RAG pipeline by asking a question:

question = "Do high levels of procalcitonin in the early phase after pediatric liver transplantation indicate poor postoperative outcome?"

response = rag_pipeline.run(
    {
        "query_embedder": {"text": question},
        "prompt_builder": {"question": question},
        "answer_builder": {"query": question},
    }
)
print(response["answer_builder"]["answers"][0].data)
response["answer_builder"]

Run the RAG pipeline and save results

Let’s run our RAG pipeline with a set of questions, and make sure to save the data we need for evaluation: questions, ground truth answers and generated answers.

  • In this example, we will use 10 random questions.
  • In the evaluation part, we will not evaluate the retrieved context, so we will not save it. However, you can choose to consider context in the evaluation: as we will see later, evaluation with Prometheus is very customizable.
import random

questions, ground_truth_answers = zip(*random.sample(list(zip(all_questions, all_ground_truth_answers)), 10))
rag_answers = []

for question in list(questions):
    results = rag_pipeline.run(
        {
            "query_embedder": {"text": question},
            "prompt_builder": {"question": question},
            "answer_builder": {"query": question},
        }
    )

    rag_answers.append(results["answer_builder"]["answers"][0].data)
results = {
    "questions": questions,
    "ground_truth_answers": ground_truth_answers,
    "rag_answers": rag_answers,
}
import json

with open("gemma_2b_rag_results.json", "w") as fo:
    json.dump(results, fo)

Evaluation with Prometheus 2

After the preparation work, we can use Prometheus 2 to evaluate the responses generated along several desired axes.

This model expects a prompt like the one below and returns a text containing feedback and a score.

###Task Description:
An instruction (might include an Input inside it), a response to evaluate, a reference answer that gets a score of 5, and a score rubric representing a evaluation criteria are given.
1. Write a detailed feedback that assess the quality of the response strictly based on the given score rubric, not evaluating in general.
2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
3. The output format should look as follows: \"Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)\"
4. Please do not generate any other opening, closing, and explanations.

###The instruction to evaluate:
{orig_instruction}

###Response to evaluate:
{orig_response}

###Reference Answer (Score 5):
{orig_reference_answer}

###Score Rubrics:
[{orig_criteria}]
Score 1: {orig_score1_description}
Score 2: {orig_score2_description}
Score 3: {orig_score3_description}
Score 4: {orig_score4_description}
Score 5: {orig_score5_description}

###Feedback:

Create a Prometheus Evaluator component

To perform evaluation, we create a custom Haystack Evaluator component. In Haystack, it is easy to create custom components, and we can implement Prometheus Evaluator with just a few lines of code.

Design choices

Our implementation is hacky and and directed at experimentation, but some choices are worth explaining.

  • the component is inspired and extends our LLMEvaluator, but with specific adaptations for Prometheus

  • init parameters

    • template: Prometheus is highly customizable, so we can easily create different evaluators with different prompt templates
    • inputs: The inputs that the evaluator expects and that it evaluates. They should match those defined in the template.
    • generator: (hacky) allows passing different types of Haystack generators to use the Prometheus model. Examples: HuggingFaceLocalGenerator, LlamaCPPGenerator, etc.
  • run method: for each example to evaluate, the inputs are integrated into the prompt and passed to the model; then the model output is parsed to extract score and feedback. This method returns a dictionary containing an aggregate score, individual_scores and feedbacks.

from typing import Any, Dict, List, Tuple, Type
from haystack import component
from haystack.components.evaluators import LLMEvaluator
from haystack.components.builders import PromptBuilder
from tqdm import tqdm
from numpy import mean as np_mean


ABS_SYSTEM_PROMPT = (
    "You are a fair judge assistant tasked with providing clear, objective feedback based on "
    "specific criteria, ensuring each assessment reflects the absolute standards set for performance."
)


@component
class PrometheusLLMEvaluator(LLMEvaluator):
    def __init__(
        self,
        generator,
        template: str,
        inputs: List[Tuple[str, Type[List]]],
        progress_bar: bool = True,
    ):
        outputs = ["feedback", "score"]
        self.validate_init_parameters(inputs, outputs, [])
        self.inputs = inputs
        self.outputs = outputs

        self._builder = PromptBuilder(template=template)
        self._generator = generator
        self.progress_bar = progress_bar

        component.set_input_types(self, **dict(inputs))

    def _parse_output(self, output):
        feedback, _, score_str = output.rpartition("[RESULT]")
        feedback = feedback.rpartition("###Feedback: [/INST]")[-1].strip()
        score_str = score_str.strip()

        score = None
        if score_str.isdigit() and score_str in ["1", "2", "3", "4", "5"]:
            score = int(score_str)
        return feedback, score

    @component.output_types(score=float, individual_scores=List[float], feedbacks=List[str])
    def run(self, **inputs) -> Dict[str, Any]:
        self.validate_input_parameters(dict(self.inputs), inputs)

        # inputs is a dictionary with keys being input names and values being a list of input values
        # We need to iterate through the lists in parallel for all keys of the dictionary
        input_names, values = inputs.keys(), list(zip(*inputs.values()))
        list_of_input_names_to_values = [dict(zip(input_names, v)) for v in values]

        individual_scores, feedbacks = [], []
        for input_names_to_values in tqdm(list_of_input_names_to_values, disable=not self.progress_bar):
            
            partial_prompt = self._builder.run(**input_names_to_values)["prompt"]
            prompt = f"[INST] {ABS_SYSTEM_PROMPT}\n{partial_prompt} [/INST]"
            
            output = self._generator.run(prompt=prompt)["replies"][0]

            feedback, individual_score = self._parse_output(output)
            if individual_score is not None:
                individual_scores.append(individual_score)
            feedbacks.append(feedback)
        score = np_mean(individual_scores)

        return {
            "score": score,
            "individual_scores": individual_scores,
            "feedbacks": feedbacks,
        }

Load the Prometheus 2 model

We are going to use prometheus-7b-v2.0: the smallest variant of Prometheus 2, which can run on a standard Colab notebook with 8-bit quantization.

In particular, we will use the model via HuggingFaceLocalGenerator, based on the Transformers library.

The generation_kwargs simply replicate those used in the prometheus-eval library. For practical applications, it would be worth experimenting and seeing if there is a better combination of parameters that provides good evaluation performance and reproducibility.

As mentioned earlier, there are several other options for running this open model with Haystack:

  • resource-constrained environments: [LlamaCPPGenerator] (can run on CPU-only environments thanks to the GGUF quantized format; example commented below)
  • in production, with available GPU resources: TGI (via HuggingFaceAPIGenerator), vLLM.
# if you have previously run the RAG pipeline, you will probably need to restart
# the kernel in order to free up GPU memory

from haystack.components.generators import HuggingFaceLocalGenerator

generator = HuggingFaceLocalGenerator(
    model="prometheus-eval/prometheus-7b-v2.0",
    task="text2text-generation",
    huggingface_pipeline_kwargs={
        "device_map": "auto",
        "model_kwargs": {"load_in_8bit": True},
    },
    generation_kwargs={
        "max_new_tokens": 512,
        "temperature": 1.0,
        "do_sample": True,
        "repetition_penalty": 1.03,
        "top_p": 0.9,
    },
)

generator.warm_up()
# UNCOMMENT THE FOLLOWING LINES TO USE llama.cpp
# You can also choose a model with a different quantization: you will lose some quality in exchange with using less resources and being faster

# ! pip install haystack-ai llama-cpp-haystack

# from haystack_integrations.components.generators.llama_cpp import LlamaCppGenerator
# from huggingface_hub import hf_hub_download

# prometheus_path = hf_hub_download(
#             repo_id="AlekseiPravdin/prometheus-7b-v2_0-gguf", filename="prometheus-7b-v2_0.q8_0.gguf", repo_type="model"
# )

# generator = LlamaCppGenerator(
#     model=prometheus_path,
#     n_ctx=8192,
#     n_batch=512,
# 	  generation_kwargs={"max_tokens": 512, "temperature": 1.0, "do_sample":True, "repeat_penalty": 1.03, "top_p": 0.9},
# )
# generator.warm_up()

Initialize different Prometheus Evaluators

We will define 3 prompt templates and corresponding Prometheus Evaluators:

  • Correctness: Evaluates the generated answer considering both relevance to the question and similarity to the ground truth answer.
  • Response Relevance: Evaluates the generated answer in terms of its relevance to the user’s question.
  • Logical Robustness: Evaluates the logical organization and progression of the response.

As shown, by customizing the prompt model, a diverse range of evaluators can be created.

In general, the first section (Task Description) should be left intact. the only aspect to be changed, as illustrated in the following examples, is whether or not to use a reference answer.

⚠️ Although these evaluator names may be similar to evaluation metrics used in Haystack or other libraries, it is important to understand that they are created specifically for Prometheus and produce scores between 1 and 5. They are not comparable to conceptually similar but differently defined metrics.

correctness_prompt_template = """
###Task Description
An instruction (might include an Input inside it), a response to evaluate, a reference answer that gets a score of 5, and a score rubric representing a evaluation criteria are given.
1. Write a detailed feedback that assesses the quality of the response strictly based on the given score rubric, not evaluating in general.
2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (1 or 2 or 3 or 4 or 5)"
4. Please do not generate any other opening, closing, and explanations.

###The instruction to evaluate:
Your task is to evaluate the generated answer against the reference answer for the question: {{query}}

###Response to evaluate:
generated answer: {{generated_answer}}

###Reference Answer (Score 5): {{reference_answer}}

###Score Rubrics:
Score 1: The answer is not relevant to the question and does not align with the reference answer.
Score 2: The answer is relevant to the question but deviates significantly from the reference answer.
Score 3: The answer is relevant to the question and generally aligns with the reference answer but has errors or omissions.
Score 4: The answer is relevant to the question and closely matches the reference answer but is less concise or clear.
Score 5: The answer is highly relevant, fully accurate, and matches the reference answer in both content and clarity.

###Feedback:""".strip()

correctness_evaluator = PrometheusLLMEvaluator(
    template=correctness_prompt_template,
    generator=generator,
    inputs=[
        ("query", List[str]),
        ("generated_answer", List[str]),
        ("reference_answer", List[str]),
    ],
)



response_relevance_prompt_template = """
###Task Description
An instruction (might include an Input inside it), a response to evaluate, and a score rubric representing a evaluation criteria are given.
1. Write a detailed feedback that assess the quality of the response strictly based on the given score rubric, not evaluating in general.
2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"
4. Please do not generate any other opening, closing, and explanations.

###The instruction to evaluate:
Your task is to evaluate whether the generated answer is relevant to the question: {{query}}

###Response to evaluate:
generated answer: {{generated_answer}}

Score 1: The generated answer is off-topic or irrelevant to the question asked.
Score 2: The generated answer includes some relevant information but often contains unrelated details.
Score 3: The generated answer is generally relevant to the question but occasionally includes extraneous or off-topic details.
Score 4: The generated answer is mostly relevant to the question, with minimal unrelated information.
Score 5: The generated answer is highly relevant to the question, addressing it directly and thoroughly without including unnecessary information.

###Feedback:""".strip()

response_relevance_evaluator = PrometheusLLMEvaluator(
    template=response_relevance_prompt_template,
    generator=generator,
    inputs=[("query", List[str]), ("generated_answer", List[str])],
)



logical_robustness_prompt_template = """
###Task Description
An instruction (might include an Input inside it), a response to evaluate, and a score rubric representing a evaluation criteria are given.
1. Write a detailed feedback that assess the quality of the response strictly based on the given score rubric, not evaluating in general.
2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"
4. Please do not generate any other opening, closing, and explanations.

###The instruction to evaluate:
Your task is to evaluate how logically the generated answer for the question is organized, ensuring a clear progression of ideas and arguments that are easy to follow. question:{{query}}

###Response to evaluate:
generated answer: {{generated_answer}}

###Score Rubrics:
Score 1: Disorganized, lacks clear structure, and is difficult to follow.
Score 2: Some structure, but inconsistent and hard to follow due to abrupt transitions.
Score 3: Generally organized with minor flow issues and occasional unclear connections.
Score 4: Well-organized with clear and smooth transitions, easy to follow.
Score 5: Excellently organized with flawless logical flow and seamless transitions.

###Feedback:""".strip()

logical_robustness_evaluator = PrometheusLLMEvaluator(
    template=logical_robustness_prompt_template,
    generator=generator,
    inputs=[("query", List[str]), ("generated_answer", List[str])],
)

Let’s try the logical_robustness_evaluator

query = [
    "Are group 2 innate lymphoid cells ( ILC2s ) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?",
    "Does poor sleep predict symptoms of depression and disability retirement due to depression?",
]
generated_answer = [
    "As ILC2s are elevated in patients with CRSwNP, they may drive nasal polyp formation in CRS. ILC2s are also linked with high tissue and blood eosinophilia and have a potential role in the activation and survival of eosinophils during the Th2 immune response. The association of innate lymphoid cells in CRS provides insights into its pathogenesis.",
    "Lack of baseline diagnostic interviews; sleep quality based on self-report.",
]


res = logical_robustness_evaluator.run(query=query, generated_answer=generated_answer)
res

Ok, nice!

Evaluation pipeline

We can now add our evaluators to an Evaluation pipeline and run the pipeline with our RAG results.

from haystack import Pipeline

eval_pipeline = Pipeline()
eval_pipeline.add_component("correctness_evaluator", correctness_evaluator)
eval_pipeline.add_component("response_relevance_evaluator", response_relevance_evaluator)
eval_pipeline.add_component("logical_robustness_evaluator", logical_robustness_evaluator)

Let’s download the RAG results. If you have run the RAG pipeline, you can skip the next cell.

# skip this cell if you have run the RAG pipeline before

!wget "https://raw.githubusercontent.com/deepset-ai/haystack-cookbook/main/data/prometheus2_evaluation/gemma_2b_rag_results.json"
import json

with open("gemma_2b_rag_results.json", "r") as fin:
    rag_results = json.load(fin)

questions = rag_results["questions"]
ground_truth_answers = rag_results["ground_truth_answers"]
rag_answers = rag_results["rag_answers"]
eval_results = eval_pipeline.run(
    {
        "correctness_evaluator": {
            "query": questions,
            "generated_answer": rag_answers,
            "reference_answer": ground_truth_answers,
        },
        "response_relevance_evaluator": {
            "query": questions,
            "generated_answer": rag_answers,
        },
        "logical_robustness_evaluator": {
            "query": questions,
            "generated_answer": rag_answers,
        },
    }
)

Evaluation results

Once we’ve run our evaluation pipeline, we can also create a full evaluation report. Haystack provides an EvaluationRunResult which we can use to display a score_report.

from haystack.evaluation.eval_run_result import EvaluationRunResult

inputs = {
    "question": questions,
    "answer": ground_truth_answers,
    "predicted_answer": rag_answers,
}

evaluation_result = EvaluationRunResult(run_name="pubmed_rag_pipeline", inputs=inputs, results=eval_results)
evaluation_result.score_report()

In general, in our small sample, Gemma-1.1-2b-it seems to generate relevant answers, but the responses are different from ground truth answers and the logical organization is not optimal.

Let’s inspect the specific metrics in a dataframe.

import pandas as pd

# do not truncate text
pd.set_option("display.max_colwidth", None)

results_df = evaluation_result.to_pandas()
results_df

Since Prometheus provides a feedback for each evaluation, it can be interesting to take a look at them.

eval_results["logical_robustness_evaluator"]["feedbacks"]

📚 References