mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 09:56:37 +00:00
test: replace ElasticsearchDS
with InMemoryDS
when it makes sense; support scale_score
in InMemoryDS
(#4283)
* replace elasticds with imds - first draft * fix * fix tests and implement scale_score in imds bm25 * add docstrings for scale_score
This commit is contained in:
parent
ee74421212
commit
e8f9b1b65d
@ -16,6 +16,7 @@ import torch
|
||||
from tqdm.auto import tqdm
|
||||
import rank_bm25
|
||||
import pandas as pd
|
||||
from scipy.special import expit
|
||||
|
||||
from haystack.schema import Document, FilterType, Label
|
||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
||||
@ -953,7 +954,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
all_terms_must_match: bool = False,
|
||||
scale_score: bool = False,
|
||||
scale_score: bool = True,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
@ -961,6 +962,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
:param query: The query.
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents.
|
||||
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||||
"""
|
||||
|
||||
if headers:
|
||||
@ -973,10 +975,6 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
logger.warning(
|
||||
"InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored."
|
||||
)
|
||||
if scale_score is True:
|
||||
logger.warning(
|
||||
"InMemoryDocumentStore does not support scale_score for BM25 retrieval. This parameter is ignored."
|
||||
)
|
||||
|
||||
index = index or self.index
|
||||
if index not in self.bm25:
|
||||
@ -989,6 +987,9 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
|
||||
tokenized_query = self.bm25_tokenization_regex(query.lower())
|
||||
docs_scores = self.bm25[index].get_scores(tokenized_query)
|
||||
if scale_score is True:
|
||||
# scaling probability from BM25
|
||||
docs_scores = [float(expit(np.asarray(score / 8))) for score in docs_scores]
|
||||
top_docs_positions = np.argsort(docs_scores)[::-1][:top_k]
|
||||
|
||||
textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type in ["text", "table"]]
|
||||
@ -1009,7 +1010,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
all_terms_must_match: bool = False,
|
||||
scale_score: bool = False,
|
||||
scale_score: bool = True,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
@ -1018,6 +1019,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
:param query: The query.
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents.
|
||||
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||||
"""
|
||||
|
||||
if headers:
|
||||
@ -1030,10 +1032,6 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
logger.warning(
|
||||
"InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored."
|
||||
)
|
||||
if scale_score is True:
|
||||
logger.warning(
|
||||
"InMemoryDocumentStore does not support scale_score for BM25 retrieval. This parameter is ignored."
|
||||
)
|
||||
|
||||
index = index or self.index
|
||||
if index not in self.bm25:
|
||||
@ -1043,6 +1041,6 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
|
||||
result_documents = []
|
||||
for query in queries:
|
||||
result_documents.append(self.query(query=query, top_k=top_k, index=index))
|
||||
result_documents.append(self.query(query=query, top_k=top_k, index=index, scale_score=scale_score))
|
||||
|
||||
return result_documents
|
||||
|
@ -863,6 +863,7 @@ def get_document_store(
|
||||
index=index,
|
||||
similarity=similarity,
|
||||
use_bm25=True,
|
||||
bm25_parameters={"k1": 1.2, "b": 0.75}, # parameters similar to those of Elasticsearch
|
||||
)
|
||||
|
||||
elif document_store_type == "elasticsearch":
|
||||
|
@ -1,22 +1,18 @@
|
||||
import pytest
|
||||
|
||||
from haystack.pipelines import TranslationWrapperPipeline, ExtractiveQAPipeline
|
||||
from haystack.nodes import DensePassageRetriever, EmbeddingRetriever
|
||||
from .test_summarizer import SPLIT_DOCS
|
||||
|
||||
|
||||
# Keeping few (retriever,document_store,reader) combination to reduce test time
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.summarizer
|
||||
@pytest.mark.parametrize("retriever,document_store,reader", [("embedding", "memory", "farm")], indirect=True)
|
||||
def test_extractive_qa_pipeline_with_translator(
|
||||
document_store, retriever, reader, en_to_de_translator, de_to_en_translator
|
||||
):
|
||||
document_store.write_documents(SPLIT_DOCS)
|
||||
|
||||
if isinstance(retriever, EmbeddingRetriever) or isinstance(retriever, DensePassageRetriever):
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
query = "Wo steht der Eiffelturm?"
|
||||
base_pipeline = ExtractiveQAPipeline(retriever=retriever, reader=reader)
|
||||
|
@ -161,6 +161,7 @@ def test_eval_reader(reader, document_store, use_confidence_scores):
|
||||
assert reader_eval_results["top_n_accuracy"] == 100.0
|
||||
|
||||
|
||||
# using ElasticsearchDocumentStore, since InMemoryDocumentStore doesn't return meaningful BM25 scores when there are very few documents
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("open_domain", [True, False])
|
||||
|
@ -21,8 +21,7 @@ class MockRetriever(BaseMockRetriever):
|
||||
raise ValueError("TEST ERROR!")
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_node_names_validation(document_store_with_docs, tmp_path):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(
|
||||
@ -52,28 +51,27 @@ def test_node_names_validation(document_store_with_docs, tmp_path):
|
||||
assert "top_k" not in exception_raised
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_debug_attributes_global(document_store_with_docs, tmp_path):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
bm25_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=bm25_retriever, name="BM25Retriever", inputs=["Query"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["BM25Retriever"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 10}, "Reader": {"top_k": 3}}, debug=True
|
||||
query="Who lives in Berlin?", params={"BM25Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}, debug=True
|
||||
)
|
||||
assert "_debug" in prediction.keys()
|
||||
assert "ESRetriever" in prediction["_debug"].keys()
|
||||
assert "BM25Retriever" in prediction["_debug"].keys()
|
||||
assert "Reader" in prediction["_debug"].keys()
|
||||
assert "input" in prediction["_debug"]["ESRetriever"].keys()
|
||||
assert "output" in prediction["_debug"]["ESRetriever"].keys()
|
||||
assert "input" in prediction["_debug"]["BM25Retriever"].keys()
|
||||
assert "output" in prediction["_debug"]["BM25Retriever"].keys()
|
||||
assert "input" in prediction["_debug"]["Reader"].keys()
|
||||
assert "output" in prediction["_debug"]["Reader"].keys()
|
||||
assert prediction["_debug"]["ESRetriever"]["input"]
|
||||
assert prediction["_debug"]["ESRetriever"]["output"]
|
||||
assert prediction["_debug"]["BM25Retriever"]["input"]
|
||||
assert prediction["_debug"]["BM25Retriever"]["output"]
|
||||
assert prediction["_debug"]["Reader"]["input"]
|
||||
assert prediction["_debug"]["Reader"]["output"]
|
||||
|
||||
@ -81,57 +79,55 @@ def test_debug_attributes_global(document_store_with_docs, tmp_path):
|
||||
json.dumps(prediction, default=str)
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
bm25_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=bm25_retriever, name="BM25Retriever", inputs=["Query"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["BM25Retriever"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3}}
|
||||
query="Who lives in Berlin?", params={"BM25Retriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3}}
|
||||
)
|
||||
assert "_debug" in prediction.keys()
|
||||
assert "ESRetriever" in prediction["_debug"].keys()
|
||||
assert "BM25Retriever" in prediction["_debug"].keys()
|
||||
assert "Reader" not in prediction["_debug"].keys()
|
||||
assert "input" in prediction["_debug"]["ESRetriever"].keys()
|
||||
assert "output" in prediction["_debug"]["ESRetriever"].keys()
|
||||
assert prediction["_debug"]["ESRetriever"]["input"]
|
||||
assert prediction["_debug"]["ESRetriever"]["output"]
|
||||
assert "input" in prediction["_debug"]["BM25Retriever"].keys()
|
||||
assert "output" in prediction["_debug"]["BM25Retriever"].keys()
|
||||
assert prediction["_debug"]["BM25Retriever"]["input"]
|
||||
assert prediction["_debug"]["BM25Retriever"]["output"]
|
||||
|
||||
# Avoid circular reference: easiest way to detect those is to use json.dumps
|
||||
json.dumps(prediction, default=str)
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_debug_attributes_for_join_nodes(document_store_with_docs, tmp_path):
|
||||
es_retriever_1 = BM25Retriever(document_store=document_store_with_docs)
|
||||
es_retriever_2 = BM25Retriever(document_store=document_store_with_docs)
|
||||
bm25_retriever_1 = BM25Retriever(document_store=document_store_with_docs)
|
||||
bm25_retriever_2 = BM25Retriever(document_store=document_store_with_docs)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever_1, name="ESRetriever1", inputs=["Query"])
|
||||
pipeline.add_node(component=es_retriever_2, name="ESRetriever2", inputs=["Query"])
|
||||
pipeline.add_node(component=JoinDocuments(), name="JoinDocuments", inputs=["ESRetriever1", "ESRetriever2"])
|
||||
pipeline.add_node(component=bm25_retriever_1, name="BM25Retriever1", inputs=["Query"])
|
||||
pipeline.add_node(component=bm25_retriever_2, name="BM25Retriever2", inputs=["Query"])
|
||||
pipeline.add_node(component=JoinDocuments(), name="JoinDocuments", inputs=["BM25Retriever1", "BM25Retriever2"])
|
||||
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", debug=True)
|
||||
assert "_debug" in prediction.keys()
|
||||
assert "ESRetriever1" in prediction["_debug"].keys()
|
||||
assert "ESRetriever2" in prediction["_debug"].keys()
|
||||
assert "BM25Retriever1" in prediction["_debug"].keys()
|
||||
assert "BM25Retriever2" in prediction["_debug"].keys()
|
||||
assert "JoinDocuments" in prediction["_debug"].keys()
|
||||
assert "input" in prediction["_debug"]["ESRetriever1"].keys()
|
||||
assert "output" in prediction["_debug"]["ESRetriever1"].keys()
|
||||
assert "input" in prediction["_debug"]["ESRetriever2"].keys()
|
||||
assert "output" in prediction["_debug"]["ESRetriever2"].keys()
|
||||
assert "input" in prediction["_debug"]["BM25Retriever1"].keys()
|
||||
assert "output" in prediction["_debug"]["BM25Retriever1"].keys()
|
||||
assert "input" in prediction["_debug"]["BM25Retriever2"].keys()
|
||||
assert "output" in prediction["_debug"]["BM25Retriever2"].keys()
|
||||
assert "input" in prediction["_debug"]["JoinDocuments"].keys()
|
||||
assert "output" in prediction["_debug"]["JoinDocuments"].keys()
|
||||
assert prediction["_debug"]["ESRetriever1"]["input"]
|
||||
assert prediction["_debug"]["ESRetriever1"]["output"]
|
||||
assert prediction["_debug"]["ESRetriever2"]["input"]
|
||||
assert prediction["_debug"]["ESRetriever2"]["output"]
|
||||
assert prediction["_debug"]["BM25Retriever1"]["input"]
|
||||
assert prediction["_debug"]["BM25Retriever1"]["output"]
|
||||
assert prediction["_debug"]["BM25Retriever2"]["input"]
|
||||
assert prediction["_debug"]["BM25Retriever2"]["output"]
|
||||
assert prediction["_debug"]["JoinDocuments"]["input"]
|
||||
assert prediction["_debug"]["JoinDocuments"]["output"]
|
||||
|
||||
@ -139,37 +135,36 @@ def test_debug_attributes_for_join_nodes(document_store_with_docs, tmp_path):
|
||||
json.dumps(prediction, default=str)
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_global_debug_attributes_override_node_ones(document_store_with_docs, tmp_path):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
bm25_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=bm25_retriever, name="BM25Retriever", inputs=["Query"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["BM25Retriever"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={"ESRetriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3, "debug": True}},
|
||||
params={"BM25Retriever": {"top_k": 10, "debug": True}, "Reader": {"top_k": 3, "debug": True}},
|
||||
debug=False,
|
||||
)
|
||||
assert "_debug" not in prediction.keys()
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={"ESRetriever": {"top_k": 10, "debug": False}, "Reader": {"top_k": 3, "debug": False}},
|
||||
params={"BM25Retriever": {"top_k": 10, "debug": False}, "Reader": {"top_k": 3, "debug": False}},
|
||||
debug=True,
|
||||
)
|
||||
assert "_debug" in prediction.keys()
|
||||
assert "ESRetriever" in prediction["_debug"].keys()
|
||||
assert "BM25Retriever" in prediction["_debug"].keys()
|
||||
assert "Reader" in prediction["_debug"].keys()
|
||||
assert "input" in prediction["_debug"]["ESRetriever"].keys()
|
||||
assert "output" in prediction["_debug"]["ESRetriever"].keys()
|
||||
assert "input" in prediction["_debug"]["BM25Retriever"].keys()
|
||||
assert "output" in prediction["_debug"]["BM25Retriever"].keys()
|
||||
assert "input" in prediction["_debug"]["Reader"].keys()
|
||||
assert "output" in prediction["_debug"]["Reader"].keys()
|
||||
assert prediction["_debug"]["ESRetriever"]["input"]
|
||||
assert prediction["_debug"]["ESRetriever"]["output"]
|
||||
assert prediction["_debug"]["BM25Retriever"]["input"]
|
||||
assert prediction["_debug"]["BM25Retriever"]["output"]
|
||||
assert prediction["_debug"]["Reader"]["input"]
|
||||
assert prediction["_debug"]["Reader"]["output"]
|
||||
|
||||
|
@ -293,10 +293,9 @@ def test_most_similar_documents_pipeline_with_filters_batch(retriever, document_
|
||||
assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"]
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
|
||||
def test_join_merge_no_weights(document_store_dot_product_with_docs):
|
||||
es = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
dpr = DensePassageRetriever(
|
||||
document_store=document_store_dot_product_with_docs,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
@ -309,17 +308,16 @@ def test_join_merge_no_weights(document_store_dot_product_with_docs):
|
||||
|
||||
join_node = JoinDocuments(join_mode="merge")
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
results = p.run(query=query)
|
||||
assert len(results["documents"]) == 5
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
|
||||
def test_join_merge_with_weights(document_store_dot_product_with_docs):
|
||||
es = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
dpr = DensePassageRetriever(
|
||||
document_store=document_store_dot_product_with_docs,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
@ -332,18 +330,17 @@ def test_join_merge_with_weights(document_store_dot_product_with_docs):
|
||||
|
||||
join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2)
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
results = p.run(query=query)
|
||||
assert math.isclose(results["documents"][0].score, 0.5481393431183286, rel_tol=0.0001)
|
||||
assert math.isclose(results["documents"][0].score, 0.5336782589721345, rel_tol=0.0001)
|
||||
assert len(results["documents"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
|
||||
def test_join_concatenate(document_store_dot_product_with_docs):
|
||||
es = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
dpr = DensePassageRetriever(
|
||||
document_store=document_store_dot_product_with_docs,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
@ -356,17 +353,16 @@ def test_join_concatenate(document_store_dot_product_with_docs):
|
||||
|
||||
join_node = JoinDocuments(join_mode="concatenate")
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
results = p.run(query=query)
|
||||
assert len(results["documents"]) == 5
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
|
||||
def test_join_concatenate_with_topk(document_store_dot_product_with_docs):
|
||||
es = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
dpr = DensePassageRetriever(
|
||||
document_store=document_store_dot_product_with_docs,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
@ -379,7 +375,7 @@ def test_join_concatenate_with_topk(document_store_dot_product_with_docs):
|
||||
|
||||
join_node = JoinDocuments(join_mode="concatenate")
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
one_result = p.run(query=query, params={"Join": {"top_k_join": 1}})
|
||||
@ -388,11 +384,10 @@ def test_join_concatenate_with_topk(document_store_dot_product_with_docs):
|
||||
assert len(two_results["documents"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_join_with_reader(document_store_dot_product_with_docs, reader):
|
||||
es = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
dpr = DensePassageRetriever(
|
||||
document_store=document_store_dot_product_with_docs,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
@ -405,7 +400,7 @@ def test_join_with_reader(document_store_dot_product_with_docs, reader):
|
||||
|
||||
join_node = JoinDocuments()
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
p.add_node(component=reader, name="Reader", inputs=["Join"])
|
||||
@ -414,10 +409,9 @@ def test_join_with_reader(document_store_dot_product_with_docs, reader):
|
||||
assert results["answers"][0].answer == "Berlin" or results["answers"][1].answer == "Berlin"
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["memory"], indirect=True)
|
||||
def test_join_with_rrf(document_store_dot_product_with_docs):
|
||||
es = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
bm25 = BM25Retriever(document_store=document_store_dot_product_with_docs)
|
||||
dpr = DensePassageRetriever(
|
||||
document_store=document_store_dot_product_with_docs,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
@ -430,7 +424,7 @@ def test_join_with_rrf(document_store_dot_product_with_docs):
|
||||
|
||||
join_node = JoinDocuments(join_mode="reciprocal_rank_fusion")
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=bm25, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
results = p.run(query=query)
|
||||
@ -444,7 +438,9 @@ def test_join_with_rrf(document_store_dot_product_with_docs):
|
||||
0.031009615384615385,
|
||||
]
|
||||
|
||||
assert all([doc.score == expected_scores[idx] for idx, doc in enumerate(results["documents"])])
|
||||
assert all(
|
||||
doc.score == pytest.approx(expected_scores[idx], abs=1e-3) for idx, doc in enumerate(results["documents"])
|
||||
)
|
||||
|
||||
|
||||
def test_query_keyword_statement_classifier():
|
||||
|
Loading…
x
Reference in New Issue
Block a user