mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
test: enhance e2e tests to also draw and serialize/deserialize the test pipelines (#5910)
* add draw and serialization/deserialization to e2e pipeline examples * add comment about json serialization * fix a small gptgenerator bug and move indexing in tests * to json * review feedback
This commit is contained in:
parent
40b83d8a47
commit
71f2430fd1
@ -1,25 +1,39 @@
|
||||
import json
|
||||
|
||||
from haystack.preview import Pipeline, Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
from haystack.preview.components.retrievers import MemoryBM25Retriever
|
||||
from haystack.preview.components.readers import ExtractiveReader
|
||||
|
||||
|
||||
def test_extractive_qa_pipeline():
|
||||
document_store = MemoryDocumentStore()
|
||||
def test_extractive_qa_pipeline(tmp_path):
|
||||
# Create the pipeline
|
||||
qa_pipeline = Pipeline()
|
||||
qa_pipeline.add_component(instance=MemoryBM25Retriever(document_store=MemoryDocumentStore()), name="retriever")
|
||||
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
|
||||
qa_pipeline.connect("retriever", "reader")
|
||||
|
||||
# Draw the pipeline
|
||||
qa_pipeline.draw(tmp_path / "test_extractive_qa_pipeline.png")
|
||||
|
||||
# Serialize the pipeline to JSON
|
||||
with open(tmp_path / "test_bm25_rag_pipeline.json", "w") as f:
|
||||
print(json.dumps(qa_pipeline.to_dict(), indent=4))
|
||||
json.dump(qa_pipeline.to_dict(), f)
|
||||
|
||||
# Load the pipeline back
|
||||
with open(tmp_path / "test_bm25_rag_pipeline.json", "r") as f:
|
||||
qa_pipeline = Pipeline.from_dict(json.load(f))
|
||||
|
||||
# Populate the document store
|
||||
documents = [
|
||||
Document(text="My name is Jean and I live in Paris."),
|
||||
Document(text="My name is Mark and I live in Berlin."),
|
||||
Document(text="My name is Giorgio and I live in Rome."),
|
||||
]
|
||||
qa_pipeline.get_component("retriever").document_store.write_documents(documents)
|
||||
|
||||
document_store.write_documents(documents)
|
||||
|
||||
qa_pipeline = Pipeline()
|
||||
qa_pipeline.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
|
||||
qa_pipeline.connect("retriever", "reader")
|
||||
|
||||
# Query and assert
|
||||
questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"]
|
||||
answers_spywords = ["Jean", "Mark", "Giorgio"]
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Pipeline, Document
|
||||
@ -15,15 +16,8 @@ from haystack.preview.components.builders.prompt_builder import PromptBuilder
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_bm25_rag_pipeline():
|
||||
document_store = MemoryDocumentStore()
|
||||
|
||||
documents = [
|
||||
Document(text="My name is Jean and I live in Paris."),
|
||||
Document(text="My name is Mark and I live in Berlin."),
|
||||
Document(text="My name is Giorgio and I live in Rome."),
|
||||
]
|
||||
|
||||
def test_bm25_rag_pipeline(tmp_path):
|
||||
# Create the RAG pipeline
|
||||
prompt_template = """
|
||||
Given these documents, answer the question.\nDocuments:
|
||||
{% for doc in documents %}
|
||||
@ -33,11 +27,8 @@ def test_bm25_rag_pipeline():
|
||||
\nQuestion: {{question}}
|
||||
\nAnswer:
|
||||
"""
|
||||
|
||||
document_store.write_documents(documents)
|
||||
|
||||
rag_pipeline = Pipeline()
|
||||
rag_pipeline.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||
rag_pipeline.add_component(instance=MemoryBM25Retriever(document_store=MemoryDocumentStore()), name="retriever")
|
||||
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
||||
@ -47,6 +38,26 @@ def test_bm25_rag_pipeline():
|
||||
rag_pipeline.connect("llm.metadata", "answer_builder.metadata")
|
||||
rag_pipeline.connect("retriever", "answer_builder.documents")
|
||||
|
||||
# Draw the pipeline
|
||||
rag_pipeline.draw(tmp_path / "test_bm25_rag_pipeline.png")
|
||||
|
||||
# Serialize the pipeline to JSON
|
||||
with open(tmp_path / "test_bm25_rag_pipeline.json", "w") as f:
|
||||
json.dump(rag_pipeline.to_dict(), f)
|
||||
|
||||
# Load the pipeline back
|
||||
with open(tmp_path / "test_bm25_rag_pipeline.json", "r") as f:
|
||||
rag_pipeline = Pipeline.from_dict(json.load(f))
|
||||
|
||||
# Populate the document store
|
||||
documents = [
|
||||
Document(text="My name is Jean and I live in Paris."),
|
||||
Document(text="My name is Mark and I live in Berlin."),
|
||||
Document(text="My name is Giorgio and I live in Rome."),
|
||||
]
|
||||
rag_pipeline.get_component("retriever").document_store.write_documents(documents)
|
||||
|
||||
# Query and assert
|
||||
questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"]
|
||||
answers_spywords = ["Jean", "Mark", "Giorgio"]
|
||||
|
||||
@ -71,15 +82,8 @@ def test_bm25_rag_pipeline():
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_embedding_retrieval_rag_pipeline():
|
||||
document_store = MemoryDocumentStore()
|
||||
|
||||
documents = [
|
||||
Document(text="My name is Jean and I live in Paris."),
|
||||
Document(text="My name is Mark and I live in Berlin."),
|
||||
Document(text="My name is Giorgio and I live in Rome."),
|
||||
]
|
||||
|
||||
def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
# Create the RAG pipeline
|
||||
prompt_template = """
|
||||
Given these documents, answer the question.\nDocuments:
|
||||
{% for doc in documents %}
|
||||
@ -89,22 +93,14 @@ def test_embedding_retrieval_rag_pipeline():
|
||||
\nQuestion: {{question}}
|
||||
\nAnswer:
|
||||
"""
|
||||
|
||||
indexing_pipeline = Pipeline()
|
||||
indexing_pipeline.add_component(
|
||||
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-mpnet-base-v2"),
|
||||
name="document_embedder",
|
||||
)
|
||||
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer")
|
||||
indexing_pipeline.connect("document_embedder", "document_writer")
|
||||
indexing_pipeline.run({"document_embedder": {"documents": documents}})
|
||||
|
||||
rag_pipeline = Pipeline()
|
||||
rag_pipeline.add_component(
|
||||
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-mpnet-base-v2"),
|
||||
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
|
||||
name="text_embedder",
|
||||
)
|
||||
rag_pipeline.add_component(instance=MemoryEmbeddingRetriever(document_store=document_store), name="retriever")
|
||||
rag_pipeline.add_component(
|
||||
instance=MemoryEmbeddingRetriever(document_store=MemoryDocumentStore()), name="retriever"
|
||||
)
|
||||
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
||||
@ -115,6 +111,34 @@ def test_embedding_retrieval_rag_pipeline():
|
||||
rag_pipeline.connect("llm.metadata", "answer_builder.metadata")
|
||||
rag_pipeline.connect("retriever", "answer_builder.documents")
|
||||
|
||||
# Draw the pipeline
|
||||
rag_pipeline.draw(tmp_path / "test_embedding_rag_pipeline.png")
|
||||
|
||||
# Serialize the pipeline to JSON
|
||||
with open(tmp_path / "test_bm25_rag_pipeline.json", "w") as f:
|
||||
json.dump(rag_pipeline.to_dict(), f)
|
||||
|
||||
# Load the pipeline back
|
||||
with open(tmp_path / "test_bm25_rag_pipeline.json", "r") as f:
|
||||
rag_pipeline = Pipeline.from_dict(json.load(f))
|
||||
|
||||
# Populate the document store
|
||||
documents = [
|
||||
Document(text="My name is Jean and I live in Paris."),
|
||||
Document(text="My name is Mark and I live in Berlin."),
|
||||
Document(text="My name is Giorgio and I live in Rome."),
|
||||
]
|
||||
document_store = rag_pipeline.get_component("retriever").document_store
|
||||
indexing_pipeline = Pipeline()
|
||||
indexing_pipeline.add_component(
|
||||
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
|
||||
name="document_embedder",
|
||||
)
|
||||
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer")
|
||||
indexing_pipeline.connect("document_embedder", "document_writer")
|
||||
indexing_pipeline.run({"document_embedder": {"documents": documents}})
|
||||
|
||||
# Query and assert
|
||||
questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"]
|
||||
answers_spywords = ["Jean", "Mark", "Giorgio"]
|
||||
|
||||
@ -129,7 +153,6 @@ def test_embedding_retrieval_rag_pipeline():
|
||||
|
||||
assert len(result["answer_builder"]["answers"]) == 1
|
||||
generated_answer = result["answer_builder"]["answers"][0]
|
||||
print(generated_answer)
|
||||
assert spyword in generated_answer.data
|
||||
assert generated_answer.query == question
|
||||
assert hasattr(generated_answer, "documents")
|
||||
|
||||
@ -121,7 +121,7 @@ class GPTGenerator:
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
streaming_callback = None
|
||||
if "streaming_callback" in init_params:
|
||||
if "streaming_callback" in init_params and init_params["streaming_callback"]:
|
||||
parts = init_params["streaming_callback"].split(".")
|
||||
module_name = ".".join(parts[:-1])
|
||||
function_name = parts[-1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user