test: replace ElasticsearchDS with InMemoryDS when it makes sense; support scale_score in InMemoryDS (#4283)

* replace elasticds with imds - first draft

* fix

* fix tests and implement scale_score in imds bm25

* add docstrings for scale_score
This commit is contained in:
Stefano Fiorucci 2023-03-01 11:35:10 +01:00 committed by GitHub
parent ee74421212
commit e8f9b1b65d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 82 additions and 95 deletions

View File

@ -16,6 +16,7 @@ import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
import rank_bm25 import rank_bm25
import pandas as pd import pandas as pd
from scipy.special import expit
from haystack.schema import Document, FilterType, Label from haystack.schema import Document, FilterType, Label
from haystack.errors import DuplicateDocumentError, DocumentStoreError from haystack.errors import DuplicateDocumentError, DocumentStoreError
@ -953,7 +954,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
index: Optional[str] = None, index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
all_terms_must_match: bool = False, all_terms_must_match: bool = False,
scale_score: bool = False, scale_score: bool = True,
) -> List[Document]: ) -> List[Document]:
""" """
Scan through documents in DocumentStore and return a small number documents Scan through documents in DocumentStore and return a small number documents
@ -961,6 +962,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
:param query: The query. :param query: The query.
:param top_k: How many documents to return per query. :param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents. :param index: The name of the index in the DocumentStore from which to retrieve documents.
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
""" """
if headers: if headers:
@ -973,10 +975,6 @@ class InMemoryDocumentStore(KeywordDocumentStore):
logger.warning( logger.warning(
"InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored." "InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored."
) )
if scale_score is True:
logger.warning(
"InMemoryDocumentStore does not support scale_score for BM25 retrieval. This parameter is ignored."
)
index = index or self.index index = index or self.index
if index not in self.bm25: if index not in self.bm25:
@ -989,6 +987,9 @@ class InMemoryDocumentStore(KeywordDocumentStore):
tokenized_query = self.bm25_tokenization_regex(query.lower()) tokenized_query = self.bm25_tokenization_regex(query.lower())
docs_scores = self.bm25[index].get_scores(tokenized_query) docs_scores = self.bm25[index].get_scores(tokenized_query)
if scale_score is True:
# scaling probability from BM25
docs_scores = [float(expit(np.asarray(score / 8))) for score in docs_scores]
top_docs_positions = np.argsort(docs_scores)[::-1][:top_k] top_docs_positions = np.argsort(docs_scores)[::-1][:top_k]
textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type in ["text", "table"]] textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type in ["text", "table"]]
@ -1009,7 +1010,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
index: Optional[str] = None, index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
all_terms_must_match: bool = False, all_terms_must_match: bool = False,
scale_score: bool = False, scale_score: bool = True,
) -> List[List[Document]]: ) -> List[List[Document]]:
""" """
Scan through documents in DocumentStore and return a small number documents Scan through documents in DocumentStore and return a small number documents
@ -1018,6 +1019,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
:param query: The query. :param query: The query.
:param top_k: How many documents to return per query. :param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents. :param index: The name of the index in the DocumentStore from which to retrieve documents.
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
""" """
if headers: if headers:
@ -1030,10 +1032,6 @@ class InMemoryDocumentStore(KeywordDocumentStore):
logger.warning( logger.warning(
"InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored." "InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored."
) )
if scale_score is True:
logger.warning(
"InMemoryDocumentStore does not support scale_score for BM25 retrieval. This parameter is ignored."
)
index = index or self.index index = index or self.index
if index not in self.bm25: if index not in self.bm25:
@ -1043,6 +1041,6 @@ class InMemoryDocumentStore(KeywordDocumentStore):
result_documents = [] result_documents = []
for query in queries: for query in queries:
result_documents.append(self.query(query=query, top_k=top_k, index=index)) result_documents.append(self.query(query=query, top_k=top_k, index=index, scale_score=scale_score))
return result_documents return result_documents

View File

@ -863,6 +863,7 @@ def get_document_store(
index=index, index=index,
similarity=similarity, similarity=similarity,
use_bm25=True, use_bm25=True,
bm25_parameters={"k1": 1.2, "b": 0.75}, # parameters similar to those of Elasticsearch
) )
elif document_store_type == "elasticsearch": elif document_store_type == "elasticsearch":

View File

