haystack/examples/rag/rag_self_correction.py

153 lines
5.5 KiB
Python
Raw Normal View History

from typing import List, Any, Optional, Dict
import logging
from pprint import pprint
from canals.component.types import Variadic
from haystack import Pipeline, Document, component, default_to_dict, default_from_dict, DeserializationError
from haystack.document_stores import InMemoryDocumentStore
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.components.generators import OpenAIGenerator
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.others import Multiplexer
from haystack.components.routers.conditional_router import ConditionalRouter
logging.getLogger().setLevel(logging.DEBUG)
@component
class PaginatedRetriever:
"""
This component is used to paginate the results of a retriever.
It is useful when the retriever returns a large number of documents, and we want to pass them to the LLM
in batches.
It is useful in cases where the LLM's context length is limited, and we want to avoid passing too many
documents to it at once.
"""
def __init__(self, retriever: Any, page_size: int = 1, top_k: int = 100):
self.retriever = retriever
self.page_size = page_size
self.top_k = top_k
self.retrieved_documents = None
def to_dict(self):
return default_to_dict(self, retriever=self.retriever.to_dict(), page_size=self.page_size)
@classmethod
def from_dict(cls, data):
if not "retriever" in data["init_parameters"]:
raise DeserializationError("Missing required field 'retriever' in SlidingWindowRetriever")
retriever_data = data["init_parameters"]["retriever"]
if "type" not in retriever_data:
raise DeserializationError("Missing 'type' in retriever's serialization data")
if retriever_data["type"] not in component.registry:
raise DeserializationError(f"Component type '{retriever_data['type']}' not found")
retriever_class = component.registry[retriever_data["type"]]
data["init_parameters"]["retriever"] = retriever_class.from_dict(retriever_data)
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(
self,
query: Variadic[str],
top_k: Optional[int] = None,
filters: Optional[Dict[str, Any]] = None,
scale_score: Optional[bool] = None,
):
if not top_k:
top_k = self.top_k
if self.retrieved_documents is None:
self.retrieved_documents = self.retriever.run(
query=query[0], filters=filters, top_k=top_k, scale_score=scale_score
)["documents"]
if not self.retrieved_documents:
raise ValueError("No more documents available :(")
next_page = self.retrieved_documents[: self.page_size]
self.retrieved_documents = self.retrieved_documents[self.page_size :]
return {"documents": next_page}
def self_correcting_pipeline():
# Create the RAG pipeline
rag_pipeline = Pipeline(max_loops_allowed=10)
rag_pipeline.add_component(instance=Multiplexer(str), name="query_multiplexer")
rag_pipeline.add_component(
instance=PaginatedRetriever(InMemoryBM25Retriever(document_store=InMemoryDocumentStore())), name="retriever"
)
rag_pipeline.add_component(
instance=PromptBuilder(
template="""
Given these documents, answer the question.
If the documents don't provide enough information to answer the question, answer with the string "UNKNOWN".
Documents:
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
Question: {{question}}
Answer:
"""
),
name="prompt_builder",
)
rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm")
rag_pipeline.add_component(
instance=ConditionalRouter(
routes=[
{
"condition": "{{ 'UNKNOWN' in replies|join(' ') }}",
"output": "{{ query }}",
"output_name": "unanswered_query",
"output_type": str,
},
{
"condition": "{{ 'UNKNOWN' not in replies|join(' ') }}",
"output": "{{ replies }}",
"output_name": "replies",
"output_type": List[str],
},
]
),
name="answer_checker",
)
rag_pipeline.connect("query_multiplexer", "retriever")
rag_pipeline.connect("query_multiplexer", "prompt_builder.question")
rag_pipeline.connect("query_multiplexer", "answer_checker.query")
rag_pipeline.connect("retriever", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder", "llm")
rag_pipeline.connect("llm.replies", "answer_checker.replies")
rag_pipeline.connect("answer_checker.unanswered_query", "query_multiplexer")
# Draw the pipeline
rag_pipeline.draw("self_correcting_pipeline.png")
# Populate the document store
documents = [
Document(content="My name is Jean and I live in Paris."),
Document(content="My name is Mark and I live in Berlin."),
Document(content="My name is Giorgio and I live in Rome."),
Document(content="My name is Juan and I live in Madrid."),
]
rag_pipeline.get_component("retriever").retriever.document_store.write_documents(documents)
# Query and assert
question = "Who lives in Germany?"
result = rag_pipeline.run({"query_multiplexer": {"value": question}})
pprint(result)
if __name__ == "__main__":
self_correcting_pipeline()