2023-09-26 11:49:50 +02:00
|
|
|
import os
|
2023-10-09 12:54:17 +01:00
|
|
|
import json
|
2023-09-26 11:49:50 +02:00
|
|
|
import pytest
|
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack import Pipeline, Document
|
2024-01-10 21:20:42 +01:00
|
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack.components.writers import DocumentWriter
|
2024-01-10 21:20:42 +01:00
|
|
|
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
|
2023-12-22 19:37:29 +01:00
|
|
|
from haystack.components.generators import OpenAIGenerator
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack.components.builders.answer_builder import AnswerBuilder
|
|
|
|
from haystack.components.builders.prompt_builder import PromptBuilder
|
2023-09-26 11:49:50 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
|
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
|
|
)
|
2023-10-09 12:54:17 +01:00
|
|
|
def test_bm25_rag_pipeline(tmp_path):
|
|
|
|
# Create the RAG pipeline
|
2023-09-26 11:49:50 +02:00
|
|
|
prompt_template = """
|
|
|
|
Given these documents, answer the question.\nDocuments:
|
|
|
|
{% for doc in documents %}
|
2023-10-31 12:44:04 +01:00
|
|
|
{{ doc.content }}
|
2023-09-26 11:49:50 +02:00
|
|
|
{% endfor %}
|
|
|
|
|
|
|
|
\nQuestion: {{question}}
|
|
|
|
\nAnswer:
|
|
|
|
"""
|
|
|
|
rag_pipeline = Pipeline()
|
2023-10-17 16:15:16 +02:00
|
|
|
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
2023-09-26 11:49:50 +02:00
|
|
|
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
2023-12-22 19:37:29 +01:00
|
|
|
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
2023-09-26 11:49:50 +02:00
|
|
|
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
|
|
|
rag_pipeline.connect("retriever", "prompt_builder.documents")
|
|
|
|
rag_pipeline.connect("prompt_builder", "llm")
|
|
|
|
rag_pipeline.connect("llm.replies", "answer_builder.replies")
|
2023-12-21 14:09:31 +01:00
|
|
|
rag_pipeline.connect("llm.meta", "answer_builder.meta")
|
2023-09-26 11:49:50 +02:00
|
|
|
rag_pipeline.connect("retriever", "answer_builder.documents")
|
|
|
|
|
2023-10-09 12:54:17 +01:00
|
|
|
# Draw the pipeline
|
|
|
|
rag_pipeline.draw(tmp_path / "test_bm25_rag_pipeline.png")
|
|
|
|
|
|
|
|
# Serialize the pipeline to JSON
|
|
|
|
with open(tmp_path / "test_bm25_rag_pipeline.json", "w") as f:
|
|
|
|
json.dump(rag_pipeline.to_dict(), f)
|
|
|
|
|
|
|
|
# Load the pipeline back
|
|
|
|
with open(tmp_path / "test_bm25_rag_pipeline.json", "r") as f:
|
|
|
|
rag_pipeline = Pipeline.from_dict(json.load(f))
|
|
|
|
|
|
|
|
# Populate the document store
|
|
|
|
documents = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="My name is Jean and I live in Paris."),
|
|
|
|
Document(content="My name is Mark and I live in Berlin."),
|
|
|
|
Document(content="My name is Giorgio and I live in Rome."),
|
2023-10-09 12:54:17 +01:00
|
|
|
]
|
|
|
|
rag_pipeline.get_component("retriever").document_store.write_documents(documents)
|
|
|
|
|
|
|
|
# Query and assert
|
2023-09-26 11:49:50 +02:00
|
|
|
questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"]
|
|
|
|
answers_spywords = ["Jean", "Mark", "Giorgio"]
|
|
|
|
|
|
|
|
for question, spyword in zip(questions, answers_spywords):
|
|
|
|
result = rag_pipeline.run(
|
|
|
|
{
|
|
|
|
"retriever": {"query": question},
|
|
|
|
"prompt_builder": {"question": question},
|
|
|
|
"answer_builder": {"query": question},
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
assert len(result["answer_builder"]["answers"]) == 1
|
|
|
|
generated_answer = result["answer_builder"]["answers"][0]
|
|
|
|
assert spyword in generated_answer.data
|
|
|
|
assert generated_answer.query == question
|
|
|
|
assert hasattr(generated_answer, "documents")
|
2023-12-11 18:50:49 +01:00
|
|
|
assert hasattr(generated_answer, "meta")
|
2023-09-26 11:49:50 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
|
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
|
|
)
|
2023-10-09 12:54:17 +01:00
|
|
|
def test_embedding_retrieval_rag_pipeline(tmp_path):
|
|
|
|
# Create the RAG pipeline
|
2023-09-26 11:49:50 +02:00
|
|
|
prompt_template = """
|
|
|
|
Given these documents, answer the question.\nDocuments:
|
|
|
|
{% for doc in documents %}
|
2023-10-31 12:44:04 +01:00
|
|
|
{{ doc.content }}
|
2023-09-26 11:49:50 +02:00
|
|
|
{% endfor %}
|
|
|
|
|
|
|
|
\nQuestion: {{question}}
|
|
|
|
\nAnswer:
|
|
|
|
"""
|
|
|
|
rag_pipeline = Pipeline()
|
|
|
|
rag_pipeline.add_component(
|
2024-01-12 15:30:17 +01:00
|
|
|
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="text_embedder"
|
2023-09-26 11:49:50 +02:00
|
|
|
)
|
2023-10-09 12:54:17 +01:00
|
|
|
rag_pipeline.add_component(
|
2023-10-17 16:15:16 +02:00
|
|
|
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
|
2023-10-09 12:54:17 +01:00
|
|
|
)
|
2023-09-26 11:49:50 +02:00
|
|
|
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
2023-12-22 19:37:29 +01:00
|
|
|
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
2023-09-26 11:49:50 +02:00
|
|
|
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
|
|
|
rag_pipeline.connect("text_embedder", "retriever")
|
|
|
|
rag_pipeline.connect("retriever", "prompt_builder.documents")
|
|
|
|
rag_pipeline.connect("prompt_builder", "llm")
|
|
|
|
rag_pipeline.connect("llm.replies", "answer_builder.replies")
|
2023-12-21 14:09:31 +01:00
|
|
|
rag_pipeline.connect("llm.meta", "answer_builder.meta")
|
2023-09-26 11:49:50 +02:00
|
|
|
rag_pipeline.connect("retriever", "answer_builder.documents")
|
|
|
|
|
2023-10-09 12:54:17 +01:00
|
|
|
# Draw the pipeline
|
|
|
|
rag_pipeline.draw(tmp_path / "test_embedding_rag_pipeline.png")
|
|
|
|
|
|
|
|
# Serialize the pipeline to JSON
|
2023-11-23 16:59:02 +01:00
|
|
|
with open(tmp_path / "test_embedding_rag_pipeline.json", "w") as f:
|
2023-10-09 12:54:17 +01:00
|
|
|
json.dump(rag_pipeline.to_dict(), f)
|
|
|
|
|
|
|
|
# Load the pipeline back
|
2023-11-23 16:59:02 +01:00
|
|
|
with open(tmp_path / "test_embedding_rag_pipeline.json", "r") as f:
|
2023-10-09 12:54:17 +01:00
|
|
|
rag_pipeline = Pipeline.from_dict(json.load(f))
|
|
|
|
|
|
|
|
# Populate the document store
|
|
|
|
documents = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="My name is Jean and I live in Paris."),
|
|
|
|
Document(content="My name is Mark and I live in Berlin."),
|
|
|
|
Document(content="My name is Giorgio and I live in Rome."),
|
2023-10-09 12:54:17 +01:00
|
|
|
]
|
|
|
|
document_store = rag_pipeline.get_component("retriever").document_store
|
|
|
|
indexing_pipeline = Pipeline()
|
|
|
|
indexing_pipeline.add_component(
|
2024-01-12 15:30:17 +01:00
|
|
|
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
|
2023-10-09 12:54:17 +01:00
|
|
|
name="document_embedder",
|
|
|
|
)
|
|
|
|
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer")
|
|
|
|
indexing_pipeline.connect("document_embedder", "document_writer")
|
|
|
|
indexing_pipeline.run({"document_embedder": {"documents": documents}})
|
|
|
|
|
|
|
|
# Query and assert
|
2023-09-26 11:49:50 +02:00
|
|
|
questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"]
|
|
|
|
answers_spywords = ["Jean", "Mark", "Giorgio"]
|
|
|
|
|
|
|
|
for question, spyword in zip(questions, answers_spywords):
|
|
|
|
result = rag_pipeline.run(
|
|
|
|
{
|
|
|
|
"text_embedder": {"text": question},
|
|
|
|
"prompt_builder": {"question": question},
|
|
|
|
"answer_builder": {"query": question},
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
assert len(result["answer_builder"]["answers"]) == 1
|
|
|
|
generated_answer = result["answer_builder"]["answers"][0]
|
|
|
|
assert spyword in generated_answer.data
|
|
|
|
assert generated_answer.query == question
|
|
|
|
assert hasattr(generated_answer, "documents")
|
2023-12-11 18:50:49 +01:00
|
|
|
assert hasattr(generated_answer, "meta")
|