@ -1,21 +1,17 @@
import pytest import pytest
from haystack.pipelines import TranslationWrapperPipeline, ExtractiveQAPipeline from haystack.pipelines import TranslationWrapperPipeline, ExtractiveQAPipeline
from haystack.nodes import DensePassageRetriever, EmbeddingRetriever
from .test_summarizer import SPLIT_DOCS from .test_summarizer import SPLIT_DOCS
# Keeping few (retriever,document_store,reader) combination to reduce test time # Keeping few (retriever,document_store,reader) combination to reduce test time
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.elasticsearch
@pytest.mark.summarizer @pytest.mark.summarizer
@pytest.mark.parametrize("retriever,document_store,reader", [("embedding", "memory", "farm")], indirect=True) @pytest.mark.parametrize("retriever,document_store,reader", [("embedding", "memory", "farm")], indirect=True)
def test_extractive_qa_pipeline_with_translator( def test_extractive_qa_pipeline_with_translator(
document_store, retriever, reader, en_to_de_translator, de_to_en_translator document_store, retriever, reader, en_to_de_translator, de_to_en_translator
): ):
document_store.write_documents(SPLIT_DOCS) document_store.write_documents(SPLIT_DOCS)
if isinstance(retriever, EmbeddingRetriever) or isinstance(retriever, DensePassageRetriever):
document_store.update_embeddings(retriever=retriever) document_store.update_embeddings(retriever=retriever)
query = "Wo steht der Eiffelturm?" query = "Wo steht der Eiffelturm?"

View File

@ -161,6 +161,7 @@ def test_eval_reader(reader, document_store, use_confidence_scores):
assert reader_eval_results["top_n_accuracy"] == 100.0 assert reader_eval_results["top_n_accuracy"] == 100.0
# using ElasticsearchDocumentStore, since InMemoryDocumentStore doesn't return meaningful BM25 scores when there are very few documents
@pytest.mark.elasticsearch @pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True) @pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("open_domain", [True, False]) @pytest.mark.parametrize("open_domain", [True, False])

View File

