feat: Add DocumentJoiner component 2.0 (#6105)

* draft DocumentJoiner

* implement merge and rrf

* draft end-to-end test with DocumentJoiner in hybrid doc search pipeline

* adjust for variadics Canals PR #122

* fix text_embedder input

* adapt to the new Document class

* adapt to new doc id

* specify documents input as Variadic in run method

* compare doc ids instead of full docs

* rename text_file_converter input to sources

* update docstring

* Update haystack/preview/components/routers/document_joiner.py

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from docstring review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* capitalize Documents and Retrievers in docstrings

* fix log message in test

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>
This commit is contained in:
Julian Risch 2023-11-20 10:56:56 +01:00 committed by GitHub
parent e905066458
commit 4ef2a680bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 349 additions and 2 deletions

View File

@ -0,0 +1,57 @@
import json
from haystack.preview import Pipeline, Document
from haystack.preview.components.embedders import SentenceTransformersTextEmbedder
from haystack.preview.components.rankers import TransformersSimilarityRanker
from haystack.preview.components.routers.document_joiner import DocumentJoiner
from haystack.preview.document_stores import InMemoryDocumentStore
from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
def test_hybrid_doc_search_pipeline(tmp_path):
# Create the pipeline
document_store = InMemoryDocumentStore()
hybrid_pipeline = Pipeline()
hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
hybrid_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
name="text_embedder",
)
hybrid_pipeline.add_component(
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever"
)
hybrid_pipeline.add_component(instance=DocumentJoiner(), name="joiner")
hybrid_pipeline.add_component(instance=TransformersSimilarityRanker(top_k=20), name="ranker")
hybrid_pipeline.connect("bm25_retriever", "joiner")
hybrid_pipeline.connect("text_embedder", "embedding_retriever")
hybrid_pipeline.connect("embedding_retriever", "joiner")
hybrid_pipeline.connect("joiner", "ranker")
# Draw the pipeline
hybrid_pipeline.draw(tmp_path / "test_hybrid_doc_search_pipeline.png")
# Serialize the pipeline to JSON
with open(tmp_path / "test_hybrid_doc_search_pipeline.json", "w") as f:
print(json.dumps(hybrid_pipeline.to_dict(), indent=4))
json.dump(hybrid_pipeline.to_dict(), f)
# Load the pipeline back
with open(tmp_path / "test_hybrid_doc_search_pipeline.json", "r") as f:
hybrid_pipeline = Pipeline.from_dict(json.load(f))
# Populate the document store
documents = [
Document(content="My name is Jean and I live in Paris."),
Document(content="My name is Mark and I live in Berlin."),
Document(content="My name is Mario and I live in the capital of Italy."),
Document(content="My name is Giorgio and I live in Rome."),
]
hybrid_pipeline.get_component("bm25_retriever").document_store.write_documents(documents)
query = "Who lives in Rome?"
result = hybrid_pipeline.run(
{"bm25_retriever": {"query": query}, "text_embedder": {"text": query}, "ranker": {"query": query}}
)
assert result["ranker"]["documents"][0].content == "My name is Giorgio and I live in Rome."
assert result["ranker"]["documents"][1].content == "My name is Mario and I live in the capital of Italy."

View File

