feat: Add RAG pipeline (#6461)

* add rag pipeline

* Update examples/getting_started/rag.py

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
ZanSara 2023-12-04 14:25:29 +00:00 committed by GitHub
parent 4912f7cb58
commit a38f871dbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 224 additions and 0 deletions

View File

@ -0,0 +1,22 @@
import os
from haystack import Document
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipeline_utils import build_rag_pipeline
API_KEY = "SET YOUR OPENAI API KEY HERE"
# We support many different databases. Here we load a simple and lightweight in-memory document store.
document_store = InMemoryDocumentStore()
# Create some example documents and add them to the document store.
documents = [
Document(content="My name is Jean and I live in Paris."),
Document(content="My name is Mark and I live in Berlin."),
Document(content="My name is Giorgio and I live in Rome."),
]
document_store.write_documents(documents)
# Let's now build a simple RAG pipeline that uses a generative model to answer questions.
rag_pipeline = build_rag_pipeline(llm_api_key=API_KEY, document_store=document_store)
answers = rag_pipeline.run(query="Who lives in Rome?")
print(answers.data)

View File

@ -0,0 +1,3 @@
from haystack.pipeline_utils.rag import build_rag_pipeline
__all__ = ["build_rag_pipeline"]

View File

@ -0,0 +1,134 @@
from typing import Optional
from haystack import Pipeline
from haystack.dataclasses import Answer
from haystack.document_stores import InMemoryDocumentStore
from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.generators import GPTGenerator
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.builders.prompt_builder import PromptBuilder
def build_rag_pipeline(
document_store: "InMemoryDocumentStore",
generation_model: str = "gpt-3.5-turbo",
prompt_template: Optional[str] = None,
embedding_model: Optional[str] = None,
llm_api_key: Optional[str] = None,
):
"""
Returns a prebuilt pipeline to perform retrieval augmented generation with or without an embedding model
(without embeddings, it performs retrieval using BM25).
Example usage:
```python
from haystack.utils import build_rag_pipeline
pipeline = build_rag_pipeline(document_store=your_document_store_instance)
pipeline.run(query="What's the capital of France?")
>>> Answer(data="The capital of France is Paris.")
```
:param document_store: An instance of a DocumentStore to read from.
:param generation_model: The name of the model to use for generation.
:param prompt_template: The template to use for the prompt. If not given, a default template is used.
:param embedding_model: The name of the model to use for embedding. If not given, BM25 is used.
:param llm_api_key: The API key to use for the OpenAI Language Model. If not given, the value of the
llm_api_key will be attempted to be read from the environment variable OPENAI_API_KEY.
"""
return _RAGPipeline(
document_store=document_store,
generation_model=generation_model,
prompt_template=prompt_template,
embedding_model=embedding_model,
llm_api_key=llm_api_key,
)
class _RAGPipeline:
"""
A simple ready-made pipeline for RAG. It requires a populated document store.
If an embedding model is given, it uses embedding retrieval. Otherwise, it falls back to BM25 retrieval.
Example usage:
```python
rag_pipe = RAGPipeline(document_store=InMemoryDocumentStore())
answers = rag_pipe.run(query="Who lives in Rome?")
>>> Answer(data="Giorgio")
```
"""
def __init__(
self,
document_store: InMemoryDocumentStore,
generation_model: str = "gpt-3.5-turbo",
prompt_template: Optional[str] = None,
embedding_model: Optional[str] = None,
llm_api_key: Optional[str] = None,
):
"""
:param document_store: An instance of a DocumentStore to retrieve documents from.
:param generation_model: The name of the model to use for generation.
:param prompt_template: The template to use for the prompt. If not given, a default template is used.
:param embedding_model: The name of the model to use for embedding. If not given, BM25 is used.
:param llm_api_key: The API key to use for the OpenAI Language Model.
"""
prompt_template = (
prompt_template
or """
Given these documents, answer the question.
Documents:
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
Question: {{question}}
Answer:
"""
)
if not isinstance(document_store, InMemoryDocumentStore):
raise ValueError("RAGPipeline only works with an InMemoryDocumentStore.")
self.pipeline = Pipeline()
if embedding_model:
self.pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path=embedding_model), name="text_embedder"
)
self.pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="retriever"
)
self.pipeline.connect("text_embedder", "retriever")
else:
self.pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever")
self.pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
self.pipeline.add_component(instance=GPTGenerator(api_key=llm_api_key, model_name=generation_model), name="llm")
self.pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
self.pipeline.connect("retriever", "prompt_builder.documents")
self.pipeline.connect("prompt_builder.prompt", "llm.prompt")
self.pipeline.connect("llm.replies", "answer_builder.replies")
self.pipeline.connect("llm.metadata", "answer_builder.metadata")
self.pipeline.connect("retriever", "answer_builder.documents")
def run(self, query: str) -> Answer:
"""
Performs RAG using the given query.
:param query: The query to ask.
:return: An Answer object.
"""
run_values = {"prompt_builder": {"question": query}, "answer_builder": {"query": query}}
if self.pipeline.graph.nodes.get("text_embedder"):
run_values["text_embedder"] = {"text": query}
else:
run_values["retriever"] = {"query": query}
return self.pipeline.run(run_values)["answer_builder"]["answers"][0]

View File

@ -0,0 +1,4 @@
---
features:
- |
Add a `build_rag_pipeline` utility function

View File

@ -0,0 +1,61 @@
from unittest.mock import patch, Mock
import pytest
from haystack.dataclasses import Answer
from haystack.testing.factory import document_store_class
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipeline_utils.rag import build_rag_pipeline
@pytest.fixture
def mock_chat_completion():
"""
Mock the OpenAI API completion response and reuse it for tests
"""
with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create:
# mimic the response from the OpenAI API
mock_choice = Mock()
mock_choice.index = 0
mock_choice.finish_reason = "stop"
mock_message = Mock()
mock_message.content = "I'm fine, thanks. How are you?"
mock_message.role = "user"
mock_choice.message = mock_message
mock_response = Mock()
mock_response.model = "gpt-3.5-turbo"
mock_response.usage = Mock()
mock_response.usage.items.return_value = [
("prompt_tokens", 57),
("completion_tokens", 40),
("total_tokens", 97),
]
mock_response.choices = [mock_choice]
mock_chat_completion_create.return_value = mock_response
yield mock_chat_completion_create
def test_rag_pipeline(mock_chat_completion):
rag_pipe = build_rag_pipeline(document_store=InMemoryDocumentStore())
answer = rag_pipe.run(query="question")
assert isinstance(answer, Answer)
def test_rag_pipeline_other_docstore():
FakeStore = document_store_class("FakeStore")
with pytest.raises(ValueError, match="InMemoryDocumentStore"):
assert build_rag_pipeline(document_store=FakeStore())
def test_rag_pipeline_no_embedder_if_no_model():
rag_pipe = build_rag_pipeline(document_store=InMemoryDocumentStore())
assert "text_embedder" not in rag_pipe.pipeline.graph.nodes
def test_rag_pipeline_embedder_exist_if_model_is_given():
rag_pipe = build_rag_pipeline(
document_store=InMemoryDocumentStore(), embedding_model="sentence-transformers/all-mpnet-base-v2"
)
assert "text_embedder" in rag_pipe.pipeline.graph.nodes