mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: Add RAG pipeline (#6461)
* add rag pipeline * Update examples/getting_started/rag.py Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> --------- Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com> Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
parent
4912f7cb58
commit
a38f871dbd
22
examples/getting_started/rag.py
Normal file
22
examples/getting_started/rag.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import os
|
||||||
|
from haystack import Document
|
||||||
|
from haystack.document_stores import InMemoryDocumentStore
|
||||||
|
from haystack.pipeline_utils import build_rag_pipeline
|
||||||
|
|
||||||
|
API_KEY = "SET YOUR OPENAI API KEY HERE"
|
||||||
|
|
||||||
|
# We support many different databases. Here we load a simple and lightweight in-memory document store.
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
|
||||||
|
# Create some example documents and add them to the document store.
|
||||||
|
documents = [
|
||||||
|
Document(content="My name is Jean and I live in Paris."),
|
||||||
|
Document(content="My name is Mark and I live in Berlin."),
|
||||||
|
Document(content="My name is Giorgio and I live in Rome."),
|
||||||
|
]
|
||||||
|
document_store.write_documents(documents)
|
||||||
|
|
||||||
|
# Let's now build a simple RAG pipeline that uses a generative model to answer questions.
|
||||||
|
rag_pipeline = build_rag_pipeline(llm_api_key=API_KEY, document_store=document_store)
|
||||||
|
answers = rag_pipeline.run(query="Who lives in Rome?")
|
||||||
|
print(answers.data)
|
3
haystack/pipeline_utils/__init__.py
Normal file
3
haystack/pipeline_utils/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from haystack.pipeline_utils.rag import build_rag_pipeline
|
||||||
|
|
||||||
|
__all__ = ["build_rag_pipeline"]
|
134
haystack/pipeline_utils/rag.py
Normal file
134
haystack/pipeline_utils/rag.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from haystack import Pipeline
|
||||||
|
from haystack.dataclasses import Answer
|
||||||
|
from haystack.document_stores import InMemoryDocumentStore
|
||||||
|
from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||||
|
from haystack.components.embedders import SentenceTransformersTextEmbedder
|
||||||
|
from haystack.components.generators import GPTGenerator
|
||||||
|
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||||
|
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||||
|
|
||||||
|
|
||||||
|
def build_rag_pipeline(
|
||||||
|
document_store: "InMemoryDocumentStore",
|
||||||
|
generation_model: str = "gpt-3.5-turbo",
|
||||||
|
prompt_template: Optional[str] = None,
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
llm_api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns a prebuilt pipeline to perform retrieval augmented generation with or without an embedding model
|
||||||
|
(without embeddings, it performs retrieval using BM25).
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from haystack.utils import build_rag_pipeline
|
||||||
|
pipeline = build_rag_pipeline(document_store=your_document_store_instance)
|
||||||
|
pipeline.run(query="What's the capital of France?")
|
||||||
|
|
||||||
|
>>> Answer(data="The capital of France is Paris.")
|
||||||
|
```
|
||||||
|
|
||||||
|
:param document_store: An instance of a DocumentStore to read from.
|
||||||
|
:param generation_model: The name of the model to use for generation.
|
||||||
|
:param prompt_template: The template to use for the prompt. If not given, a default template is used.
|
||||||
|
:param embedding_model: The name of the model to use for embedding. If not given, BM25 is used.
|
||||||
|
:param llm_api_key: The API key to use for the OpenAI Language Model. If not given, the value of the
|
||||||
|
llm_api_key will be attempted to be read from the environment variable OPENAI_API_KEY.
|
||||||
|
"""
|
||||||
|
return _RAGPipeline(
|
||||||
|
document_store=document_store,
|
||||||
|
generation_model=generation_model,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
llm_api_key=llm_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _RAGPipeline:
|
||||||
|
"""
|
||||||
|
A simple ready-made pipeline for RAG. It requires a populated document store.
|
||||||
|
|
||||||
|
If an embedding model is given, it uses embedding retrieval. Otherwise, it falls back to BM25 retrieval.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
rag_pipe = RAGPipeline(document_store=InMemoryDocumentStore())
|
||||||
|
answers = rag_pipe.run(query="Who lives in Rome?")
|
||||||
|
>>> Answer(data="Giorgio")
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
document_store: InMemoryDocumentStore,
|
||||||
|
generation_model: str = "gpt-3.5-turbo",
|
||||||
|
prompt_template: Optional[str] = None,
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
llm_api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param document_store: An instance of a DocumentStore to retrieve documents from.
|
||||||
|
:param generation_model: The name of the model to use for generation.
|
||||||
|
:param prompt_template: The template to use for the prompt. If not given, a default template is used.
|
||||||
|
:param embedding_model: The name of the model to use for embedding. If not given, BM25 is used.
|
||||||
|
:param llm_api_key: The API key to use for the OpenAI Language Model.
|
||||||
|
"""
|
||||||
|
prompt_template = (
|
||||||
|
prompt_template
|
||||||
|
or """
|
||||||
|
Given these documents, answer the question.
|
||||||
|
|
||||||
|
Documents:
|
||||||
|
{% for doc in documents %}
|
||||||
|
{{ doc.content }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
Question: {{question}}
|
||||||
|
|
||||||
|
Answer:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if not isinstance(document_store, InMemoryDocumentStore):
|
||||||
|
raise ValueError("RAGPipeline only works with an InMemoryDocumentStore.")
|
||||||
|
|
||||||
|
self.pipeline = Pipeline()
|
||||||
|
|
||||||
|
if embedding_model:
|
||||||
|
self.pipeline.add_component(
|
||||||
|
instance=SentenceTransformersTextEmbedder(model_name_or_path=embedding_model), name="text_embedder"
|
||||||
|
)
|
||||||
|
self.pipeline.add_component(
|
||||||
|
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="retriever"
|
||||||
|
)
|
||||||
|
self.pipeline.connect("text_embedder", "retriever")
|
||||||
|
else:
|
||||||
|
self.pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||||
|
|
||||||
|
self.pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||||
|
self.pipeline.add_component(instance=GPTGenerator(api_key=llm_api_key, model_name=generation_model), name="llm")
|
||||||
|
self.pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
||||||
|
self.pipeline.connect("retriever", "prompt_builder.documents")
|
||||||
|
self.pipeline.connect("prompt_builder.prompt", "llm.prompt")
|
||||||
|
self.pipeline.connect("llm.replies", "answer_builder.replies")
|
||||||
|
self.pipeline.connect("llm.metadata", "answer_builder.metadata")
|
||||||
|
self.pipeline.connect("retriever", "answer_builder.documents")
|
||||||
|
|
||||||
|
def run(self, query: str) -> Answer:
|
||||||
|
"""
|
||||||
|
Performs RAG using the given query.
|
||||||
|
|
||||||
|
:param query: The query to ask.
|
||||||
|
:return: An Answer object.
|
||||||
|
"""
|
||||||
|
run_values = {"prompt_builder": {"question": query}, "answer_builder": {"query": query}}
|
||||||
|
if self.pipeline.graph.nodes.get("text_embedder"):
|
||||||
|
run_values["text_embedder"] = {"text": query}
|
||||||
|
else:
|
||||||
|
run_values["retriever"] = {"query": query}
|
||||||
|
|
||||||
|
return self.pipeline.run(run_values)["answer_builder"]["answers"][0]
|
4
releasenotes/notes/rag_pipeline-4e9dfc82a4402935.yaml
Normal file
4
releasenotes/notes/rag_pipeline-4e9dfc82a4402935.yaml
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add a `build_rag_pipeline` utility function
|
61
test/pipelines/test_rag_pipelines.py
Normal file
61
test/pipelines/test_rag_pipelines.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from unittest.mock import patch, Mock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.dataclasses import Answer
|
||||||
|
from haystack.testing.factory import document_store_class
|
||||||
|
from haystack.document_stores import InMemoryDocumentStore
|
||||||
|
from haystack.pipeline_utils.rag import build_rag_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_chat_completion():
|
||||||
|
"""
|
||||||
|
Mock the OpenAI API completion response and reuse it for tests
|
||||||
|
"""
|
||||||
|
with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create:
|
||||||
|
# mimic the response from the OpenAI API
|
||||||
|
mock_choice = Mock()
|
||||||
|
mock_choice.index = 0
|
||||||
|
mock_choice.finish_reason = "stop"
|
||||||
|
|
||||||
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I'm fine, thanks. How are you?"
|
||||||
|
mock_message.role = "user"
|
||||||
|
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.model = "gpt-3.5-turbo"
|
||||||
|
mock_response.usage = Mock()
|
||||||
|
mock_response.usage.items.return_value = [
|
||||||
|
("prompt_tokens", 57),
|
||||||
|
("completion_tokens", 40),
|
||||||
|
("total_tokens", 97),
|
||||||
|
]
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_chat_completion_create.return_value = mock_response
|
||||||
|
yield mock_chat_completion_create
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_pipeline(mock_chat_completion):
|
||||||
|
rag_pipe = build_rag_pipeline(document_store=InMemoryDocumentStore())
|
||||||
|
answer = rag_pipe.run(query="question")
|
||||||
|
assert isinstance(answer, Answer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_pipeline_other_docstore():
|
||||||
|
FakeStore = document_store_class("FakeStore")
|
||||||
|
with pytest.raises(ValueError, match="InMemoryDocumentStore"):
|
||||||
|
assert build_rag_pipeline(document_store=FakeStore())
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_pipeline_no_embedder_if_no_model():
|
||||||
|
rag_pipe = build_rag_pipeline(document_store=InMemoryDocumentStore())
|
||||||
|
assert "text_embedder" not in rag_pipe.pipeline.graph.nodes
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_pipeline_embedder_exist_if_model_is_given():
|
||||||
|
rag_pipe = build_rag_pipeline(
|
||||||
|
document_store=InMemoryDocumentStore(), embedding_model="sentence-transformers/all-mpnet-base-v2"
|
||||||
|
)
|
||||||
|
assert "text_embedder" in rag_pipe.pipeline.graph.nodes
|
Loading…
x
Reference in New Issue
Block a user