mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-13 20:10:45 +00:00

* remove symbols under the haystack.document_stores namespace * Update haystack/document_stores/types/protocol.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * fix * same for retrievers * leftovers * more leftovers * add relnote * leftovers * one more * fix examples --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
153 lines
5.6 KiB
Python
153 lines
5.6 KiB
Python
from typing import List, Any, Optional, Dict
|
|
|
|
import logging
|
|
from pprint import pprint
|
|
|
|
from haystack import Pipeline, Document, component, default_to_dict, default_from_dict, DeserializationError
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
|
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
|
from haystack.components.generators import OpenAIGenerator
|
|
from haystack.components.builders.prompt_builder import PromptBuilder
|
|
from haystack.components.others import Multiplexer
|
|
from haystack.components.routers.conditional_router import ConditionalRouter
|
|
from haystack.core.component.types import Variadic
|
|
|
|
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
|
|
@component
|
|
class PaginatedRetriever:
|
|
"""
|
|
This component is used to paginate the results of a retriever.
|
|
It is useful when the retriever returns a large number of documents, and we want to pass them to the LLM
|
|
in batches.
|
|
|
|
It is useful in cases where the LLM's context length is limited, and we want to avoid passing too many
|
|
documents to it at once.
|
|
"""
|
|
|
|
def __init__(self, retriever: Any, page_size: int = 1, top_k: int = 100):
|
|
self.retriever = retriever
|
|
self.page_size = page_size
|
|
self.top_k = top_k
|
|
self.retrieved_documents = None
|
|
|
|
def to_dict(self):
|
|
return default_to_dict(self, retriever=self.retriever.to_dict(), page_size=self.page_size)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data):
|
|
if not "retriever" in data["init_parameters"]:
|
|
raise DeserializationError("Missing required field 'retriever' in SlidingWindowRetriever")
|
|
|
|
retriever_data = data["init_parameters"]["retriever"]
|
|
if "type" not in retriever_data:
|
|
raise DeserializationError("Missing 'type' in retriever's serialization data")
|
|
if retriever_data["type"] not in component.registry:
|
|
raise DeserializationError(f"Component type '{retriever_data['type']}' not found")
|
|
retriever_class = component.registry[retriever_data["type"]]
|
|
|
|
data["init_parameters"]["retriever"] = retriever_class.from_dict(retriever_data)
|
|
return default_from_dict(cls, data)
|
|
|
|
@component.output_types(documents=List[Document])
|
|
def run(
|
|
self,
|
|
query: Variadic[str],
|
|
top_k: Optional[int] = None,
|
|
filters: Optional[Dict[str, Any]] = None,
|
|
scale_score: Optional[bool] = None,
|
|
):
|
|
if not top_k:
|
|
top_k = self.top_k
|
|
|
|
if self.retrieved_documents is None:
|
|
self.retrieved_documents = self.retriever.run(
|
|
query=query[0], filters=filters, top_k=top_k, scale_score=scale_score # type: ignore
|
|
)["documents"]
|
|
|
|
if not self.retrieved_documents:
|
|
raise ValueError("No more documents available :(")
|
|
|
|
next_page = self.retrieved_documents[: self.page_size]
|
|
self.retrieved_documents = self.retrieved_documents[self.page_size :]
|
|
return {"documents": next_page}
|
|
|
|
|
|
def self_correcting_pipeline():
|
|
# Create the RAG pipeline
|
|
rag_pipeline = Pipeline(max_loops_allowed=10)
|
|
rag_pipeline.add_component(instance=Multiplexer(str), name="query_multiplexer")
|
|
rag_pipeline.add_component(
|
|
instance=PaginatedRetriever(InMemoryBM25Retriever(document_store=InMemoryDocumentStore())), name="retriever"
|
|
)
|
|
rag_pipeline.add_component(
|
|
instance=PromptBuilder(
|
|
template="""
|
|
Given these documents, answer the question.
|
|
If the documents don't provide enough information to answer the question, answer with the string "UNKNOWN".
|
|
|
|
Documents:
|
|
{% for doc in documents %}
|
|
{{ doc.content }}
|
|
{% endfor %}
|
|
|
|
Question: {{question}}
|
|
Answer:
|
|
"""
|
|
),
|
|
name="prompt_builder",
|
|
)
|
|
rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm")
|
|
rag_pipeline.add_component(
|
|
instance=ConditionalRouter(
|
|
routes=[
|
|
{
|
|
"condition": "{{ 'UNKNOWN' in replies|join(' ') }}",
|
|
"output": "{{ query }}",
|
|
"output_name": "unanswered_query",
|
|
"output_type": str,
|
|
},
|
|
{
|
|
"condition": "{{ 'UNKNOWN' not in replies|join(' ') }}",
|
|
"output": "{{ replies }}",
|
|
"output_name": "replies",
|
|
"output_type": List[str],
|
|
},
|
|
]
|
|
),
|
|
name="answer_checker",
|
|
)
|
|
|
|
rag_pipeline.connect("query_multiplexer", "retriever")
|
|
rag_pipeline.connect("query_multiplexer", "prompt_builder.question")
|
|
rag_pipeline.connect("query_multiplexer", "answer_checker.query")
|
|
rag_pipeline.connect("retriever", "prompt_builder.documents")
|
|
rag_pipeline.connect("prompt_builder", "llm")
|
|
rag_pipeline.connect("llm.replies", "answer_checker.replies")
|
|
rag_pipeline.connect("answer_checker.unanswered_query", "query_multiplexer")
|
|
|
|
# Draw the pipeline
|
|
rag_pipeline.draw("self_correcting_pipeline.png")
|
|
|
|
# Populate the document store
|
|
documents = [
|
|
Document(content="My name is Jean and I live in Paris."),
|
|
Document(content="My name is Mark and I live in Berlin."),
|
|
Document(content="My name is Giorgio and I live in Rome."),
|
|
Document(content="My name is Juan and I live in Madrid."),
|
|
]
|
|
rag_pipeline.get_component("retriever").retriever.document_store.write_documents(documents)
|
|
|
|
# Query and assert
|
|
question = "Who lives in Germany?"
|
|
|
|
result = rag_pipeline.run({"query_multiplexer": {"value": question}})
|
|
|
|
pprint(result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
self_correcting_pipeline()
|