Add support for aggregating scores in JoinDocuments node (#683)

This commit is contained in:
Tanay Soni 2020-12-16 15:54:58 +01:00 committed by GitHub
parent 143da4cb3f
commit 4c2804e38e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 110 additions and 10 deletions

View File

@ -1,3 +1,4 @@
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Dict
@ -280,19 +281,63 @@ class QueryNode:
class JoinDocuments:
"""
A node to join documents outputted by multiple retriever nodes.
The node allows multiple join modes:
* concatenate: combine the documents from multiple nodes. Any duplicate documents are discarded.
* merge: merge scores of documents from multiple nodes. Optionally, each input score can be given a different
`weight` & a `top_k` limit can be set. This mode can also be used for "reranking" retrieved documents.
"""
outgoing_edges = 1
def __init__(self, join_mode="concatenate"):
pass
def __init__(
self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None
):
"""
:param join_mode: `concatenate` to combine documents from multiple retrievers or `merge` to aggregate scores of
individual documents.
:param weights: A node-wise list(length of list must be equal to the number of input nodes) of weights for
adjusting document scores when using the `merge` join_mode. By default, equal weight is given
to each retriever score. This param is not compatible with the `concatenate` join_mode.
:param top_k_join: Limit documents to top_k based on the resulting scores of the join.
"""
assert join_mode in ["concatenate", "merge"], f"JoinDocuments node does not support '{join_mode}' join_mode."
assert not (
weights is not None and join_mode == "concatenate"
), "Weights are not compatible with 'concatenate' join_mode."
self.join_mode = join_mode
self.weights = weights
self.top_k = top_k_join
def run(self, **kwargs):
inputs = kwargs["inputs"]
documents = []
for i, _ in inputs:
documents.extend(i["documents"])
output = {
"query": inputs[0][0]["query"],
"documents": documents
}
if self.join_mode == "concatenate":
document_map = {}
for input_from_node, _ in inputs:
for doc in input_from_node["documents"]:
document_map[doc.id] = doc
elif self.join_mode == "merge":
document_map = {}
if self.weights:
weights = self.weights
else:
weights = [1/len(inputs)] * len(inputs)
for (input_from_node, _), weight in zip(inputs, weights):
for doc in input_from_node["documents"]:
if document_map.get(doc.id): # document already exists; update score
document_map[doc.id].score += doc.score * weight
else: # add the document in map
document_map[doc.id] = deepcopy(doc)
document_map[doc.id].score *= weight
else:
raise Exception(f"Invalid join_mode: {self.join_mode}")
documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True)
if self.top_k:
documents = documents[: self.top_k]
output = {"query": inputs[0][0]["query"], "documents": documents}
return output, "output_1"

View File

@ -1,7 +1,9 @@
import pytest
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.pipeline import ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline
from haystack.pipeline import JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline
from haystack.retriever.dense import DensePassageRetriever
from haystack.retriever.sparse import ElasticsearchRetriever
@pytest.mark.slow
@ -117,3 +119,56 @@ def test_document_search_pipeline(retriever, document_store):
if isinstance(document_store, ElasticsearchDocumentStore):
output = pipeline.run(query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5)
assert len(output["documents"]) == 1
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_join_document_pipeline(document_store_with_docs, reader):
es = ElasticsearchRetriever(document_store=document_store_with_docs)
dpr = DensePassageRetriever(
document_store=document_store_with_docs,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
)
document_store_with_docs.update_embeddings(dpr)
query = "Where does Carla lives?"
# test merge without weights
join_node = JoinDocuments(join_mode="merge")
p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query)
assert len(results["documents"]) == 3
# test merge with weights
join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2)
p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query)
assert results["documents"][0].score > 1000
assert len(results["documents"]) == 2
# test concatenate
join_node = JoinDocuments(join_mode="concatenate")
p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
results = p.run(query=query)
assert len(results["documents"]) == 3
# test join_node with reader
join_node = JoinDocuments()
p = Pipeline()
p.add_node(component=es, name="R1", inputs=["Query"])
p.add_node(component=dpr, name="R2", inputs=["Query"])
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
p.add_node(component=reader, name="Reader", inputs=["Join"])
results = p.run(query=query)
assert results["answers"][0]["answer"] == "Berlin"