mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 11:56:35 +00:00
feat: Add DocumentJoiner component 2.0 (#6105)
* draft DocumentJoiner * implement merge and rrf * draft end-to-end test with DocumentJoiner in hybrid doc search pipeline * adjust for variadics Canals PR #122 * fix text_embedder input * adapt to the new Document class * adapt to new doc id * specify documents input as Variadic in run method * compare doc ids instead of full docs * rename text_file_converter input to sources * update docstring * Update haystack/preview/components/routers/document_joiner.py Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Apply suggestions from docstring review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * capitalize Documents and Retrievers in docstrings * fix log message in test --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Co-authored-by: anakin87 <stefanofiorucci@gmail.com> Co-authored-by: ZanSara <sara.zanzottera@deepset.ai> Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>
This commit is contained in:
parent
e905066458
commit
4ef2a680bb
57
e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py
Normal file
57
e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py
Normal file
@ -0,0 +1,57 @@
|
||||
import json
|
||||
|
||||
from haystack.preview import Pipeline, Document
|
||||
from haystack.preview.components.embedders import SentenceTransformersTextEmbedder
|
||||
from haystack.preview.components.rankers import TransformersSimilarityRanker
|
||||
from haystack.preview.components.routers.document_joiner import DocumentJoiner
|
||||
from haystack.preview.document_stores import InMemoryDocumentStore
|
||||
from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||
|
||||
|
||||
def test_hybrid_doc_search_pipeline(tmp_path):
|
||||
# Create the pipeline
|
||||
document_store = InMemoryDocumentStore()
|
||||
hybrid_pipeline = Pipeline()
|
||||
hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
|
||||
hybrid_pipeline.add_component(
|
||||
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
|
||||
name="text_embedder",
|
||||
)
|
||||
hybrid_pipeline.add_component(
|
||||
instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever"
|
||||
)
|
||||
hybrid_pipeline.add_component(instance=DocumentJoiner(), name="joiner")
|
||||
hybrid_pipeline.add_component(instance=TransformersSimilarityRanker(top_k=20), name="ranker")
|
||||
|
||||
hybrid_pipeline.connect("bm25_retriever", "joiner")
|
||||
hybrid_pipeline.connect("text_embedder", "embedding_retriever")
|
||||
hybrid_pipeline.connect("embedding_retriever", "joiner")
|
||||
hybrid_pipeline.connect("joiner", "ranker")
|
||||
|
||||
# Draw the pipeline
|
||||
hybrid_pipeline.draw(tmp_path / "test_hybrid_doc_search_pipeline.png")
|
||||
|
||||
# Serialize the pipeline to JSON
|
||||
with open(tmp_path / "test_hybrid_doc_search_pipeline.json", "w") as f:
|
||||
print(json.dumps(hybrid_pipeline.to_dict(), indent=4))
|
||||
json.dump(hybrid_pipeline.to_dict(), f)
|
||||
|
||||
# Load the pipeline back
|
||||
with open(tmp_path / "test_hybrid_doc_search_pipeline.json", "r") as f:
|
||||
hybrid_pipeline = Pipeline.from_dict(json.load(f))
|
||||
|
||||
# Populate the document store
|
||||
documents = [
|
||||
Document(content="My name is Jean and I live in Paris."),
|
||||
Document(content="My name is Mark and I live in Berlin."),
|
||||
Document(content="My name is Mario and I live in the capital of Italy."),
|
||||
Document(content="My name is Giorgio and I live in Rome."),
|
||||
]
|
||||
hybrid_pipeline.get_component("bm25_retriever").document_store.write_documents(documents)
|
||||
|
||||
query = "Who lives in Rome?"
|
||||
result = hybrid_pipeline.run(
|
||||
{"bm25_retriever": {"query": query}, "text_embedder": {"text": query}, "ranker": {"query": query}}
|
||||
)
|
||||
assert result["ranker"]["documents"][0].content == "My name is Giorgio and I live in Rome."
|
||||
assert result["ranker"]["documents"][1].content == "My name is Mario and I live in the capital of Italy."
|
@ -29,7 +29,7 @@ def test_preprocessing_pipeline(tmp_path):
|
||||
name="embedder",
|
||||
)
|
||||
preprocessing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="writer")
|
||||
preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.paths")
|
||||
preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.sources")
|
||||
preprocessing_pipeline.connect("text_file_converter.documents", "language_classifier.documents")
|
||||
preprocessing_pipeline.connect("language_classifier.documents", "router.documents")
|
||||
preprocessing_pipeline.connect("router.en", "cleaner.documents")
|
||||
@ -75,7 +75,7 @@ def test_preprocessing_pipeline(tmp_path):
|
||||
filled_document_store = preprocessing_pipeline.get_component("writer").document_store
|
||||
assert filled_document_store.count_documents() == 6
|
||||
|
||||
# Check preprocessed texts and mime_types
|
||||
# Check preprocessed texts
|
||||
stored_documents = filled_document_store.filter_documents()
|
||||
expected_texts = [
|
||||
"This is an english sentence.",
|
||||
|
145
haystack/preview/components/routers/document_joiner.py
Normal file
145
haystack/preview/components/routers/document_joiner.py
Normal file
@ -0,0 +1,145 @@
|
||||
import itertools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from math import inf
|
||||
from typing import List, Optional
|
||||
from canals.component.types import Variadic
|
||||
|
||||
from haystack.preview import component, Document
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class DocumentJoiner:
|
||||
"""
|
||||
A component that joins input lists of Documents from multiple connections and outputs them as one list.
|
||||
|
||||
The component allows multiple join modes:
|
||||
* concatenate: Combine Documents from multiple components. Discards duplicate Documents.
|
||||
Documents get their scores from the last component in the pipeline that assigns scores.
|
||||
This join mode doesn't influence Document scores.
|
||||
* merge: Merge scores of duplicate Documents coming from multiple components.
|
||||
Optionally, you can assign a weight to the scores and set the top_k limit for this join mode.
|
||||
You can also use this join mode to rerank retrieved Documents.
|
||||
* reciprocal_rank_fusion: Combine Documents into a single list based on their ranking received from multiple components.
|
||||
|
||||
Example usage in a hybrid retrieval pipeline:
|
||||
```python
|
||||
document_store = InMemoryDocumentStore()
|
||||
p = Pipeline()
|
||||
p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever")
|
||||
p.add_component(instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever")
|
||||
p.add_component(instance=DocumentJoiner(), name="joiner")
|
||||
p.connect("bm25_retriever", "joiner")
|
||||
p.connect("embedding_retriever", "joiner")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
join_mode: str = "concatenate",
|
||||
weights: Optional[List[float]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
sort_by_score: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the DocumentJoiner.
|
||||
|
||||
:param join_mode: Specifies the join mode to use. Available modes: `concatenate` to combine Documents from multiple Retrievers, `merge` to aggregate the scores of
|
||||
individual Documents, `reciprocal_rank_fusion` to apply rank-based scoring.
|
||||
:param weights: A component-wise list (the length of the list must be equal to the number of input components) of weights for
|
||||
adjusting Document scores when using the `merge` join_mode. By default, equal weight is given
|
||||
to each Retriever score. This param is not compatible with the `concatenate` join_mode.
|
||||
:param top_k: The maximum number of Documents to be returned as output. By default, returns all Documents.
|
||||
:param sort_by_score: Whether the output list of Documents should be sorted by Document scores in descending order.
|
||||
By default, the output is sorted.
|
||||
Documents without score are handled as if their score was -infinity.
|
||||
"""
|
||||
if join_mode not in ["concatenate", "merge", "reciprocal_rank_fusion"]:
|
||||
raise ValueError(f"DocumentJoiner component does not support '{join_mode}' join_mode.")
|
||||
self.join_mode = join_mode
|
||||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
||||
self.top_k = top_k
|
||||
self.sort_by_score = sort_by_score
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: Variadic[List[Document]]):
|
||||
"""
|
||||
Run the DocumentJoiner. This method joins the input lists of Documents into one output list based on the join_mode specified during initialization.
|
||||
|
||||
:param documents: An arbitrary number of lists of Documents to join.
|
||||
"""
|
||||
output_documents = []
|
||||
if self.join_mode == "concatenate":
|
||||
output_documents = self._concatenate(documents)
|
||||
elif self.join_mode == "merge":
|
||||
output_documents = self._merge(documents)
|
||||
elif self.join_mode == "reciprocal_rank_fusion":
|
||||
output_documents = self._reciprocal_rank_fusion(documents)
|
||||
|
||||
if self.sort_by_score:
|
||||
output_documents = sorted(
|
||||
output_documents, key=lambda doc: doc.score if doc.score is not None else -inf, reverse=True
|
||||
)
|
||||
if any(doc.score is None for doc in output_documents):
|
||||
logger.info(
|
||||
"Some of the Documents DocumentJoiner got have score=None. It was configured to sort Documents by "
|
||||
"score, so those with score=None were sorted as if they had a score of -infinity."
|
||||
)
|
||||
|
||||
if self.top_k:
|
||||
output_documents = output_documents[: self.top_k]
|
||||
return {"documents": output_documents}
|
||||
|
||||
def _concatenate(self, document_lists):
|
||||
"""
|
||||
Concatenate multiple lists of Documents and return only the Document with the highest score for duplicate Documents.
|
||||
"""
|
||||
output = []
|
||||
docs_per_id = defaultdict(list)
|
||||
for doc in itertools.chain.from_iterable(document_lists):
|
||||
docs_per_id[doc.id].append(doc)
|
||||
for docs in docs_per_id.values():
|
||||
doc_with_best_score = max(docs, key=lambda doc: doc.score if doc.score else -inf)
|
||||
output.append(doc_with_best_score)
|
||||
return output
|
||||
|
||||
def _merge(self, document_lists):
|
||||
"""
|
||||
Merge multiple lists of Documents and calculate a weighted sum of the scores of duplicate Documents.
|
||||
"""
|
||||
scores_map = defaultdict(int)
|
||||
documents_map = {}
|
||||
weights = self.weights if self.weights else [1 / len(document_lists)] * len(document_lists)
|
||||
|
||||
for documents, weight in zip(document_lists, weights):
|
||||
for doc in documents:
|
||||
scores_map[doc.id] += (doc.score if doc.score else 0) * weight
|
||||
documents_map[doc.id] = doc
|
||||
|
||||
for doc in documents_map.values():
|
||||
doc.score = scores_map[doc.id]
|
||||
|
||||
return documents_map.values()
|
||||
|
||||
def _reciprocal_rank_fusion(self, document_lists):
|
||||
"""
|
||||
Merge multiple lists of Documents and assign scores based on reciprocal rank fusion.
|
||||
The constant k is set to 61 (60 was suggested by the original paper,
|
||||
plus 1 as python lists are 0-based and the paper used 1-based ranking).
|
||||
"""
|
||||
k = 61
|
||||
|
||||
scores_map = defaultdict(int)
|
||||
documents_map = {}
|
||||
for documents in document_lists:
|
||||
for rank, doc in enumerate(documents):
|
||||
scores_map[doc.id] += 1 / (k + rank)
|
||||
documents_map[doc.id] = doc
|
||||
|
||||
for doc in documents_map.values():
|
||||
doc.score = scores_map[doc.id]
|
||||
|
||||
return documents_map.values()
|
@ -0,0 +1,5 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Added a new DocumentJoiner component so that hybrid retrieval pipelines can merge the document result lists of multiple retrievers.
|
||||
Similarly, indexing pipelines can use DocumentJoiner to merge multiple lists of documents created by different file converters.
|
140
test/preview/components/routers/test_document_joiner.py
Normal file
140
test/preview/components/routers/test_document_joiner.py
Normal file
@ -0,0 +1,140 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.components.routers.document_joiner import DocumentJoiner
|
||||
|
||||
|
||||
class TestDocumentJoiner:
|
||||
@pytest.mark.unit
|
||||
def test_init(self):
|
||||
joiner = DocumentJoiner()
|
||||
assert joiner.join_mode == "concatenate"
|
||||
assert joiner.weights is None
|
||||
assert joiner.top_k is None
|
||||
assert joiner.sort_by_score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_custom_parameters(self):
|
||||
joiner = DocumentJoiner(join_mode="merge", weights=[0.4, 0.6], top_k=5, sort_by_score=False)
|
||||
assert joiner.join_mode == "merge"
|
||||
assert joiner.weights == [0.4, 0.6]
|
||||
assert joiner.top_k == 5
|
||||
assert not joiner.sort_by_score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_list(self):
|
||||
joiner = DocumentJoiner()
|
||||
result = joiner.run([])
|
||||
assert result == {"documents": []}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list_of_empty_lists(self):
|
||||
joiner = DocumentJoiner()
|
||||
result = joiner.run([[], []])
|
||||
assert result == {"documents": []}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list_with_one_empty_list(self):
|
||||
joiner = DocumentJoiner()
|
||||
documents = [Document(content="a"), Document(content="b"), Document(content="c")]
|
||||
result = joiner.run([[], documents])
|
||||
assert result == {"documents": documents}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unsupported_join_mode(self):
|
||||
with pytest.raises(ValueError, match="DocumentJoiner component does not support 'unsupported_mode' join_mode."):
|
||||
DocumentJoiner(join_mode="unsupported_mode")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_concatenate_join_mode_and_top_k(self):
|
||||
joiner = DocumentJoiner(top_k=6)
|
||||
documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")]
|
||||
documents_2 = [
|
||||
Document(content="d"),
|
||||
Document(content="e"),
|
||||
Document(content="f", meta={"key": "value"}),
|
||||
Document(content="g"),
|
||||
]
|
||||
output = joiner.run([documents_1, documents_2])
|
||||
assert len(output["documents"]) == 6
|
||||
assert sorted(documents_1 + documents_2[:-1], key=lambda d: d.id) == sorted(
|
||||
output["documents"], key=lambda d: d.id
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_concatenate_join_mode_and_duplicate_documents(self):
|
||||
joiner = DocumentJoiner()
|
||||
documents_1 = [Document(content="a", score=0.3), Document(content="b"), Document(content="c")]
|
||||
documents_2 = [
|
||||
Document(content="a", score=0.2),
|
||||
Document(content="a"),
|
||||
Document(content="f", meta={"key": "value"}),
|
||||
]
|
||||
output = joiner.run([documents_1, documents_2])
|
||||
assert len(output["documents"]) == 4
|
||||
assert sorted(documents_1 + [documents_2[-1]], key=lambda d: d.id) == sorted(
|
||||
output["documents"], key=lambda d: d.id
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_merge_join_mode(self):
|
||||
joiner = DocumentJoiner(join_mode="merge", weights=[1.5, 0.5])
|
||||
documents_1 = [Document(content="a", score=1.0), Document(content="b", score=2.0)]
|
||||
documents_2 = [
|
||||
Document(content="a", score=0.5),
|
||||
Document(content="b", score=3.0),
|
||||
Document(content="f", score=4.0, meta={"key": "value"}),
|
||||
]
|
||||
output = joiner.run([documents_1, documents_2])
|
||||
assert len(output["documents"]) == 3
|
||||
expected_document_ids = [
|
||||
doc.id
|
||||
for doc in [
|
||||
Document(content="a", score=1.25),
|
||||
Document(content="b", score=2.25),
|
||||
Document(content="f", score=4.0, meta={"key": "value"}),
|
||||
]
|
||||
]
|
||||
assert all(doc.id in expected_document_ids for doc in output["documents"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_reciprocal_rank_fusion_join_mode(self):
|
||||
joiner = DocumentJoiner(join_mode="reciprocal_rank_fusion")
|
||||
documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")]
|
||||
documents_2 = [
|
||||
Document(content="b", score=1000.0),
|
||||
Document(content="c"),
|
||||
Document(content="a"),
|
||||
Document(content="f", meta={"key": "value"}),
|
||||
]
|
||||
output = joiner.run([documents_1, documents_2])
|
||||
assert len(output["documents"]) == 4
|
||||
expected_document_ids = [
|
||||
doc.id
|
||||
for doc in [
|
||||
Document(content="b"),
|
||||
Document(content="a"),
|
||||
Document(content="c"),
|
||||
Document(content="f", meta={"key": "value"}),
|
||||
]
|
||||
]
|
||||
assert all(doc.id in expected_document_ids for doc in output["documents"])
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sort_by_score_without_scores(self, caplog):
|
||||
joiner = DocumentJoiner()
|
||||
with caplog.at_level(logging.INFO):
|
||||
documents = [Document(content="a"), Document(content="b", score=0.5)]
|
||||
output = joiner.run([documents])
|
||||
assert "those with score=None were sorted as if they had a score of -infinity" in caplog.text
|
||||
assert output["documents"] == documents[::-1]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_output_documents_not_sorted_by_score(self):
|
||||
joiner = DocumentJoiner(sort_by_score=False)
|
||||
documents_1 = [Document(content="a", score=0.1)]
|
||||
documents_2 = [Document(content="d", score=0.2)]
|
||||
output = joiner.run([documents_1, documents_2])
|
||||
assert output["documents"] == documents_1 + documents_2
|
Loading…
x
Reference in New Issue
Block a user