mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	 e1ec4e5e4d
			
		
	
	
		e1ec4e5e4d
		
			
		
	
	
	
	
		
			
			* remove symbols under the haystack.document_stores namespace * Update haystack/document_stores/types/protocol.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * fix * same for retrievers * leftovers * more leftovers * add relnote * leftovers * one more * fix examples --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
		
			
				
	
	
		
			153 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			153 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import List, Any, Optional, Dict
 | |
| 
 | |
| import logging
 | |
| from pprint import pprint
 | |
| 
 | |
| from haystack import Pipeline, Document, component, default_to_dict, default_from_dict, DeserializationError
 | |
| from haystack.document_stores.in_memory import InMemoryDocumentStore
 | |
| from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
 | |
| from haystack.components.generators import OpenAIGenerator
 | |
| from haystack.components.builders.prompt_builder import PromptBuilder
 | |
| from haystack.components.others import Multiplexer
 | |
| from haystack.components.routers.conditional_router import ConditionalRouter
 | |
| from haystack.core.component.types import Variadic
 | |
| 
 | |
| 
 | |
| logging.getLogger().setLevel(logging.DEBUG)
 | |
| 
 | |
| 
 | |
| @component
 | |
| class PaginatedRetriever:
 | |
|     """
 | |
|     This component is used to paginate the results of a retriever.
 | |
|     It is useful when the retriever returns a large number of documents, and we want to pass them to the LLM
 | |
|     in batches.
 | |
| 
 | |
|     It is useful in cases where the LLM's context length is limited, and we want to avoid passing too many
 | |
|     documents to it at once.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, retriever: Any, page_size: int = 1, top_k: int = 100):
 | |
|         self.retriever = retriever
 | |
|         self.page_size = page_size
 | |
|         self.top_k = top_k
 | |
|         self.retrieved_documents = None
 | |
| 
 | |
|     def to_dict(self):
 | |
|         return default_to_dict(self, retriever=self.retriever.to_dict(), page_size=self.page_size)
 | |
| 
 | |
|     @classmethod
 | |
|     def from_dict(cls, data):
 | |
|         if not "retriever" in data["init_parameters"]:
 | |
|             raise DeserializationError("Missing required field 'retriever' in SlidingWindowRetriever")
 | |
| 
 | |
|         retriever_data = data["init_parameters"]["retriever"]
 | |
|         if "type" not in retriever_data:
 | |
|             raise DeserializationError("Missing 'type' in retriever's serialization data")
 | |
|         if retriever_data["type"] not in component.registry:
 | |
|             raise DeserializationError(f"Component type '{retriever_data['type']}' not found")
 | |
|         retriever_class = component.registry[retriever_data["type"]]
 | |
| 
 | |
|         data["init_parameters"]["retriever"] = retriever_class.from_dict(retriever_data)
 | |
|         return default_from_dict(cls, data)
 | |
| 
 | |
|     @component.output_types(documents=List[Document])
 | |
|     def run(
 | |
|         self,
 | |
|         query: Variadic[str],
 | |
|         top_k: Optional[int] = None,
 | |
|         filters: Optional[Dict[str, Any]] = None,
 | |
|         scale_score: Optional[bool] = None,
 | |
|     ):
 | |
|         if not top_k:
 | |
|             top_k = self.top_k
 | |
| 
 | |
|         if self.retrieved_documents is None:
 | |
|             self.retrieved_documents = self.retriever.run(
 | |
|                 query=query[0], filters=filters, top_k=top_k, scale_score=scale_score  # type: ignore
 | |
|             )["documents"]
 | |
| 
 | |
|         if not self.retrieved_documents:
 | |
|             raise ValueError("No more documents available :(")
 | |
| 
 | |
|         next_page = self.retrieved_documents[: self.page_size]
 | |
|         self.retrieved_documents = self.retrieved_documents[self.page_size :]
 | |
|         return {"documents": next_page}
 | |
| 
 | |
| 
 | |
| def self_correcting_pipeline():
 | |
|     # Create the RAG pipeline
 | |
|     rag_pipeline = Pipeline(max_loops_allowed=10)
 | |
|     rag_pipeline.add_component(instance=Multiplexer(str), name="query_multiplexer")
 | |
|     rag_pipeline.add_component(
 | |
|         instance=PaginatedRetriever(InMemoryBM25Retriever(document_store=InMemoryDocumentStore())), name="retriever"
 | |
|     )
 | |
|     rag_pipeline.add_component(
 | |
|         instance=PromptBuilder(
 | |
|             template="""
 | |
|     Given these documents, answer the question.
 | |
|     If the documents don't provide enough information to answer the question, answer with the string "UNKNOWN".
 | |
| 
 | |
|     Documents:
 | |
|     {% for doc in documents %}
 | |
|         {{ doc.content }}
 | |
|     {% endfor %}
 | |
| 
 | |
|     Question: {{question}}
 | |
|     Answer:
 | |
|     """
 | |
|         ),
 | |
|         name="prompt_builder",
 | |
|     )
 | |
|     rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm")
 | |
|     rag_pipeline.add_component(
 | |
|         instance=ConditionalRouter(
 | |
|             routes=[
 | |
|                 {
 | |
|                     "condition": "{{ 'UNKNOWN' in replies|join(' ') }}",
 | |
|                     "output": "{{ query }}",
 | |
|                     "output_name": "unanswered_query",
 | |
|                     "output_type": str,
 | |
|                 },
 | |
|                 {
 | |
|                     "condition": "{{ 'UNKNOWN' not in replies|join(' ') }}",
 | |
|                     "output": "{{ replies }}",
 | |
|                     "output_name": "replies",
 | |
|                     "output_type": List[str],
 | |
|                 },
 | |
|             ]
 | |
|         ),
 | |
|         name="answer_checker",
 | |
|     )
 | |
| 
 | |
|     rag_pipeline.connect("query_multiplexer", "retriever")
 | |
|     rag_pipeline.connect("query_multiplexer", "prompt_builder.question")
 | |
|     rag_pipeline.connect("query_multiplexer", "answer_checker.query")
 | |
|     rag_pipeline.connect("retriever", "prompt_builder.documents")
 | |
|     rag_pipeline.connect("prompt_builder", "llm")
 | |
|     rag_pipeline.connect("llm.replies", "answer_checker.replies")
 | |
|     rag_pipeline.connect("answer_checker.unanswered_query", "query_multiplexer")
 | |
| 
 | |
|     # Draw the pipeline
 | |
|     rag_pipeline.draw("self_correcting_pipeline.png")
 | |
| 
 | |
|     # Populate the document store
 | |
|     documents = [
 | |
|         Document(content="My name is Jean and I live in Paris."),
 | |
|         Document(content="My name is Mark and I live in Berlin."),
 | |
|         Document(content="My name is Giorgio and I live in Rome."),
 | |
|         Document(content="My name is Juan and I live in Madrid."),
 | |
|     ]
 | |
|     rag_pipeline.get_component("retriever").retriever.document_store.write_documents(documents)
 | |
| 
 | |
|     # Query and assert
 | |
|     question = "Who lives in Germany?"
 | |
| 
 | |
|     result = rag_pipeline.run({"query_multiplexer": {"value": question}})
 | |
| 
 | |
|     pprint(result)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     self_correcting_pipeline()
 |