@ -21,8 +21,7 @@ class MockRetriever(BaseMockRetriever):
raise ValueError("TEST ERROR!") raise ValueError("TEST ERROR!")
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_node_names_validation(document_store_with_docs, tmp_path): def test_node_names_validation(document_store_with_docs, tmp_path):
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node( pipeline.add_node(
@ -52,28 +51,27 @@ def test_node_names_validation(document_store_with_docs, tmp_path):
assert "top_k" not in exception_raised assert "top_k" not in exception_raised
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_debug_attributes_global(document_store_with_docs, tmp_path): def test_debug_attributes_global(document_store_with_docs, tmp_path):
es_retriever = BM25Retriever(document_store=document_store_with_docs) bm25_retriever = BM25Retriever(document_store=document_store_with_docs)
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0) reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=bm25_retriever, name="BM25Retriever", inputs=["Query"])
pipeline.add_node(component=reader, name="Reader", inputs=["ESRetriever"]) pipeline.add_node(component=reader, name="Reader", inputs=["BM25Retriever"])
prediction = pipeline.run( prediction = pipeline.run(
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 10}, "Reader": {"top_k": 3}}, debug=True query="Who lives in Berlin?", params={"BM25Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}, debug=True
) )
assert "_debug" in prediction.keys() assert "_debug" in prediction.keys()
assert "ESRetriever" in prediction["_debug"].keys() assert "BM25Retriever" in prediction["_debug"].keys()
assert "Reader" in prediction["_debug"].keys() assert "Reader" in prediction["_debug"].keys()
assert "input" in prediction["_debug"]["ESRetriever"].keys() assert "input" in prediction["_debug"]["BM25Retriever"].keys()
assert "output" in prediction["_debug"]["ESRetriever"].keys() assert "output" in prediction["_debug"]["BM25Retriever"].keys()
assert "input" in prediction["_debug"]["Reader"].keys() assert "input" in prediction["_debug"]["Reader"].keys()
assert "output" in prediction["_debug"]["Reader"].keys() assert "output" in prediction["_debug"]["Reader"].keys()
assert prediction["_debug"]["ESRetriever"]["input"] assert prediction["_debug"]["BM25Retriever"]["input"]
assert prediction["_debug"]["ESRetriever"]["output"] assert prediction["_debug"]["BM25Retriever"]["output"]
assert prediction["_debug"]["Reader"]["input"] assert prediction["_debug"]["Reader"]["input"]
assert prediction["_debug"]["Reader"]["output"] assert prediction["_debug"]["Reader"]["output"]
@ -81,57 +79,55 @@ def test_debug_attributes_global(document_store_with_docs, tmp_path):
json.dumps(prediction, default=str) json.dumps(prediction, default=str)
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_debug_attributes_per_node(document_store_with_docs, tmp_path): def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
es_retriever = BM25Retriever(document_store=document_store_with_docs) bm25_retriever = BM25Retriever(document_store=document_store_with_docs)
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0) reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=bm25_retriever, name="BM25Retriever", inputs=["Query"])
pipeline.add_node(component=reader, name="Reader", inputs=["ESRetriever"]) pipeline.add_node(component=reader, name="Reader", inputs=["BM25Retriever"])
prediction = pipeline.run( prediction = pipeline.run(
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3}} query="Who lives in Berlin?", params={"BM25Retriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3}}
) )
assert "_debug" in prediction.keys() assert "_debug" in prediction.keys()
assert "ESRetriever" in prediction["_debug"].keys() assert "BM25Retriever" in prediction["_debug"].keys()
assert "Reader" not in prediction["_debug"].keys() assert "Reader" not in prediction["_debug"].keys()
assert "input" in prediction["_debug"]["ESRetriever"].keys() assert "input" in prediction["_debug"]["BM25Retriever"].keys()
assert "output" in prediction["_debug"]["ESRetriever"].keys() assert "output" in prediction["_debug"]["BM25Retriever"].keys()
assert prediction["_debug"]["ESRetriever"]["input"] assert prediction["_debug"]["BM25Retriever"]["input"]
assert prediction["_debug"]["ESRetriever"]["output"] assert prediction["_debug"]["BM25Retriever"]["output"]
# Avoid circular reference: easiest way to detect those is to use json.dumps # Avoid circular reference: easiest way to detect those is to use json.dumps
json.dumps(prediction, default=str) json.dumps(prediction, default=str)
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_debug_attributes_for_join_nodes(document_store_with_docs, tmp_path): def test_debug_attributes_for_join_nodes(document_store_with_docs, tmp_path):
es_retriever_1 = BM25Retriever(document_store=document_store_with_docs) bm25_retriever_1 = BM25Retriever(document_store=document_store_with_docs)
es_retriever_2 = BM25Retriever(document_store=document_store_with_docs) bm25_retriever_2 = BM25Retriever(document_store=document_store_with_docs)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever_1, name="ESRetriever1", inputs=["Query"]) pipeline.add_node(component=bm25_retriever_1, name="BM25Retriever1", inputs=["Query"])
pipeline.add_node(component=es_retriever_2, name="ESRetriever2", inputs=["Query"]) pipeline.add_node(component=bm25_retriever_2, name="BM25Retriever2", inputs=["Query"])
pipeline.add_node(component=JoinDocuments(), name="JoinDocuments", inputs=["ESRetriever1", "ESRetriever2"]) pipeline.add_node(component=JoinDocuments(), name="JoinDocuments", inputs=["BM25Retriever1", "BM25Retriever2"])
prediction = pipeline.run(query="Who lives in Berlin?", debug=True) prediction = pipeline.run(query="Who lives in Berlin?", debug=True)
assert "_debug" in prediction.keys() assert "_debug" in prediction.keys()
assert "ESRetriever1" in prediction["_debug"].keys() assert "BM25Retriever1" in prediction["_debug"].keys()
assert "ESRetriever2" in prediction["_debug"].keys() assert "BM25Retriever2" in prediction["_debug"].keys()
assert "JoinDocuments" in prediction["_debug"].keys() assert "JoinDocuments" in prediction["_debug"].keys()
assert "input" in prediction["_debug"]["ESRetriever1"].keys() assert "input" in prediction["_debug"]["BM25Retriever1"].keys()
assert "output" in prediction["_debug"]["ESRetriever1"].keys() assert "output" in prediction["_debug"]["BM25Retriever1"].keys()
assert "input" in prediction["_debug"]["ESRetriever2"].keys() assert "input" in prediction["_debug"]["BM25Retriever2"].keys()
assert "output" in prediction["_debug"]["ESRetriever2"].keys() assert "output" in prediction["_debug"]["BM25Retriever2"].keys()
assert "input" in prediction["_debug"]["JoinDocuments"].keys() assert "input" in prediction["_debug"]["JoinDocuments"].keys()
assert "output" in prediction["_debug"]["JoinDocuments"].keys() assert "output" in prediction["_debug"]["JoinDocuments"].keys()
assert prediction["_debug"]["ESRetriever1"]["input"] assert prediction["_debug"]["BM25Retriever1"]["input"]
assert prediction["_debug"]["ESRetriever1"]["output"] assert prediction["_debug"]["BM25Retriever1"]["output"]
assert prediction["_debug"]["ESRetriever2"]["input"] assert prediction["_debug"]["BM25Retriever2"]["input"]
assert prediction["_debug"]["ESRetriever2"]["output"] assert prediction["_debug"]["BM25Retriever2"]["output"]
assert prediction["_debug"]["JoinDocuments"]["input"] assert prediction["_debug"]["JoinDocuments"]["input"]
assert prediction["_debug"]["JoinDocuments"]["output"] assert prediction["_debug"]["JoinDocuments"]["output"]
@ -139,37 +135,36 @@ def test_debug_attributes_for_join_nodes(document_store_with_docs, tmp_path):
json.dumps(prediction, default=str) json.dumps(prediction, default=str)
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_global_debug_attributes_override_node_ones(document_store_with_docs, tmp_path): def test_global_debug_attributes_override_node_ones(document_store_with_docs, tmp_path):
es_retriever = BM25Retriever(document_store=document_store_with_docs) bm25_retriever = BM25Retriever(document_store=document_store_with_docs)
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0) reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=bm25_retriever, name="BM25Retriever", inputs=["Query"])
pipeline.add_node(component=reader, name="Reader", inputs=["ESRetriever"]) pipeline.add_node(component=reader, name="Reader", inputs=["BM25Retriever"])
prediction = pipeline.run( prediction = pipeline.run(
query="Who lives in Berlin?", query="Who lives in Berlin?",
params={"ESRetriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3, "debug": True}}, params={"BM25Retriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3, "debug": True}},
debug=False, debug=False,
) )
assert "_debug" not in prediction.keys() assert "_debug" not in prediction.keys()
prediction = pipeline.run( prediction = pipeline.run(
query="Who lives in Berlin?", query="Who lives in Berlin?",
params={"ESRetriever": {"top_k": 10, "debug": False}, "Reader": {"top_k": 3, "debug": False}}, params={"BM25Retriever": {"top_k": 10, "debug": False}, "Reader": {"top_k": 3, "debug": False}},
debug=True, debug=True,
) )
assert "_debug" in prediction.keys() assert "_debug" in prediction.keys()
assert "ESRetriever" in prediction["_debug"].keys() assert "BM25Retriever" in prediction["_debug"].keys()
assert "Reader" in prediction["_debug"].keys() assert "Reader" in prediction["_debug"].keys()
assert "input" in prediction["_debug"]["ESRetriever"].keys() assert "input" in prediction["_debug"]["BM25Retriever"].keys()
assert "output" in prediction["_debug"]["ESRetriever"].keys() assert "output" in prediction["_debug"]["BM25Retriever"].keys()
assert "input" in prediction["_debug"]["Reader"].keys() assert "input" in prediction["_debug"]["Reader"].keys()
assert "output" in prediction["_debug"]["Reader"].keys() assert "output" in prediction["_debug"]["Reader"].keys()
assert prediction["_debug"]["ESRetriever"]["input"] assert prediction["_debug"]["BM25Retriever"]["input"]
assert prediction["_debug"]["ESRetriever"]["output"] assert prediction["_debug"]["BM25Retriever"]["output"]
assert prediction["_debug"]["Reader"]["input"] assert prediction["_debug"]["Reader"]["input"]
assert prediction["_debug"]["Reader"]["output"] assert prediction["_debug"]["Reader"]["output"]

