Advanced RAG: Query Expansion
Last Updated: September 20, 2024
by Tuana Celik ( LI, Twitter/X)
In this cookbook, you’ll learn how to implement query expansion for RAG. Query expansion consists of asking an LLM to produce a number of similar queries to a user query. We are then able to use each of these queries in the retrieval process, increasing the number and relevance of retrieved documents.
!pip install haystack-ai wikipedia
import wikipedia
import json
from typing import List, Optional
from haystack import Pipeline, component
from haystack.components.builders import PromptBuilder
from haystack.components.generators import OpenAIGenerator
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.components.writers import DocumentWriter
from haystack.dataclasses import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
import os
from getpass import getpass
if "OPENAI_API_KEY" not in os.environ:
os.environ['OPENAI_API_KEY'] = getpass("Your OpenAI API Key: ")
The Process of Query Expansion
First, let’s create a QueryExpander
. This component is going to be able to create a number
(defaults to 5) of additional queries, similar to the original user query. It returns queries
that has the original query + number
of similar queries.
@component
class QueryExpander:
def __init__(self, prompt: Optional[str] = None, model: str = "gpt-3.5-turbo"):
self.query_expansion_prompt = prompt
self.model = model
if prompt == None:
self.query_expansion_prompt = """
You are part of an information system that processes users queries.
You expand a given query into {{ number }} queries that are similar in meaning.
Structure:
Follow the structure shown below in examples to generate expanded queries.
Examples:
1. Example Query 1: "climate change effects"
Example Expanded Queries: ["impact of climate change", "consequences of global warming", "effects of environmental changes"]
2. Example Query 2: ""machine learning algorithms""
Example Expanded Queries: ["neural networks", "clustering", "supervised learning", "deep learning"]
Your Task:
Query: "{{query}}"
Example Expanded Queries:
"""
builder = PromptBuilder(self.query_expansion_prompt)
llm = OpenAIGenerator(model = self.model)
self.pipeline = Pipeline()
self.pipeline.add_component(name="builder", instance=builder)
self.pipeline.add_component(name="llm", instance=llm)
self.pipeline.connect("builder", "llm")
@component.output_types(queries=List[str])
def run(self, query: str, number: int = 5):
result = self.pipeline.run({'builder': {'query': query, 'number': number}})
expanded_query = json.loads(result['llm']['replies'][0]) + [query]
return {"queries": list(expanded_query)}
expander = QueryExpander()
expander.run(query="open source nlp frameworks", number=4)
Retrieval Without Query Expansion
documents = [
Document(content="The effects of climate are many including loss of biodiversity"),
Document(content="The impact of climate change is evident in the melting of the polar ice caps."),
Document(content="Consequences of global warming include the rise in sea levels."),
Document(content="One of the effects of environmental changes is the change in weather patterns."),
Document(content="There is a global call to reduce the amount of air travel people take."),
Document(content="Air travel is one of the core contributors to climate change."),
Document(content="Expect warm climates in Turkey during the summer period."),
]
doc_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
doc_store.write_documents(documents)
retriever = InMemoryBM25Retriever(document_store=doc_store, top_k=3)
retrieval_pipeline = Pipeline()
retrieval_pipeline.add_component("keyword_retriever", retriever)
query = "climate change"
retrieval_pipeline.run({"keyword_retriever":{ "query": query, "top_k": 3}})
Retrieval With Query Expansion
Now let’s have a look at what documents we are able to retrieve if we are to inluce query expansion in the process. For this step, let’s create a MultiQueryInMemoryBM25Retriever
that is able to use BM25 retrieval for each (expansded) query in turn.
This component also handles the same document being retrieved for multiple queries and will not return duplicates.
@component
class MultiQueryInMemoryBM25Retriever:
def __init__(self, retriever: InMemoryBM25Retriever, top_k: int = 3):
self.retriever = retriever
self.results = []
self.ids = set()
self.top_k = top_k
def add_document(self, document: Document):
if document.id not in self.ids:
self.results.append(document)
self.ids.add(document.id)
@component.output_types(documents=List[Document])
def run(self, queries: List[str], top_k: int = None):
if top_k != None:
self.top_k = top_k
for query in queries:
result = self.retriever.run(query = query, top_k = self.top_k)
for doc in result['documents']:
self.add_document(doc)
self.results.sort(key=lambda x: x.score, reverse=True)
return {"documents": self.results}
query_expander = QueryExpander()
retriever = MultiQueryInMemoryBM25Retriever(InMemoryBM25Retriever(document_store=doc_store))
expanded_retrieval_pipeline = Pipeline()
expanded_retrieval_pipeline.add_component("expander", query_expander)
expanded_retrieval_pipeline.add_component("keyword_retriever", retriever)
expanded_retrieval_pipeline.connect("expander.queries", "keyword_retriever.queries")
expanded_retrieval_pipeline.run({"expander": {"query": query}}, include_outputs_from=["expander"])
Query Expansion for RAG
Let’s start off by populating a document store with chunks of context from various Wikipedia pages.
def get_doc_store():
raw_docs = []
wikipedia_page_titles = ["Electric_vehicle", "Dam", "Electric_battery", "Tree", "Solar_panel", "Nuclear_power",
"Wind_power", "Hydroelectricity", "Coal", "Natural_gas", "Greenhouse_gas", "Renewable_energy",
"Fossil_fuel"]
for title in wikipedia_page_titles:
page = wikipedia.page(title=title, auto_suggest=False)
doc = Document(content=page.content, meta={"title": page.title, "url": page.url})
raw_docs.append(doc)
doc_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
indexing_pipeline = Pipeline()
indexing_pipeline.add_component("cleaner", DocumentCleaner())
indexing_pipeline.add_component("splitter", DocumentSplitter(split_by="passage", split_length=1))
indexing_pipeline.add_component("writer", DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP))
indexing_pipeline.connect("cleaner", "splitter")
indexing_pipeline.connect("splitter", "writer")
indexing_pipeline.run({"cleaner": {"documents": raw_docs}})
return doc_store
doc_store = get_doc_store()
RAG without Query Expansion
template = """
You are part of an information system that summarises related documents.
You answer a query using the textual content from the documents retrieved for the
following query.
You build the summary answer based only on quoting information from the documents.
You should reference the documents you used to support your answer.
###
Original Query: "{{query}}"
Retrieved Documents: {{documents}}
Summary Answer:
"""
retriever = InMemoryBM25Retriever(document_store=doc_store)
prompt_builder = PromptBuilder(template = template)
llm = OpenAIGenerator()
keyword_rag_pipeline = Pipeline()
keyword_rag_pipeline.add_component("keyword_retriever", retriever)
keyword_rag_pipeline.add_component("prompt", prompt_builder)
keyword_rag_pipeline.add_component("llm", llm)
keyword_rag_pipeline.connect("keyword_retriever.documents", "prompt.documents")
keyword_rag_pipeline.connect("prompt", "llm")
keyword_rag_pipeline.run({"query": "green energy sources", "top_k": 3}, include_outputs_from=["keyword_retriever"])
RAG with Query Expansion
template = """
You are part of an information system that summarises related documents.
You answer a query using the textual content from the documents retrieved for the
following query.
You build the summary answer based only on quoting information from the documents.
You should reference the documents you used to support your answer.
###
Original Query: "{{query}}"
Retrieved Documents: {{documents}}
Summary Answer:
"""
query_expander = QueryExpander()
retriever = MultiQueryInMemoryBM25Retriever(InMemoryBM25Retriever(document_store=doc_store))
prompt_builder = PromptBuilder(template = template)
llm = OpenAIGenerator()
query_expanded_rag_pipeline = Pipeline()
query_expanded_rag_pipeline.add_component("expander", query_expander)
query_expanded_rag_pipeline.add_component("keyword_retriever", retriever)
query_expanded_rag_pipeline.add_component("prompt", prompt_builder)
query_expanded_rag_pipeline.add_component("llm", llm)
query_expanded_rag_pipeline.connect("expander.queries", "keyword_retriever.queries")
query_expanded_rag_pipeline.connect("keyword_retriever.documents", "prompt.documents")
query_expanded_rag_pipeline.connect("prompt", "llm")
query_expanded_rag_pipeline.show()
query_expanded_rag_pipeline.run({"query": "green energy sources", "top_k": 3}, include_outputs_from=["keyword_retriever", "expander"])