haystack/e2e/pipelines/test_pipeline_topologies.py

191 lines
7.5 KiB
Python
Raw Normal View History

import math
import pytest
from haystack.pipelines import Pipeline, RootNode
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import DensePassageRetriever, BM25Retriever, TransformersQueryClassifier, JoinDocuments, FARMReader
@pytest.mark.parametrize("classifier", [TransformersQueryClassifier()])
def test_query_keyword_statement_classifier(classifier):
class KeywordOutput(RootNode):
outgoing_edges = 2
def run(self, **kwargs):
kwargs["output"] = "keyword"
return kwargs, "output_1"
class QuestionOutput(RootNode):
outgoing_edges = 2
def run(self, **kwargs):
kwargs["output"] = "question"
return kwargs, "output_2"
pipeline = Pipeline()
pipeline.add_node(name="classifier", component=classifier, inputs=["Query"])
pipeline.add_node(name="KeywordNode", component=KeywordOutput(), inputs=["classifier.output_2"])
pipeline.add_node(name="QuestionNode", component=QuestionOutput(), inputs=["classifier.output_1"])
output = pipeline.run(query="morse code")
assert output["output"] == "keyword"
output = pipeline.run(query="How old is John?")
assert output["output"] == "question"
def test_join_merge_no_weights(docs):
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
document_store.write_documents(documents=docs)
bm25 = BM25Retriever(document_store=document_store)
dpr = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
)
document_store.update_embeddings(dpr)
query = "Where does Carla live?"
join_node = JoinDocuments(join_mode="merge")
p = Pipeline()
p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query)
assert len(results["documents"]) == 5
def test_join_merge_with_weights(docs):
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
document_store.write_documents(documents=docs)
bm25 = BM25Retriever(document_store=document_store)
dpr = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
)
document_store.update_embeddings(dpr)
query = "Where does Carla live?"
join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2)
p = Pipeline()
p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query)
assert math.isclose(results["documents"][0].score, 0.5336782589721345, rel_tol=0.0001)
assert len(results["documents"]) == 2
def test_join_concatenate(docs):
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
document_store.write_documents(documents=docs)
bm25 = BM25Retriever(document_store=document_store)
dpr = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
)
document_store.update_embeddings(dpr)
query = "Where does Carla live?"
join_node = JoinDocuments(join_mode="concatenate")
p = Pipeline()
p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query)
assert len(results["documents"]) == 5
def test_join_concatenate_with_topk(docs):
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
document_store.write_documents(documents=docs)
bm25 = BM25Retriever(document_store=document_store)
dpr = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
)
document_store.update_embeddings(dpr)
query = "Where does Carla live?"
join_node = JoinDocuments(join_mode="concatenate")
p = Pipeline()
p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
one_result = p.run(query=query, params={"Join": {"top_k_join": 1}})
two_results = p.run(query=query, params={"Join": {"top_k_join": 2}})
assert len(one_result["documents"]) == 1
assert len(two_results["documents"]) == 2
def test_join_with_reader(docs):
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
document_store.write_documents(documents=docs)
bm25 = BM25Retriever(document_store=document_store)
dpr = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
)
document_store.update_embeddings(dpr)
reader = FARMReader(
model_name_or_path="deepset/bert-medium-squad2-distilled", use_gpu=False, top_k_per_sample=5, num_processes=0
)
query = "Where does Carla live?"
join_node = JoinDocuments()
p = Pipeline()
p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
p.add_node(component=reader, name="Reader", inputs=["Join"])
results = p.run(query=query)
# check whether correct answer is within top 2 predictions
assert results["answers"][0].answer == "Berlin" or results["answers"][1].answer == "Berlin"
def test_join_with_rrf(docs):
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
document_store.write_documents(documents=docs)
bm25 = BM25Retriever(document_store=document_store)
dpr = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
)
document_store.update_embeddings(dpr)
query = "Where does Carla live?"
join_node = JoinDocuments(join_mode="reciprocal_rank_fusion")
p = Pipeline()
p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query)
# list of precalculated expected results
expected_scores = [
0.03278688524590164,
0.03200204813108039,
0.03200204813108039,
0.031009615384615385,
0.031009615384615385,
]
assert all(
doc.score == pytest.approx(expected_scores[idx], abs=1e-3) for idx, doc in enumerate(results["documents"])
)