@ -29,7 +29,7 @@ def test_preprocessing_pipeline(tmp_path):
name="embedder",
)
preprocessing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="writer")
preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.paths")
preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.sources")
preprocessing_pipeline.connect("text_file_converter.documents", "language_classifier.documents")
preprocessing_pipeline.connect("language_classifier.documents", "router.documents")
preprocessing_pipeline.connect("router.en", "cleaner.documents")
@ -75,7 +75,7 @@ def test_preprocessing_pipeline(tmp_path):
filled_document_store = preprocessing_pipeline.get_component("writer").document_store
assert filled_document_store.count_documents() == 6
# Check preprocessed texts and mime_types
# Check preprocessed texts
stored_documents = filled_document_store.filter_documents()
expected_texts = [
"This is an english sentence.",

View File

@ -0,0 +1,145 @@
import itertools
import logging
from collections import defaultdict
from math import inf
from typing import List, Optional
from canals.component.types import Variadic
from haystack.preview import component, Document
logger = logging.getLogger(__name__)
@component
class DocumentJoiner:
"""
A component that joins input lists of Documents from multiple connections and outputs them as one list.
The component allows multiple join modes:
* concatenate: Combine Documents from multiple components. Discards duplicate Documents.
Documents get their scores from the last component in the pipeline that assigns scores.
This join mode doesn't influence Document scores.
* merge: Merge scores of duplicate Documents coming from multiple components.
Optionally, you can assign a weight to the scores and set the top_k limit for this join mode.
You can also use this join mode to rerank retrieved Documents.
* reciprocal_rank_fusion: Combine Documents into a single list based on their ranking received from multiple components.
Example usage in a hybrid retrieval pipeline:
```python
document_store = InMemoryDocumentStore()
p = Pipeline()
p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
p.add_component(instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever")
p.add_component(instance=DocumentJoiner(), name="joiner")
p.connect("bm25_retriever", "joiner")
p.connect("embedding_retriever", "joiner")
```
"""
def __init__(
self,
join_mode: str = "concatenate",
weights: Optional[List[float]] = None,
top_k: Optional[int] = None,
sort_by_score: bool = True,
):
"""
Initialize the DocumentJoiner.
:param join_mode: Specifies the join mode to use. Available modes: `concatenate` to combine Documents from multiple Retrievers, `merge` to aggregate the scores of
individual Documents, `reciprocal_rank_fusion` to apply rank-based scoring.
:param weights: A component-wise list (the length of the list must be equal to the number of input components) of weights for
adjusting Document scores when using the `merge` join_mode. By default, equal weight is given
to each Retriever score. This param is not compatible with the `concatenate` join_mode.
:param top_k: The maximum number of Documents to be returned as output. By default, returns all Documents.
:param sort_by_score: Whether the output list of Documents should be sorted by Document scores in descending order.
By default, the output is sorted.
Documents without score are handled as if their score was -infinity.
"""
if join_mode not in ["concatenate", "merge", "reciprocal_rank_fusion"]:
raise ValueError(f"DocumentJoiner component does not support '{join_mode}' join_mode.")
self.join_mode = join_mode
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
self.top_k = top_k
self.sort_by_score = sort_by_score
@component.output_types(documents=List[Document])
def run(self, documents: Variadic[List[Document]]):
"""
Run the DocumentJoiner. This method joins the input lists of Documents into one output list based on the join_mode specified during initialization.
:param documents: An arbitrary number of lists of Documents to join.
"""
output_documents = []
if self.join_mode == "concatenate":
output_documents = self._concatenate(documents)
elif self.join_mode == "merge":
output_documents = self._merge(documents)
elif self.join_mode == "reciprocal_rank_fusion":
output_documents = self._reciprocal_rank_fusion(documents)
if self.sort_by_score:
output_documents = sorted(
output_documents, key=lambda doc: doc.score if doc.score is not None else -inf, reverse=True
)
if any(doc.score is None for doc in output_documents):
logger.info(
"Some of the Documents DocumentJoiner got have score=None. It was configured to sort Documents by "
"score, so those with score=None were sorted as if they had a score of -infinity."
)
if self.top_k:
output_documents = output_documents[: self.top_k]
return {"documents": output_documents}
def _concatenate(self, document_lists):
"""
Concatenate multiple lists of Documents and return only the Document with the highest score for duplicate Documents.
"""
output = []
docs_per_id = defaultdict(list)
for doc in itertools.chain.from_iterable(document_lists):
docs_per_id[doc.id].append(doc)
for docs in docs_per_id.values():
doc_with_best_score = max(docs, key=lambda doc: doc.score if doc.score else -inf)
output.append(doc_with_best_score)
return output
def _merge(self, document_lists):
"""
Merge multiple lists of Documents and calculate a weighted sum of the scores of duplicate Documents.
"""
scores_map = defaultdict(int)
documents_map = {}
weights = self.weights if self.weights else [1 / len(document_lists)] * len(document_lists)
for documents, weight in zip(document_lists, weights):
for doc in documents:
scores_map[doc.id] += (doc.score if doc.score else 0) * weight
documents_map[doc.id] = doc
for doc in documents_map.values():
doc.score = scores_map[doc.id]
return documents_map.values()
def _reciprocal_rank_fusion(self, document_lists):
"""
Merge multiple lists of Documents and assign scores based on reciprocal rank fusion.
The constant k is set to 61 (60 was suggested by the original paper,
plus 1 as python lists are 0-based and the paper used 1-based ranking).
"""
k = 61
scores_map = defaultdict(int)
documents_map = {}
for documents in document_lists:
for rank, doc in enumerate(documents):
scores_map[doc.id] += 1 / (k + rank)
documents_map[doc.id] = doc
for doc in documents_map.values():
doc.score = scores_map[doc.id]
return documents_map.values()

View File

@ -0,0 +1,5 @@
---
preview:
- |
Added a new DocumentJoiner component so that hybrid retrieval pipelines can merge the document result lists of multiple retrievers.
Similarly, indexing pipelines can use DocumentJoiner to merge multiple lists of documents created by different file converters.

View File

@ -0,0 +1,140 @@
import logging
import pytest
from haystack.preview import Document
from haystack.preview.components.routers.document_joiner import DocumentJoiner
class TestDocumentJoiner:
@pytest.mark.unit
def test_init(self):
joiner = DocumentJoiner()
assert joiner.join_mode == "concatenate"
assert joiner.weights is None
assert joiner.top_k is None
assert joiner.sort_by_score
@pytest.mark.unit
def test_init_with_custom_parameters(self):
joiner = DocumentJoiner(join_mode="merge", weights=[0.4, 0.6], top_k=5, sort_by_score=False)
assert joiner.join_mode == "merge"
assert joiner.weights == [0.4, 0.6]
assert joiner.top_k == 5
assert not joiner.sort_by_score
@pytest.mark.unit
def test_empty_list(self):
joiner = DocumentJoiner()
result = joiner.run([])
assert result == {"documents": []}
@pytest.mark.unit
def test_list_of_empty_lists(self):
joiner = DocumentJoiner()
result = joiner.run([[], []])
assert result == {"documents": []}
@pytest.mark.unit
def test_list_with_one_empty_list(self):
joiner = DocumentJoiner()
documents = [Document(content="a"), Document(content="b"), Document(content="c")]
result = joiner.run([[], documents])
assert result == {"documents": documents}
@pytest.mark.unit
def test_unsupported_join_mode(self):
with pytest.raises(ValueError, match="DocumentJoiner component does not support 'unsupported_mode' join_mode."):
DocumentJoiner(join_mode="unsupported_mode")
@pytest.mark.unit
def test_run_with_concatenate_join_mode_and_top_k(self):
joiner = DocumentJoiner(top_k=6)
documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")]
documents_2 = [
Document(content="d"),
Document(content="e"),
Document(content="f", meta={"key": "value"}),
Document(content="g"),
]
output = joiner.run([documents_1, documents_2])
assert len(output["documents"]) == 6
assert sorted(documents_1 + documents_2[:-1], key=lambda d: d.id) == sorted(
output["documents"], key=lambda d: d.id
)
@pytest.mark.unit
def test_run_with_concatenate_join_mode_and_duplicate_documents(self):
joiner = DocumentJoiner()
documents_1 = [Document(content="a", score=0.3), Document(content="b"), Document(content="c")]
documents_2 = [
Document(content="a", score=0.2),
Document(content="a"),
Document(content="f", meta={"key": "value"}),
]
output = joiner.run([documents_1, documents_2])
assert len(output["documents"]) == 4
assert sorted(documents_1 + [documents_2[-1]], key=lambda d: d.id) == sorted(
output["documents"], key=lambda d: d.id
)
@pytest.mark.unit
def test_run_with_merge_join_mode(self):
joiner = DocumentJoiner(join_mode="merge", weights=[1.5, 0.5])
documents_1 = [Document(content="a", score=1.0), Document(content="b", score=2.0)]
documents_2 = [
Document(content="a", score=0.5),
Document(content="b", score=3.0),
Document(content="f", score=4.0, meta={"key": "value"}),
]
output = joiner.run([documents_1, documents_2])
assert len(output["documents"]) == 3
expected_document_ids = [
doc.id
for doc in [
Document(content="a", score=1.25),
Document(content="b", score=2.25),
Document(content="f", score=4.0, meta={"key": "value"}),
]
]
assert all(doc.id in expected_document_ids for doc in output["documents"])
@pytest.mark.unit
def test_run_with_reciprocal_rank_fusion_join_mode(self):
joiner = DocumentJoiner(join_mode="reciprocal_rank_fusion")
documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")]
documents_2 = [
Document(content="b", score=1000.0),
Document(content="c"),
Document(content="a"),
Document(content="f", meta={"key": "value"}),
]
output = joiner.run([documents_1, documents_2])
assert len(output["documents"]) == 4
expected_document_ids = [
doc.id
for doc in [
Document(content="b"),
Document(content="a"),
Document(content="c"),
Document(content="f", meta={"key": "value"}),
]
]
assert all(doc.id in expected_document_ids for doc in output["documents"])
@pytest.mark.unit
def test_sort_by_score_without_scores(self, caplog):
joiner = DocumentJoiner()
with caplog.at_level(logging.INFO):
documents = [Document(content="a"), Document(content="b", score=0.5)]
output = joiner.run([documents])
assert "those with score=None were sorted as if they had a score of -infinity" in caplog.text
assert output["documents"] == documents[::-1]
@pytest.mark.unit
def test_output_documents_not_sorted_by_score(self):
joiner = DocumentJoiner(sort_by_score=False)
documents_1 = [Document(content="a", score=0.1)]
documents_2 = [Document(content="d", score=0.2)]
output = joiner.run([documents_1, documents_2])
assert output["documents"] == documents_1 + documents_2