View File

@ -293,10 +293,9 @@ def test_most_similar_documents_pipeline_with_filters_batch(retriever, document_
assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"] assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"]
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
def test_join_merge_no_weights(document_store_dot_product_with_docs): def test_join_merge_no_weights(document_store_dot_product_with_docs):
es = BM25Retriever(document_store=document_store_dot_product_with_docs) bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
dpr = DensePassageRetriever( dpr = DensePassageRetriever(
document_store=document_store_dot_product_with_docs, document_store=document_store_dot_product_with_docs,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base", query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
@ -309,17 +308,16 @@ def test_join_merge_no_weights(document_store_dot_product_with_docs):
join_node = JoinDocuments(join_mode="merge") join_node = JoinDocuments(join_mode="merge")
p = Pipeline() p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"]) p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query) results = p.run(query=query)
assert len(results["documents"]) == 5 assert len(results["documents"]) == 5
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
def test_join_merge_with_weights(document_store_dot_product_with_docs): def test_join_merge_with_weights(document_store_dot_product_with_docs):
es = BM25Retriever(document_store=document_store_dot_product_with_docs) bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
dpr = DensePassageRetriever( dpr = DensePassageRetriever(
document_store=document_store_dot_product_with_docs, document_store=document_store_dot_product_with_docs,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base", query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
@ -332,18 +330,17 @@ def test_join_merge_with_weights(document_store_dot_product_with_docs):
join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2) join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2)
p = Pipeline() p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"]) p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query) results = p.run(query=query)
assert math.isclose(results["documents"][0].score, 0.5481393431183286, rel_tol=0.0001) assert math.isclose(results["documents"][0].score, 0.5336782589721345, rel_tol=0.0001)
assert len(results["documents"]) == 2 assert len(results["documents"]) == 2
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
def test_join_concatenate(document_store_dot_product_with_docs): def test_join_concatenate(document_store_dot_product_with_docs):
es = BM25Retriever(document_store=document_store_dot_product_with_docs) bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
dpr = DensePassageRetriever( dpr = DensePassageRetriever(
document_store=document_store_dot_product_with_docs, document_store=document_store_dot_product_with_docs,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base", query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
@ -356,17 +353,16 @@ def test_join_concatenate(document_store_dot_product_with_docs):
join_node = JoinDocuments(join_mode="concatenate") join_node = JoinDocuments(join_mode="concatenate")
p = Pipeline() p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"]) p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query) results = p.run(query=query)
assert len(results["documents"]) == 5 assert len(results["documents"]) == 5
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
def test_join_concatenate_with_topk(document_store_dot_product_with_docs): def test_join_concatenate_with_topk(document_store_dot_product_with_docs):
es = BM25Retriever(document_store=document_store_dot_product_with_docs) bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
dpr = DensePassageRetriever( dpr = DensePassageRetriever(
document_store=document_store_dot_product_with_docs, document_store=document_store_dot_product_with_docs,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base", query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
@ -379,7 +375,7 @@ def test_join_concatenate_with_topk(document_store_dot_product_with_docs):
join_node = JoinDocuments(join_mode="concatenate") join_node = JoinDocuments(join_mode="concatenate")
p = Pipeline() p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"]) p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
one_result = p.run(query=query, params={"Join": {"top_k_join": 1}}) one_result = p.run(query=query, params={"Join": {"top_k_join": 1}})
@ -388,11 +384,10 @@ def test_join_concatenate_with_topk(document_store_dot_product_with_docs):
assert len(two_results["documents"]) == 2 assert len(two_results["documents"]) == 2
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True) @pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_join_with_reader(document_store_dot_product_with_docs, reader): def test_join_with_reader(document_store_dot_product_with_docs, reader):
es = BM25Retriever(document_store=document_store_dot_product_with_docs) bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
dpr = DensePassageRetriever( dpr = DensePassageRetriever(
document_store=document_store_dot_product_with_docs, document_store=document_store_dot_product_with_docs,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base", query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
@ -405,7 +400,7 @@ def test_join_with_reader(document_store_dot_product_with_docs, reader):
join_node = JoinDocuments() join_node = JoinDocuments()
p = Pipeline() p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"]) p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
p.add_node(component=reader, name="Reader", inputs=["Join"]) p.add_node(component=reader, name="Reader", inputs=["Join"])
@ -414,10 +409,9 @@ def test_join_with_reader(document_store_dot_product_with_docs, reader):
assert results["answers"][0].answer == "Berlin" or results["answers"][1].answer == "Berlin" assert results["answers"][0].answer == "Berlin" or results["answers"][1].answer == "Berlin"
@pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
def test_join_with_rrf(document_store_dot_product_with_docs): def test_join_with_rrf(document_store_dot_product_with_docs):
es = BM25Retriever(document_store=document_store_dot_product_with_docs) bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
dpr = DensePassageRetriever( dpr = DensePassageRetriever(
document_store=document_store_dot_product_with_docs, document_store=document_store_dot_product_with_docs,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base", query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
@ -430,7 +424,7 @@ def test_join_with_rrf(document_store_dot_product_with_docs):
join_node = JoinDocuments(join_mode="reciprocal_rank_fusion") join_node = JoinDocuments(join_mode="reciprocal_rank_fusion")
p = Pipeline() p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"]) p.add_node(component=bm25, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"]) p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query) results = p.run(query=query)
@ -444,7 +438,9 @@ def test_join_with_rrf(document_store_dot_product_with_docs):
0.031009615384615385, 0.031009615384615385,
] ]
assert all([doc.score == expected_scores[idx] for idx, doc in enumerate(results["documents"])]) assert all(
doc.score == pytest.approx(expected_scores[idx], abs=1e-3) for idx, doc in enumerate(results["documents"])
)
def test_query_keyword_statement_classifier(): def test_query_keyword_statement_classifier():