diff --git a/haystack/pipeline.py b/haystack/pipeline.py index 507b99c1a..e326ebf9f 100644 --- a/haystack/pipeline.py +++ b/haystack/pipeline.py @@ -1,3 +1,4 @@ +from copy import deepcopy from pathlib import Path from typing import List, Optional, Dict @@ -280,19 +281,63 @@ class QueryNode: class JoinDocuments: + """ + A node to join documents outputted by multiple retriever nodes. + + The node allows multiple join modes: + * concatenate: combine the documents from multiple nodes. Any duplicate documents are discarded. + * merge: merge scores of documents from multiple nodes. Optionally, each input score can be given a different + `weight` & a `top_k` limit can be set. This mode can also be used for "reranking" retrieved documents. + """ + outgoing_edges = 1 - def __init__(self, join_mode="concatenate"): - pass + def __init__( + self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None + ): + """ + :param join_mode: `concatenate` to combine documents from multiple retrievers or `merge` to aggregate scores of + individual documents. + :param weights: A node-wise list(length of list must be equal to the number of input nodes) of weights for + adjusting document scores when using the `merge` join_mode. By default, equal weight is given + to each retriever score. This param is not compatible with the `concatenate` join_mode. + :param top_k_join: Limit documents to top_k based on the resulting scores of the join. + """ + assert join_mode in ["concatenate", "merge"], f"JoinDocuments node does not support '{join_mode}' join_mode." + + assert not ( + weights is not None and join_mode == "concatenate" + ), "Weights are not compatible with 'concatenate' join_mode." + self.join_mode = join_mode + self.weights = weights + self.top_k = top_k_join def run(self, **kwargs): inputs = kwargs["inputs"] - documents = [] - for i, _ in inputs: - documents.extend(i["documents"]) - output = { - "query": inputs[0][0]["query"], - "documents": documents - } + if self.join_mode == "concatenate": + document_map = {} + for input_from_node, _ in inputs: + for doc in input_from_node["documents"]: + document_map[doc.id] = doc + elif self.join_mode == "merge": + document_map = {} + if self.weights: + weights = self.weights + else: + weights = [1/len(inputs)] * len(inputs) + for (input_from_node, _), weight in zip(inputs, weights): + for doc in input_from_node["documents"]: + if document_map.get(doc.id): # document already exists; update score + document_map[doc.id].score += doc.score * weight + else: # add the document in map + document_map[doc.id] = deepcopy(doc) + document_map[doc.id].score *= weight + else: + raise Exception(f"Invalid join_mode: {self.join_mode}") + + documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True) + if self.top_k: + documents = documents[: self.top_k] + output = {"query": inputs[0][0]["query"], "documents": documents} return output, "output_1" diff --git a/test/test_pipeline.py b/test/test_pipeline.py index a59c76fdc..7109d99ae 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -1,7 +1,9 @@ import pytest from haystack.document_store.elasticsearch import ElasticsearchDocumentStore -from haystack.pipeline import ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline +from haystack.pipeline import JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline +from haystack.retriever.dense import DensePassageRetriever +from haystack.retriever.sparse import ElasticsearchRetriever @pytest.mark.slow @@ -117,3 +119,56 @@ def test_document_search_pipeline(retriever, document_store): if isinstance(document_store, ElasticsearchDocumentStore): output = pipeline.run(query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5) assert len(output["documents"]) == 1 + + +@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) +@pytest.mark.parametrize("reader", ["farm"], indirect=True) +def test_join_document_pipeline(document_store_with_docs, reader): + es = ElasticsearchRetriever(document_store=document_store_with_docs) + dpr = DensePassageRetriever( + document_store=document_store_with_docs, + query_embedding_model="facebook/dpr-question_encoder-single-nq-base", + passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", + use_gpu=False, + ) + document_store_with_docs.update_embeddings(dpr) + + query = "Where does Carla lives?" + + # test merge without weights + join_node = JoinDocuments(join_mode="merge") + p = Pipeline() + p.add_node(component=es, name="R1", inputs=["Query"]) + p.add_node(component=dpr, name="R2", inputs=["Query"]) + p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) + results = p.run(query=query) + assert len(results["documents"]) == 3 + + # test merge with weights + join_node = JoinDocuments(join_mode="merge", weights=[1000, 1], top_k_join=2) + p = Pipeline() + p.add_node(component=es, name="R1", inputs=["Query"]) + p.add_node(component=dpr, name="R2", inputs=["Query"]) + p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) + results = p.run(query=query) + assert results["documents"][0].score > 1000 + assert len(results["documents"]) == 2 + + # test concatenate + join_node = JoinDocuments(join_mode="concatenate") + p = Pipeline() + p.add_node(component=es, name="R1", inputs=["Query"]) + p.add_node(component=dpr, name="R2", inputs=["Query"]) + p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) + results = p.run(query=query) + assert len(results["documents"]) == 3 + + # test join_node with reader + join_node = JoinDocuments() + p = Pipeline() + p.add_node(component=es, name="R1", inputs=["Query"]) + p.add_node(component=dpr, name="R2", inputs=["Query"]) + p.add_node(component=join_node, name="Join", inputs=["R1", "R2"]) + p.add_node(component=reader, name="Reader", inputs=["Join"]) + results = p.run(query=query) + assert results["answers"][0]["answer"] == "Berlin"