mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-09 05:37:25 +00:00
Add support for aggregating scores in JoinDocuments node (#683)
This commit is contained in:
parent
143da4cb3f
commit
4c2804e38e
@ -1,3 +1,4 @@
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
@ -280,19 +281,63 @@ class QueryNode:
|
||||
|
||||
|
||||
class JoinDocuments:
|
||||
"""
|
||||
A node to join documents outputted by multiple retriever nodes.
|
||||
|
||||
The node allows multiple join modes:
|
||||
* concatenate: combine the documents from multiple nodes. Any duplicate documents are discarded.
|
||||
* merge: merge scores of documents from multiple nodes. Optionally, each input score can be given a different
|
||||
`weight` & a `top_k` limit can be set. This mode can also be used for "reranking" retrieved documents.
|
||||
"""
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(self, join_mode="concatenate"):
|
||||
pass
|
||||
def __init__(
|
||||
self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
:param join_mode: `concatenate` to combine documents from multiple retrievers or `merge` to aggregate scores of
|
||||
individual documents.
|
||||
:param weights: A node-wise list(length of list must be equal to the number of input nodes) of weights for
|
||||
adjusting document scores when using the `merge` join_mode. By default, equal weight is given
|
||||
to each retriever score. This param is not compatible with the `concatenate` join_mode.
|
||||
:param top_k_join: Limit documents to top_k based on the resulting scores of the join.
|
||||
"""
|
||||
assert join_mode in ["concatenate", "merge"], f"JoinDocuments node does not support '{join_mode}' join_mode."
|
||||
|
||||
assert not (
|
||||
weights is not None and join_mode == "concatenate"
|
||||
), "Weights are not compatible with 'concatenate' join_mode."
|
||||
self.join_mode = join_mode
|
||||
self.weights = weights
|
||||
self.top_k = top_k_join
|
||||
|
||||
def run(self, **kwargs):
|
||||
inputs = kwargs["inputs"]
|
||||
|
||||
documents = []
|
||||
for i, _ in inputs:
|
||||
documents.extend(i["documents"])
|
||||
output = {
|
||||
"query": inputs[0][0]["query"],
|
||||
"documents": documents
|
||||
}
|
||||
if self.join_mode == "concatenate":
|
||||
document_map = {}
|
||||
for input_from_node, _ in inputs:
|
||||
for doc in input_from_node["documents"]:
|
||||
document_map[doc.id] = doc
|
||||
elif self.join_mode == "merge":
|
||||
document_map = {}
|
||||
if self.weights:
|
||||
weights = self.weights
|
||||
else:
|
||||
weights = [1/len(inputs)] * len(inputs)
|
||||
for (input_from_node, _), weight in zip(inputs, weights):
|
||||
for doc in input_from_node["documents"]:
|
||||
if document_map.get(doc.id): # document already exists; update score
|
||||
document_map[doc.id].score += doc.score * weight
|
||||
else: # add the document in map
|
||||
document_map[doc.id] = deepcopy(doc)
|
||||
document_map[doc.id].score *= weight
|
||||
else:
|
||||
raise Exception(f"Invalid join_mode: {self.join_mode}")
|
||||
|
||||
documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True)
|
||||
if self.top_k:
|
||||
documents = documents[: self.top_k]
|
||||
output = {"query": inputs[0][0]["query"], "documents": documents}
|
||||
return output, "output_1"
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.pipeline import ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline
|
||||
from haystack.pipeline import JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -117,3 +119,56 @@ def test_document_search_pipeline(retriever, document_store):
|
||||
if isinstance(document_store, ElasticsearchDocumentStore):
|
||||
output = pipeline.run(query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5)
|
||||
assert len(output["documents"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_join_document_pipeline(document_store_with_docs, reader):
|
||||
es = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
dpr = DensePassageRetriever(
|
||||
document_store=document_store_with_docs,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False,
|
||||
)
|
||||
document_store_with_docs.update_embeddings(dpr)
|
||||
|
||||
query = "Where does Carla lives?"
|
||||
|
||||
# test merge without weights
|
||||
join_node = JoinDocuments(join_mode="merge")
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
results = p.run(query=query)
|
||||
assert len(results["documents"]) == 3
|
||||
|
||||
# test merge with weights
|
||||
join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2)
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
results = p.run(query=query)
|
||||
assert results["documents"][0].score > 1000
|
||||
assert len(results["documents"]) == 2
|
||||
|
||||
# test concatenate
|
||||
join_node = JoinDocuments(join_mode="concatenate")
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
results = p.run(query=query)
|
||||
assert len(results["documents"]) == 3
|
||||
|
||||
# test join_node with reader
|
||||
join_node = JoinDocuments()
|
||||
p = Pipeline()
|
||||
p.add_node(component=es, name="R1", inputs=["Query"])
|
||||
p.add_node(component=dpr, name="R2", inputs=["Query"])
|
||||
p.add_node(component=join_node, name="Join", inputs=["R1", "R2"])
|
||||
p.add_node(component=reader, name="Reader", inputs=["Join"])
|
||||
results = p.run(query=query)
|
||||
assert results["answers"][0]["answer"] == "Berlin"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user