From 4ef2a680bbfa10ddd770a26966d44d661d2ffd9c Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Mon, 20 Nov 2023 10:56:56 +0100 Subject: [PATCH] feat: Add DocumentJoiner component 2.0 (#6105) * draft DocumentJoiner * implement merge and rrf * draft end-to-end test with DocumentJoiner in hybrid doc search pipeline * adjust for variadics Canals PR #122 * fix text_embedder input * adapt to the new Document class * adapt to new doc id * specify documents input as Variadic in run method * compare doc ids instead of full docs * rename text_file_converter input to sources * update docstring * Update haystack/preview/components/routers/document_joiner.py Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from docstring review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * capitalize Documents and Retrievers in docstrings * fix log message in test --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Co-authored-by: anakin87 Co-authored-by: ZanSara Co-authored-by: Massimiliano Pippi Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> --- .../test_hybrid_doc_search_pipeline.py | 57 +++++++ .../pipelines/test_preprocessing_pipeline.py | 4 +- .../components/routers/document_joiner.py | 145 ++++++++++++++++++ .../document-joiner-2-126bd60c84be6efa.yaml | 5 + .../routers/test_document_joiner.py | 140 +++++++++++++++++ 5 files changed, 349 insertions(+), 2 deletions(-) create mode 100644 e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py create mode 100644 haystack/preview/components/routers/document_joiner.py create mode 100644 releasenotes/notes/document-joiner-2-126bd60c84be6efa.yaml create mode 100644 test/preview/components/routers/test_document_joiner.py diff --git a/e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py b/e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py new file mode 100644 index 000000000..e85db341d --- /dev/null +++ b/e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py @@ -0,0 +1,57 @@ +import json + +from haystack.preview import Pipeline, Document +from haystack.preview.components.embedders import SentenceTransformersTextEmbedder +from haystack.preview.components.rankers import TransformersSimilarityRanker +from haystack.preview.components.routers.document_joiner import DocumentJoiner +from haystack.preview.document_stores import InMemoryDocumentStore +from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever + + +def test_hybrid_doc_search_pipeline(tmp_path): + # Create the pipeline + document_store = InMemoryDocumentStore() + hybrid_pipeline = Pipeline() + hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever") + hybrid_pipeline.add_component( + instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), + name="text_embedder", + ) + hybrid_pipeline.add_component( + instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever" + ) + hybrid_pipeline.add_component(instance=DocumentJoiner(), name="joiner") + hybrid_pipeline.add_component(instance=TransformersSimilarityRanker(top_k=20), name="ranker") + + hybrid_pipeline.connect("bm25_retriever", "joiner") + hybrid_pipeline.connect("text_embedder", "embedding_retriever") + hybrid_pipeline.connect("embedding_retriever", "joiner") + hybrid_pipeline.connect("joiner", "ranker") + + # Draw the pipeline + hybrid_pipeline.draw(tmp_path / "test_hybrid_doc_search_pipeline.png") + + # Serialize the pipeline to JSON + with open(tmp_path / "test_hybrid_doc_search_pipeline.json", "w") as f: + print(json.dumps(hybrid_pipeline.to_dict(), indent=4)) + json.dump(hybrid_pipeline.to_dict(), f) + + # Load the pipeline back + with open(tmp_path / "test_hybrid_doc_search_pipeline.json", "r") as f: + hybrid_pipeline = Pipeline.from_dict(json.load(f)) + + # Populate the document store + documents = [ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Mario and I live in the capital of Italy."), + Document(content="My name is Giorgio and I live in Rome."), + ] + hybrid_pipeline.get_component("bm25_retriever").document_store.write_documents(documents) + + query = "Who lives in Rome?" + result = hybrid_pipeline.run( + {"bm25_retriever": {"query": query}, "text_embedder": {"text": query}, "ranker": {"query": query}} + ) + assert result["ranker"]["documents"][0].content == "My name is Giorgio and I live in Rome." + assert result["ranker"]["documents"][1].content == "My name is Mario and I live in the capital of Italy." diff --git a/e2e/preview/pipelines/test_preprocessing_pipeline.py b/e2e/preview/pipelines/test_preprocessing_pipeline.py index 724d7bc8c..d7b17909c 100644 --- a/e2e/preview/pipelines/test_preprocessing_pipeline.py +++ b/e2e/preview/pipelines/test_preprocessing_pipeline.py @@ -29,7 +29,7 @@ def test_preprocessing_pipeline(tmp_path): name="embedder", ) preprocessing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="writer") - preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.paths") + preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.sources") preprocessing_pipeline.connect("text_file_converter.documents", "language_classifier.documents") preprocessing_pipeline.connect("language_classifier.documents", "router.documents") preprocessing_pipeline.connect("router.en", "cleaner.documents") @@ -75,7 +75,7 @@ def test_preprocessing_pipeline(tmp_path): filled_document_store = preprocessing_pipeline.get_component("writer").document_store assert filled_document_store.count_documents() == 6 - # Check preprocessed texts and mime_types + # Check preprocessed texts stored_documents = filled_document_store.filter_documents() expected_texts = [ "This is an english sentence.", diff --git a/haystack/preview/components/routers/document_joiner.py b/haystack/preview/components/routers/document_joiner.py new file mode 100644 index 000000000..1a966c209 --- /dev/null +++ b/haystack/preview/components/routers/document_joiner.py @@ -0,0 +1,145 @@ +import itertools +import logging +from collections import defaultdict +from math import inf +from typing import List, Optional +from canals.component.types import Variadic + +from haystack.preview import component, Document + + +logger = logging.getLogger(__name__) + + +@component +class DocumentJoiner: + """ + A component that joins input lists of Documents from multiple connections and outputs them as one list. + + The component allows multiple join modes: + * concatenate: Combine Documents from multiple components. Discards duplicate Documents. + Documents get their scores from the last component in the pipeline that assigns scores. + This join mode doesn't influence Document scores. + * merge: Merge scores of duplicate Documents coming from multiple components. + Optionally, you can assign a weight to the scores and set the top_k limit for this join mode. + You can also use this join mode to rerank retrieved Documents. + * reciprocal_rank_fusion: Combine Documents into a single list based on their ranking received from multiple components. + + Example usage in a hybrid retrieval pipeline: + ```python + document_store = InMemoryDocumentStore() + p = Pipeline() + p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever") + p.add_component(instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever") + p.add_component(instance=DocumentJoiner(), name="joiner") + p.connect("bm25_retriever", "joiner") + p.connect("embedding_retriever", "joiner") + ``` + """ + + def __init__( + self, + join_mode: str = "concatenate", + weights: Optional[List[float]] = None, + top_k: Optional[int] = None, + sort_by_score: bool = True, + ): + """ + Initialize the DocumentJoiner. + + :param join_mode: Specifies the join mode to use. Available modes: `concatenate` to combine Documents from multiple Retrievers, `merge` to aggregate the scores of + individual Documents, `reciprocal_rank_fusion` to apply rank-based scoring. + :param weights: A component-wise list (the length of the list must be equal to the number of input components) of weights for + adjusting Document scores when using the `merge` join_mode. By default, equal weight is given + to each Retriever score. This param is not compatible with the `concatenate` join_mode. + :param top_k: The maximum number of Documents to be returned as output. By default, returns all Documents. + :param sort_by_score: Whether the output list of Documents should be sorted by Document scores in descending order. + By default, the output is sorted. + Documents without score are handled as if their score was -infinity. + """ + if join_mode not in ["concatenate", "merge", "reciprocal_rank_fusion"]: + raise ValueError(f"DocumentJoiner component does not support '{join_mode}' join_mode.") + self.join_mode = join_mode + self.weights = [float(i) / sum(weights) for i in weights] if weights else None + self.top_k = top_k + self.sort_by_score = sort_by_score + + @component.output_types(documents=List[Document]) + def run(self, documents: Variadic[List[Document]]): + """ + Run the DocumentJoiner. This method joins the input lists of Documents into one output list based on the join_mode specified during initialization. + + :param documents: An arbitrary number of lists of Documents to join. + """ + output_documents = [] + if self.join_mode == "concatenate": + output_documents = self._concatenate(documents) + elif self.join_mode == "merge": + output_documents = self._merge(documents) + elif self.join_mode == "reciprocal_rank_fusion": + output_documents = self._reciprocal_rank_fusion(documents) + + if self.sort_by_score: + output_documents = sorted( + output_documents, key=lambda doc: doc.score if doc.score is not None else -inf, reverse=True + ) + if any(doc.score is None for doc in output_documents): + logger.info( + "Some of the Documents DocumentJoiner got have score=None. It was configured to sort Documents by " + "score, so those with score=None were sorted as if they had a score of -infinity." + ) + + if self.top_k: + output_documents = output_documents[: self.top_k] + return {"documents": output_documents} + + def _concatenate(self, document_lists): + """ + Concatenate multiple lists of Documents and return only the Document with the highest score for duplicate Documents. + """ + output = [] + docs_per_id = defaultdict(list) + for doc in itertools.chain.from_iterable(document_lists): + docs_per_id[doc.id].append(doc) + for docs in docs_per_id.values(): + doc_with_best_score = max(docs, key=lambda doc: doc.score if doc.score else -inf) + output.append(doc_with_best_score) + return output + + def _merge(self, document_lists): + """ + Merge multiple lists of Documents and calculate a weighted sum of the scores of duplicate Documents. + """ + scores_map = defaultdict(int) + documents_map = {} + weights = self.weights if self.weights else [1 / len(document_lists)] * len(document_lists) + + for documents, weight in zip(document_lists, weights): + for doc in documents: + scores_map[doc.id] += (doc.score if doc.score else 0) * weight + documents_map[doc.id] = doc + + for doc in documents_map.values(): + doc.score = scores_map[doc.id] + + return documents_map.values() + + def _reciprocal_rank_fusion(self, document_lists): + """ + Merge multiple lists of Documents and assign scores based on reciprocal rank fusion. + The constant k is set to 61 (60 was suggested by the original paper, + plus 1 as python lists are 0-based and the paper used 1-based ranking). + """ + k = 61 + + scores_map = defaultdict(int) + documents_map = {} + for documents in document_lists: + for rank, doc in enumerate(documents): + scores_map[doc.id] += 1 / (k + rank) + documents_map[doc.id] = doc + + for doc in documents_map.values(): + doc.score = scores_map[doc.id] + + return documents_map.values() diff --git a/releasenotes/notes/document-joiner-2-126bd60c84be6efa.yaml b/releasenotes/notes/document-joiner-2-126bd60c84be6efa.yaml new file mode 100644 index 000000000..13029d0ed --- /dev/null +++ b/releasenotes/notes/document-joiner-2-126bd60c84be6efa.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Added a new DocumentJoiner component so that hybrid retrieval pipelines can merge the document result lists of multiple retrievers. + Similarly, indexing pipelines can use DocumentJoiner to merge multiple lists of documents created by different file converters. diff --git a/test/preview/components/routers/test_document_joiner.py b/test/preview/components/routers/test_document_joiner.py new file mode 100644 index 000000000..9b4ab7bf2 --- /dev/null +++ b/test/preview/components/routers/test_document_joiner.py @@ -0,0 +1,140 @@ +import logging + +import pytest + +from haystack.preview import Document +from haystack.preview.components.routers.document_joiner import DocumentJoiner + + +class TestDocumentJoiner: + @pytest.mark.unit + def test_init(self): + joiner = DocumentJoiner() + assert joiner.join_mode == "concatenate" + assert joiner.weights is None + assert joiner.top_k is None + assert joiner.sort_by_score + + @pytest.mark.unit + def test_init_with_custom_parameters(self): + joiner = DocumentJoiner(join_mode="merge", weights=[0.4, 0.6], top_k=5, sort_by_score=False) + assert joiner.join_mode == "merge" + assert joiner.weights == [0.4, 0.6] + assert joiner.top_k == 5 + assert not joiner.sort_by_score + + @pytest.mark.unit + def test_empty_list(self): + joiner = DocumentJoiner() + result = joiner.run([]) + assert result == {"documents": []} + + @pytest.mark.unit + def test_list_of_empty_lists(self): + joiner = DocumentJoiner() + result = joiner.run([[], []]) + assert result == {"documents": []} + + @pytest.mark.unit + def test_list_with_one_empty_list(self): + joiner = DocumentJoiner() + documents = [Document(content="a"), Document(content="b"), Document(content="c")] + result = joiner.run([[], documents]) + assert result == {"documents": documents} + + @pytest.mark.unit + def test_unsupported_join_mode(self): + with pytest.raises(ValueError, match="DocumentJoiner component does not support 'unsupported_mode' join_mode."): + DocumentJoiner(join_mode="unsupported_mode") + + @pytest.mark.unit + def test_run_with_concatenate_join_mode_and_top_k(self): + joiner = DocumentJoiner(top_k=6) + documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")] + documents_2 = [ + Document(content="d"), + Document(content="e"), + Document(content="f", meta={"key": "value"}), + Document(content="g"), + ] + output = joiner.run([documents_1, documents_2]) + assert len(output["documents"]) == 6 + assert sorted(documents_1 + documents_2[:-1], key=lambda d: d.id) == sorted( + output["documents"], key=lambda d: d.id + ) + + @pytest.mark.unit + def test_run_with_concatenate_join_mode_and_duplicate_documents(self): + joiner = DocumentJoiner() + documents_1 = [Document(content="a", score=0.3), Document(content="b"), Document(content="c")] + documents_2 = [ + Document(content="a", score=0.2), + Document(content="a"), + Document(content="f", meta={"key": "value"}), + ] + output = joiner.run([documents_1, documents_2]) + assert len(output["documents"]) == 4 + assert sorted(documents_1 + [documents_2[-1]], key=lambda d: d.id) == sorted( + output["documents"], key=lambda d: d.id + ) + + @pytest.mark.unit + def test_run_with_merge_join_mode(self): + joiner = DocumentJoiner(join_mode="merge", weights=[1.5, 0.5]) + documents_1 = [Document(content="a", score=1.0), Document(content="b", score=2.0)] + documents_2 = [ + Document(content="a", score=0.5), + Document(content="b", score=3.0), + Document(content="f", score=4.0, meta={"key": "value"}), + ] + output = joiner.run([documents_1, documents_2]) + assert len(output["documents"]) == 3 + expected_document_ids = [ + doc.id + for doc in [ + Document(content="a", score=1.25), + Document(content="b", score=2.25), + Document(content="f", score=4.0, meta={"key": "value"}), + ] + ] + assert all(doc.id in expected_document_ids for doc in output["documents"]) + + @pytest.mark.unit + def test_run_with_reciprocal_rank_fusion_join_mode(self): + joiner = DocumentJoiner(join_mode="reciprocal_rank_fusion") + documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")] + documents_2 = [ + Document(content="b", score=1000.0), + Document(content="c"), + Document(content="a"), + Document(content="f", meta={"key": "value"}), + ] + output = joiner.run([documents_1, documents_2]) + assert len(output["documents"]) == 4 + expected_document_ids = [ + doc.id + for doc in [ + Document(content="b"), + Document(content="a"), + Document(content="c"), + Document(content="f", meta={"key": "value"}), + ] + ] + assert all(doc.id in expected_document_ids for doc in output["documents"]) + + @pytest.mark.unit + def test_sort_by_score_without_scores(self, caplog): + joiner = DocumentJoiner() + with caplog.at_level(logging.INFO): + documents = [Document(content="a"), Document(content="b", score=0.5)] + output = joiner.run([documents]) + assert "those with score=None were sorted as if they had a score of -infinity" in caplog.text + assert output["documents"] == documents[::-1] + + @pytest.mark.unit + def test_output_documents_not_sorted_by_score(self): + joiner = DocumentJoiner(sort_by_score=False) + documents_1 = [Document(content="a", score=0.1)] + documents_2 = [Document(content="d", score=0.2)] + output = joiner.run([documents_1, documents_2]) + assert output["documents"] == documents_1 + documents_2