2023-03-06 17:26:19 +01:00
|
|
|
import math
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from haystack.pipelines import Pipeline, RootNode
|
|
|
|
from haystack.document_stores import InMemoryDocumentStore
|
2023-09-13 12:55:33 +02:00
|
|
|
from haystack.nodes import DensePassageRetriever, BM25Retriever, TransformersQueryClassifier, JoinDocuments, FARMReader
|
2023-03-06 17:26:19 +01:00
|
|
|
|
|
|
|
|
2023-09-13 12:55:33 +02:00
|
|
|
@pytest.mark.parametrize("classifier", [TransformersQueryClassifier()])
|
2023-03-06 17:26:19 +01:00
|
|
|
def test_query_keyword_statement_classifier(classifier):
|
|
|
|
class KeywordOutput(RootNode):
|
|
|
|
outgoing_edges = 2
|
|
|
|
|
|
|
|
def run(self, **kwargs):
|
|
|
|
kwargs["output"] = "keyword"
|
|
|
|
return kwargs, "output_1"
|
|
|
|
|
|
|
|
class QuestionOutput(RootNode):
|
|
|
|
outgoing_edges = 2
|
|
|
|
|
|
|
|
def run(self, **kwargs):
|
|
|
|
kwargs["output"] = "question"
|
|
|
|
return kwargs, "output_2"
|
|
|
|
|
|
|
|
pipeline = Pipeline()
|
|
|
|
pipeline.add_node(name="classifier", component=classifier, inputs=["Query"])
|
|
|
|
pipeline.add_node(name="KeywordNode", component=KeywordOutput(), inputs=["classifier.output_2"])
|
|
|
|
pipeline.add_node(name="QuestionNode", component=QuestionOutput(), inputs=["classifier.output_1"])
|
|
|
|
output = pipeline.run(query="morse code")
|
|
|
|
assert output["output"] == "keyword"
|
|
|
|
|
|
|
|
output = pipeline.run(query="How old is John?")
|
|
|
|
assert output["output"] == "question"
|
|
|
|
|
|
|
|
|
|
|
|
def test_join_merge_no_weights(docs):
|
2023-08-30 12:24:03 +02:00
|
|
|
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
|
2023-03-06 17:26:19 +01:00
|
|
|
document_store.write_documents(documents=docs)
|
|
|
|
bm25 = BM25Retriever(document_store=document_store)
|
|
|
|
dpr = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
|
|
|
document_store.update_embeddings(dpr)
|
|
|
|
|
|
|
|
query = "Where does Carla live?"
|
|
|
|
|
|
|
|
join_node = JoinDocuments(join_mode="merge")
|
|
|
|
p = Pipeline()
|
|
|
|
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
|
|
|
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
|
|
|
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
|
|
|
results = p.run(query=query)
|
|
|
|
assert len(results["documents"]) == 5
|
|
|
|
|
|
|
|
|
|
|
|
def test_join_merge_with_weights(docs):
|
2023-08-30 12:24:03 +02:00
|
|
|
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
|
2023-03-06 17:26:19 +01:00
|
|
|
document_store.write_documents(documents=docs)
|
|
|
|
bm25 = BM25Retriever(document_store=document_store)
|
|
|
|
dpr = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
|
|
|
document_store.update_embeddings(dpr)
|
|
|
|
|
|
|
|
query = "Where does Carla live?"
|
|
|
|
|
|
|
|
join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2)
|
|
|
|
p = Pipeline()
|
|
|
|
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
|
|
|
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
|
|
|
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
|
|
|
results = p.run(query=query)
|
|
|
|
assert math.isclose(results["documents"][0].score, 0.5336782589721345, rel_tol=0.0001)
|
|
|
|
assert len(results["documents"]) == 2
|
|
|
|
|
|
|
|
|
|
|
|
def test_join_concatenate(docs):
|
2023-08-30 12:24:03 +02:00
|
|
|
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
|
2023-03-06 17:26:19 +01:00
|
|
|
document_store.write_documents(documents=docs)
|
|
|
|
bm25 = BM25Retriever(document_store=document_store)
|
|
|
|
dpr = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
|
|
|
document_store.update_embeddings(dpr)
|
|
|
|
|
|
|
|
query = "Where does Carla live?"
|
|
|
|
|
|
|
|
join_node = JoinDocuments(join_mode="concatenate")
|
|
|
|
p = Pipeline()
|
|
|
|
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
|
|
|
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
|
|
|
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
|
|
|
results = p.run(query=query)
|
|
|
|
assert len(results["documents"]) == 5
|
|
|
|
|
|
|
|
|
|
|
|
def test_join_concatenate_with_topk(docs):
|
2023-08-30 12:24:03 +02:00
|
|
|
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
|
2023-03-06 17:26:19 +01:00
|
|
|
document_store.write_documents(documents=docs)
|
|
|
|
bm25 = BM25Retriever(document_store=document_store)
|
|
|
|
dpr = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
|
|
|
document_store.update_embeddings(dpr)
|
|
|
|
|
|
|
|
query = "Where does Carla live?"
|
|
|
|
|
|
|
|
join_node = JoinDocuments(join_mode="concatenate")
|
|
|
|
p = Pipeline()
|
|
|
|
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
|
|
|
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
|
|
|
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
|
|
|
one_result = p.run(query=query, params={"Join": {"top_k_join": 1}})
|
|
|
|
two_results = p.run(query=query, params={"Join": {"top_k_join": 2}})
|
|
|
|
assert len(one_result["documents"]) == 1
|
|
|
|
assert len(two_results["documents"]) == 2
|
|
|
|
|
|
|
|
|
2023-08-30 12:24:03 +02:00
|
|
|
def test_join_with_reader(docs):
|
|
|
|
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
|
2023-03-06 17:26:19 +01:00
|
|
|
document_store.write_documents(documents=docs)
|
|
|
|
bm25 = BM25Retriever(document_store=document_store)
|
|
|
|
dpr = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
|
|
|
document_store.update_embeddings(dpr)
|
|
|
|
reader = FARMReader(
|
|
|
|
model_name_or_path="deepset/bert-medium-squad2-distilled", use_gpu=False, top_k_per_sample=5, num_processes=0
|
|
|
|
)
|
|
|
|
|
|
|
|
query = "Where does Carla live?"
|
|
|
|
|
|
|
|
join_node = JoinDocuments()
|
|
|
|
p = Pipeline()
|
|
|
|
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
|
|
|
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
|
|
|
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
|
|
|
p.add_node(component=reader, name="Reader", inputs=["Join"])
|
|
|
|
results = p.run(query=query)
|
|
|
|
# check whether correct answer is within top 2 predictions
|
|
|
|
assert results["answers"][0].answer == "Berlin" or results["answers"][1].answer == "Berlin"
|
|
|
|
|
|
|
|
|
|
|
|
def test_join_with_rrf(docs):
|
2023-08-30 12:24:03 +02:00
|
|
|
document_store = InMemoryDocumentStore(embedding_dim=768, similarity="dot_product", use_bm25=True)
|
2023-03-06 17:26:19 +01:00
|
|
|
document_store.write_documents(documents=docs)
|
|
|
|
bm25 = BM25Retriever(document_store=document_store)
|
|
|
|
dpr = DensePassageRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False,
|
|
|
|
)
|
|
|
|
document_store.update_embeddings(dpr)
|
|
|
|
|
|
|
|
query = "Where does Carla live?"
|
|
|
|
|
|
|
|
join_node = JoinDocuments(join_mode="reciprocal_rank_fusion")
|
|
|
|
p = Pipeline()
|
|
|
|
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
|
|
|
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
|
|
|
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
|
|
|
results = p.run(query=query)
|
|
|
|
|
|
|
|
# list of precalculated expected results
|
|
|
|
expected_scores = [
|
|
|
|
0.03278688524590164,
|
|
|
|
0.03200204813108039,
|
|
|
|
0.03200204813108039,
|
|
|
|
0.031009615384615385,
|
|
|
|
0.031009615384615385,
|
|
|
|
]
|
|
|
|
assert all(
|
|
|
|
doc.score == pytest.approx(expected_scores[idx], abs=1e-3) for idx, doc in enumerate(results["documents"])
|
|
|
|
)
|