feat: adding AutoMergingRetriever and HierarchicalDocumentSplitter (#9067)

* adding Auto-Merging-Retriever

* adding release notes

* updating tests

* adding renamed file

* Update haystack/components/preprocessors/hierarchical_document_splitter.py

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* Update haystack/components/retrievers/auto_merging_retriever.py

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* fixing tests and imports

* adding pydoc

* adding to type checking

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
David S. Batista 2025-03-19 19:25:23 +01:00 committed by GitHub
parent 9a046ed431
commit be2d1fb303
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 834 additions and 1 deletions

View File

@ -1,7 +1,14 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/preprocessors]
modules: ["csv_document_cleaner", "csv_document_splitter", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"]
modules: [
"csv_document_cleaner",
"csv_document_splitter",
"document_cleaner",
"document_splitter",
"hierarchical_document_splitter",
"recursive_splitter",
"text_cleaner"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter

View File

@ -3,6 +3,7 @@ loaders:
search_path: [../../../haystack/components/retrievers]
modules:
[
"auto_merging_retriever",
"in_memory/bm25_retriever",
"in_memory/embedding_retriever",
"filter_retriever",

View File

@ -12,6 +12,7 @@ _import_structure = {
"csv_document_splitter": ["CSVDocumentSplitter"],
"document_cleaner": ["DocumentCleaner"],
"document_splitter": ["DocumentSplitter"],
"hierarchical_document_splitter": ["HierarchicalDocumentSplitter"],
"recursive_splitter": ["RecursiveDocumentSplitter"],
"text_cleaner": ["TextCleaner"],
}
@ -21,6 +22,7 @@ if TYPE_CHECKING:
from .csv_document_splitter import CSVDocumentSplitter
from .document_cleaner import DocumentCleaner
from .document_splitter import DocumentSplitter
from .hierarchical_document_splitter import HierarchicalDocumentSplitter
from .recursive_splitter import RecursiveDocumentSplitter
from .text_cleaner import TextCleaner

View File

@ -0,0 +1,144 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Literal, Set
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.components.preprocessors import DocumentSplitter
@component
class HierarchicalDocumentSplitter:
"""
Splits a documents into different block sizes building a hierarchical tree structure of blocks of different sizes.
The root node of the tree is the original document, the leaf nodes are the smallest blocks. The blocks in between
are connected such that the smaller blocks are children of the parent-larger blocks.
## Usage example
```python
from haystack import Document
from haystack.components.preprocessors import HierarchicalDocumentSplitter
doc = Document(content="This is a simple test document")
splitter = HierarchicalDocumentSplitter(block_sizes={3, 2}, split_overlap=0, split_by="word")
splitter.run([doc])
>> {'documents': [Document(id=3f7..., content: 'This is a simple test document', meta: {'block_size': 0, 'parent_id': None, 'children_ids': ['5ff..', '8dc..'], 'level': 0}),
>> Document(id=5ff.., content: 'This is a ', meta: {'block_size': 3, 'parent_id': '3f7..', 'children_ids': ['f19..', '52c..'], 'level': 1, 'source_id': '3f7..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}),
>> Document(id=8dc.., content: 'simple test document', meta: {'block_size': 3, 'parent_id': '3f7..', 'children_ids': ['39d..', 'e23..'], 'level': 1, 'source_id': '3f7..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 10}),
>> Document(id=f19.., content: 'This is ', meta: {'block_size': 2, 'parent_id': '5ff..', 'children_ids': [], 'level': 2, 'source_id': '5ff..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}),
>> Document(id=52c.., content: 'a ', meta: {'block_size': 2, 'parent_id': '5ff..', 'children_ids': [], 'level': 2, 'source_id': '5ff..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 8}),
>> Document(id=39d.., content: 'simple test ', meta: {'block_size': 2, 'parent_id': '8dc..', 'children_ids': [], 'level': 2, 'source_id': '8dc..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}),
>> Document(id=e23.., content: 'document', meta: {'block_size': 2, 'parent_id': '8dc..', 'children_ids': [], 'level': 2, 'source_id': '8dc..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 12})]}
```
""" # noqa: E501
def __init__(
self,
block_sizes: Set[int],
split_overlap: int = 0,
split_by: Literal["word", "sentence", "page", "passage"] = "word",
):
"""
Initialize HierarchicalDocumentSplitter.
:param block_sizes: Set of block sizes to split the document into. The blocks are split in descending order.
:param split_overlap: The number of overlapping units for each split.
:param split_by: The unit for splitting your documents.
"""
self.block_sizes = sorted(set(block_sizes), reverse=True)
self.splitters: Dict[int, DocumentSplitter] = {}
self.split_overlap = split_overlap
self.split_by = split_by
self._build_block_sizes()
@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Builds a hierarchical document structure for each document in a list of documents.
:param documents: List of Documents to split into hierarchical blocks.
:returns: List of HierarchicalDocument
"""
hierarchical_docs = []
for doc in documents:
hierarchical_docs.extend(self.build_hierarchy_from_doc(doc))
return {"documents": hierarchical_docs}
def _build_block_sizes(self):
for block_size in self.block_sizes:
self.splitters[block_size] = DocumentSplitter(
split_length=block_size, split_overlap=self.split_overlap, split_by=self.split_by
)
self.splitters[block_size].warm_up()
@staticmethod
def _add_meta_data(document: Document):
document.meta["__block_size"] = 0
document.meta["__parent_id"] = None
document.meta["__children_ids"] = []
document.meta["__level"] = 0
return document
def build_hierarchy_from_doc(self, document: Document) -> List[Document]:
"""
Build a hierarchical tree document structure from a single document.
Given a document, this function splits the document into hierarchical blocks of different sizes represented
as HierarchicalDocument objects.
:param document: Document to split into hierarchical blocks.
:returns:
List of HierarchicalDocument
"""
root = self._add_meta_data(document)
current_level_nodes = [root]
all_docs = []
for block in self.block_sizes:
next_level_nodes = []
for doc in current_level_nodes:
splitted_docs = self.splitters[block].run([doc])
child_docs = splitted_docs["documents"]
# if it's only one document skip
if len(child_docs) == 1:
next_level_nodes.append(doc)
continue
for child_doc in child_docs:
child_doc = self._add_meta_data(child_doc)
child_doc.meta["__level"] = doc.meta["__level"] + 1
child_doc.meta["__block_size"] = block
child_doc.meta["__parent_id"] = doc.id
all_docs.append(child_doc)
doc.meta["__children_ids"].append(child_doc.id)
next_level_nodes.append(child_doc)
current_level_nodes = next_level_nodes
return [root] + all_docs
def to_dict(self) -> Dict[str, Any]:
"""
Returns a dictionary representation of the component.
:returns:
Serialized dictionary representation of the component.
"""
return default_to_dict(
self, block_sizes=self.block_sizes, split_overlap=self.split_overlap, split_by=self.split_by
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HierarchicalDocumentSplitter":
"""
Deserialize this component from a dictionary.
:param data:
The dictionary to deserialize and create the component.
:returns:
The deserialized component.
"""
return default_from_dict(cls, data)

View File

@ -8,12 +8,14 @@ from typing import TYPE_CHECKING
from lazy_imports import LazyImporter
_import_structure = {
"auto_merging_retriever": ["AutoMergingRetriever"],
"filter_retriever": ["FilterRetriever"],
"in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"],
"sentence_window_retriever": ["SentenceWindowRetriever"],
}
if TYPE_CHECKING:
from .auto_merging_retriever import AutoMergingRetriever
from .filter_retriever import FilterRetriever
from .in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from .sentence_window_retriever import SentenceWindowRetriever

View File

@ -0,0 +1,169 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from typing import Any, Dict, List
from haystack import Document, component, default_to_dict
from haystack.core.serialization import default_from_dict
from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_params_inplace
@component
class AutoMergingRetriever:
"""
A retriever which returns parent documents of the matched leaf nodes documents, based on a threshold setting.
The AutoMergingRetriever assumes you have a hierarchical tree structure of documents, where the leaf nodes
are indexed in a document store. See the HierarchicalDocumentSplitter for more information on how to create
such a structure. During retrieval, if the number of matched leaf documents below the same parent is
higher than a defined threshold, the retriever will return the parent document instead of the individual leaf
documents.
The rational is, given that a paragraph is split into multiple chunks represented as leaf documents, and if for
a given query, multiple chunks are matched, the whole paragraph might be more informative than the individual
chunks alone.
Currently the AutoMergingRetriever can only be used by the following DocumentStores:
- [AstraDB](https://haystack.deepset.ai/integrations/astradb)
- [ElasticSearch](https://haystack.deepset.ai/docs/latest/documentstore/elasticsearch)
- [OpenSearch](https://haystack.deepset.ai/docs/latest/documentstore/opensearch)
- [PGVector](https://haystack.deepset.ai/docs/latest/documentstore/pgvector)
- [Qdrant](https://haystack.deepset.ai/docs/latest/documentstore/qdrant)
```python
from haystack import Document
from haystack.components.preprocessors import HierarchicalDocumentSplitter
from haystack.components.retrievers.auto_merging_retriever import AutoMergingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
# create a hierarchical document structure with 3 levels, where the parent document has 3 children
text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing."
original_document = Document(content=text)
builder = HierarchicalDocumentSplitter(block_sizes=[10, 3], split_overlap=0, split_by="word")
docs = builder.run([original_document])["documents"]
# store level-1 parent documents and initialize the retriever
doc_store_parents = InMemoryDocumentStore()
for doc in docs["documents"]:
if doc.meta["children_ids"] and doc.meta["level"] == 1:
doc_store_parents.write_documents([doc])
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5)
# assume we retrieved 2 leaf docs from the same parent, the parent document should be returned,
# since it has 3 children and the threshold=0.5, and we retrieved 2 children (2/3 > 0.66(6))
leaf_docs = [doc for doc in docs["documents"] if not doc.meta["children_ids"]]
docs = retriever.run(leaf_docs[4:6])
>> {'documents': [Document(id=538..),
>> content: 'warm glow over the trees. Birds began to sing.',
>> meta: {'block_size': 10, 'parent_id': '835..', 'children_ids': ['c17...', '3ff...', '352...'], 'level': 1, 'source_id': '835...',
>> 'page_number': 1, 'split_id': 1, 'split_idx_start': 45})]}
```
""" # noqa: E501
def __init__(self, document_store: DocumentStore, threshold: float = 0.5):
"""
Initialize the AutoMergingRetriever.
:param document_store: DocumentStore from which to retrieve the parent documents
:param threshold: Threshold to decide whether the parent instead of the individual documents is returned
"""
if not 0 < threshold < 1:
raise ValueError("The threshold parameter must be between 0 and 1.")
self.document_store = document_store
self.threshold = threshold
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
docstore = self.document_store.to_dict()
return default_to_dict(self, document_store=docstore, threshold=self.threshold)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AutoMergingRetriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary with serialized data.
:returns:
An instance of the component.
"""
deserialize_document_store_in_init_params_inplace(data)
return default_from_dict(cls, data)
@staticmethod
def _check_valid_documents(matched_leaf_documents: List[Document]):
# check if the matched leaf documents have the required meta fields
if not all(doc.meta.get("__parent_id") for doc in matched_leaf_documents):
raise ValueError("The matched leaf documents do not have the required meta field '__parent_id'")
if not all(doc.meta.get("__level") for doc in matched_leaf_documents):
raise ValueError("The matched leaf documents do not have the required meta field '__level'")
if not all(doc.meta.get("__block_size") for doc in matched_leaf_documents):
raise ValueError("The matched leaf documents do not have the required meta field '__block_size'")
@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Run the AutoMergingRetriever.
Recursively groups documents by their parents and merges them if they meet the threshold,
continuing up the hierarchy until no more merges are possible.
:param documents: List of leaf documents that were matched by a retriever
:returns:
List of documents (could be a mix of different hierarchy levels)
"""
AutoMergingRetriever._check_valid_documents(documents)
def _get_parent_doc(parent_id: str) -> Document:
parent_docs = self.document_store.filter_documents({"field": "id", "operator": "==", "value": parent_id})
if len(parent_docs) != 1:
raise ValueError(f"Expected 1 parent document with id {parent_id}, found {len(parent_docs)}")
parent_doc = parent_docs[0]
if not parent_doc.meta.get("__children_ids"):
raise ValueError(f"Parent document with id {parent_id} does not have any children.")
return parent_doc
def _try_merge_level(docs_to_merge: List[Document], docs_to_return: List[Document]) -> List[Document]:
parent_doc_id_to_child_docs: Dict[str, List[Document]] = defaultdict(list) # to group documents by parent
for doc in docs_to_merge:
if doc.meta.get("__parent_id"): # only docs that have parents
parent_doc_id_to_child_docs[doc.meta["__parent_id"]].append(doc)
else:
docs_to_return.append(doc) # keep docs that have no parents
# Process each parent group
merged_docs = []
for parent_doc_id, child_docs in parent_doc_id_to_child_docs.items():
parent_doc = _get_parent_doc(parent_doc_id)
# Calculate merge score
score = len(child_docs) / len(parent_doc.meta["__children_ids"])
if score > self.threshold:
merged_docs.append(parent_doc) # Merge into parent
else:
docs_to_return.extend(child_docs) # Keep children separate
# if no new merges were made, we're done
if not merged_docs:
return merged_docs + docs_to_return
# Recursively try to merge the next level
return _try_merge_level(merged_docs, docs_to_return)
return {"documents": _try_merge_level(documents, [])}

View File

@ -0,0 +1,4 @@
---
features:
- |
We added a new retrieval technique, `AutoMergingRetriever` which together with the `HierarchicalDocumentSplitter` implement a auto-merging retrieval technique.

View File

@ -0,0 +1,246 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack import Document, Pipeline
from haystack.components.preprocessors import HierarchicalDocumentSplitter
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
class TestHierarchicalDocumentSplitter:
def test_init_with_default_params(self):
builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300})
assert builder.block_sizes == [300, 200, 100]
assert builder.split_overlap == 0
assert builder.split_by == "word"
def test_init_with_custom_params(self):
builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}, split_overlap=25, split_by="word")
assert builder.block_sizes == [300, 200, 100]
assert builder.split_overlap == 25
assert builder.split_by == "word"
def test_to_dict(self):
builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}, split_overlap=25, split_by="word")
expected = builder.to_dict()
assert expected == {
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
"init_parameters": {"block_sizes": [300, 200, 100], "split_overlap": 25, "split_by": "word"},
}
def test_from_dict(self):
data = {
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
}
builder = HierarchicalDocumentSplitter.from_dict(data)
assert builder.block_sizes == [10, 5, 2]
assert builder.split_overlap == 0
assert builder.split_by == "word"
def test_run(self):
builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word")
text = "one two three four five six seven eight nine ten"
doc = Document(content=text)
output = builder.run([doc])
docs = output["documents"]
builder.run([doc])
assert len(docs) == 9
assert docs[0].content == "one two three four five six seven eight nine ten"
# level 1 - root node
assert docs[0].meta["__level"] == 0
assert len(docs[0].meta["__children_ids"]) == 2
# level 2 -left branch
assert docs[1].meta["__parent_id"] == docs[0].id
assert docs[1].meta["__level"] == 1
assert len(docs[1].meta["__children_ids"]) == 3
# level 2 - right branch
assert docs[2].meta["__parent_id"] == docs[0].id
assert docs[2].meta["__level"] == 1
assert len(docs[2].meta["__children_ids"]) == 3
# level 3 - left branch - leaf nodes
assert docs[3].meta["__parent_id"] == docs[1].id
assert docs[4].meta["__parent_id"] == docs[1].id
assert docs[5].meta["__parent_id"] == docs[1].id
assert docs[3].meta["__level"] == 2
assert docs[4].meta["__level"] == 2
assert docs[5].meta["__level"] == 2
assert len(docs[3].meta["__children_ids"]) == 0
assert len(docs[4].meta["__children_ids"]) == 0
assert len(docs[5].meta["__children_ids"]) == 0
# level 3 - right branch - leaf nodes
assert docs[6].meta["__parent_id"] == docs[2].id
assert docs[7].meta["__parent_id"] == docs[2].id
assert docs[8].meta["__parent_id"] == docs[2].id
assert docs[6].meta["__level"] == 2
assert docs[7].meta["__level"] == 2
assert docs[8].meta["__level"] == 2
assert len(docs[6].meta["__children_ids"]) == 0
assert len(docs[7].meta["__children_ids"]) == 0
assert len(docs[8].meta["__children_ids"]) == 0
def test_to_dict_in_pipeline(self):
pipeline = Pipeline()
hierarchical_doc_builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2})
doc_store = InMemoryDocumentStore()
doc_writer = DocumentWriter(document_store=doc_store)
pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder)
pipeline.add_component(name="doc_writer", instance=doc_writer)
pipeline.connect("hierarchical_doc_splitter", "doc_writer")
expected = pipeline.to_dict()
assert expected.keys() == {
"connections",
"connection_type_validation",
"components",
"max_runs_per_component",
"metadata",
}
assert expected["components"].keys() == {"hierarchical_doc_splitter", "doc_writer"}
assert expected["components"]["hierarchical_doc_splitter"] == {
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
}
def test_from_dict_in_pipeline(self):
data = {
"metadata": {},
"max_runs_per_component": 100,
"components": {
"hierarchical_document_splitter": {
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
},
"doc_writer": {
"type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
"bm25_algorithm": "BM25L",
"bm25_parameters": {},
"embedding_similarity_function": "dot_product",
"index": "f32ad5bf-43cb-4035-9823-1de1ae9853c1",
},
},
"policy": "NONE",
},
},
},
"connections": [{"sender": "hierarchical_document_splitter.documents", "receiver": "doc_writer.documents"}],
}
assert Pipeline.from_dict(data)
@pytest.mark.integration
def test_example_in_pipeline(self):
pipeline = Pipeline()
hierarchical_doc_builder = HierarchicalDocumentSplitter(
block_sizes={10, 5, 2}, split_overlap=0, split_by="word"
)
doc_store = InMemoryDocumentStore()
doc_writer = DocumentWriter(document_store=doc_store)
pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder)
pipeline.add_component(name="doc_writer", instance=doc_writer)
pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer")
text = "one two three four five six seven eight nine ten"
doc = Document(content=text)
docs = pipeline.run({"hierarchical_doc_splitter": {"documents": [doc]}})
assert docs["doc_writer"]["documents_written"] == 9
assert len(doc_store.storage.values()) == 9
def test_serialization_deserialization_pipeline(self):
pipeline = Pipeline()
hierarchical_doc_builder = HierarchicalDocumentSplitter(
block_sizes={10, 5, 2}, split_overlap=0, split_by="word"
)
doc_store = InMemoryDocumentStore()
doc_writer = DocumentWriter(document_store=doc_store)
pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder)
pipeline.add_component(name="doc_writer", instance=doc_writer)
pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer")
pipeline_dict = pipeline.to_dict()
new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline
def test_split_by_sentence_assure_warm_up_was_called(self):
pipeline = Pipeline()
hierarchical_doc_builder = HierarchicalDocumentSplitter(
block_sizes={10, 5, 2}, split_overlap=0, split_by="sentence"
)
doc_store = InMemoryDocumentStore()
doc_writer = DocumentWriter(document_store=doc_store)
pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder)
pipeline.add_component(name="doc_writer", instance=doc_writer)
pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer")
text = "This is one sentence. This is another sentence. This is the third sentence."
doc = Document(content=text)
docs = pipeline.run({"hierarchical_doc_splitter": {"documents": [doc]}})
assert docs["doc_writer"]["documents_written"] == 3
assert len(doc_store.storage.values()) == 3
def test_hierarchical_splitter_multiple_block_sizes(self):
# Test with three different block sizes
doc = Document(
content="This is a simple test document with multiple sentences. It should be split into various sizes. This helps test the hierarchy."
)
# Using three block sizes: 10, 5, 2 words
splitter = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word")
result = splitter.run([doc])
documents = result["documents"]
# Verify root document
assert len(documents) > 1
root = documents[0]
assert root.meta["__level"] == 0
assert root.meta["__parent_id"] is None
# Verify level 1 documents (block_size=10)
level_1_docs = [d for d in documents if d.meta["__level"] == 1]
for doc in level_1_docs:
assert doc.meta["__block_size"] == 10
assert doc.meta["__parent_id"] == root.id
# Verify level 2 documents (block_size=5)
level_2_docs = [d for d in documents if d.meta["__level"] == 2]
for doc in level_2_docs:
assert doc.meta["__block_size"] == 5
assert doc.meta["__parent_id"] in [d.id for d in level_1_docs]
# Verify level 3 documents (block_size=2)
level_3_docs = [d for d in documents if d.meta["__level"] == 3]
for doc in level_3_docs:
assert doc.meta["__block_size"] == 2
assert doc.meta["__parent_id"] in [d.id for d in level_2_docs]
# Verify children references
for doc in documents:
if doc.meta["__children_ids"]:
child_ids = doc.meta["__children_ids"]
children = [d for d in documents if d.id in child_ids]
for child in children:
assert child.meta["__parent_id"] == doc.id
assert child.meta["__level"] == doc.meta["__level"] + 1

View File

@ -0,0 +1,258 @@
import pytest
from haystack import Document, Pipeline
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.components.preprocessors import HierarchicalDocumentSplitter
from haystack.components.retrievers.auto_merging_retriever import AutoMergingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
class TestAutoMergingRetriever:
def test_init_default(self):
retriever = AutoMergingRetriever(InMemoryDocumentStore())
assert retriever.threshold == 0.5
def test_init_with_parameters(self):
retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7)
assert retriever.threshold == 0.7
def test_init_with_invalid_threshold(self):
with pytest.raises(ValueError):
AutoMergingRetriever(InMemoryDocumentStore(), threshold=-2)
def test_run_missing_parent_id(self):
docs = [Document(content="test", meta={"__level": 1, "__block_size": 10})]
retriever = AutoMergingRetriever(InMemoryDocumentStore())
with pytest.raises(
ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"
):
retriever.run(documents=docs)
def test_run_missing_level(self):
docs = [Document(content="test", meta={"__parent_id": "parent1", "__block_size": 10})]
retriever = AutoMergingRetriever(InMemoryDocumentStore())
with pytest.raises(
ValueError, match="The matched leaf documents do not have the required meta field '__level'"
):
retriever.run(documents=docs)
def test_run_missing_block_size(self):
docs = [Document(content="test", meta={"__parent_id": "parent1", "__level": 1})]
retriever = AutoMergingRetriever(InMemoryDocumentStore())
with pytest.raises(
ValueError, match="The matched leaf documents do not have the required meta field '__block_size'"
):
retriever.run(documents=docs)
def test_run_mixed_valid_and_invalid_documents(self):
docs = [
Document(content="valid", meta={"__parent_id": "parent1", "__level": 1, "__block_size": 10}),
Document(content="invalid", meta={"__level": 1, "__block_size": 10}),
]
retriever = AutoMergingRetriever(InMemoryDocumentStore())
with pytest.raises(
ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"
):
retriever.run(documents=docs)
def test_to_dict(self):
retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7)
expected = retriever.to_dict()
assert expected["type"] == "haystack.components.retrievers.auto_merging_retriever.AutoMergingRetriever"
assert expected["init_parameters"]["threshold"] == 0.7
assert (
expected["init_parameters"]["document_store"]["type"]
== "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore"
)
def test_from_dict(self):
data = {
"type": "haystack.components.retrievers.auto_merging_retriever.AutoMergingRetriever",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
"bm25_algorithm": "BM25L",
"bm25_parameters": {},
"embedding_similarity_function": "dot_product",
"index": "6b122bb4-211b-465e-804d-77c5857bf4c5",
},
},
"threshold": 0.7,
},
}
retriever = AutoMergingRetriever.from_dict(data)
assert retriever.threshold == 0.7
def test_serialization_deserialization_pipeline(self):
pipeline = Pipeline()
doc_store_parents = InMemoryDocumentStore()
bm_25_retriever = InMemoryBM25Retriever(doc_store_parents)
auto_merging_retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5)
pipeline.add_component(name="bm_25_retriever", instance=bm_25_retriever)
pipeline.add_component(name="auto_merging_retriever", instance=auto_merging_retriever)
pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.documents")
pipeline_dict = pipeline.to_dict()
new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline
def test_run_parent_not_found(self):
doc_store = InMemoryDocumentStore()
retriever = AutoMergingRetriever(doc_store, threshold=0.5)
# a leaf document with a non-existent parent_id
leaf_doc = Document(
content="test", meta={"__parent_id": "non_existent_parent", "__level": 1, "__block_size": 10}
)
with pytest.raises(ValueError, match="Expected 1 parent document with id non_existent_parent, found 0"):
retriever.run([leaf_doc])
def test_run_parent_without_children_metadata(self):
"""Test case where a parent document exists but doesn't have the __children_ids metadata field"""
doc_store = InMemoryDocumentStore()
# Create and store a parent document without __children_ids metadata
parent_doc = Document(
content="parent content",
id="parent1",
meta={
"__level": 1, # Add other required metadata
"__block_size": 10,
},
)
doc_store.write_documents([parent_doc])
retriever = AutoMergingRetriever(doc_store, threshold=0.5)
# Create a leaf document that points to this parent
leaf_doc = Document(content="leaf content", meta={"__parent_id": "parent1", "__level": 2, "__block_size": 5})
with pytest.raises(ValueError, match="Parent document with id parent1 does not have any children"):
retriever.run([leaf_doc])
def test_run_empty_documents(self):
retriever = AutoMergingRetriever(InMemoryDocumentStore())
assert retriever.run([]) == {"documents": []}
def test_run_return_parent_document(self):
text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing."
docs = [Document(content=text)]
builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word")
docs = builder.run(docs)
# store all non-leaf documents
doc_store_parents = InMemoryDocumentStore()
for doc in docs["documents"]:
if doc.meta["__children_ids"]:
doc_store_parents.write_documents([doc])
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5)
# assume we retrieved 2 leaf docs from the same parent, the parent document should be returned,
# since it has 3 children and the threshold=0.5, and we retrieved 2 children (2/3 > 0.66(6))
leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]]
docs = retriever.run(leaf_docs[4:6])
assert len(docs["documents"]) == 1
assert docs["documents"][0].content == "warm glow over the trees. Birds began to sing."
assert len(docs["documents"][0].meta["__children_ids"]) == 3
def test_run_return_leafs_document(self):
docs = [Document(content="The monarch of the wild blue yonder rises from the eastern side of the horizon.")]
builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word")
docs = builder.run(docs)
doc_store_parents = InMemoryDocumentStore()
for doc in docs["documents"]:
if doc.meta["__level"] == 1:
doc_store_parents.write_documents([doc])
leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]]
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.6)
result = retriever.run([leaf_docs[4]])
assert len(result["documents"]) == 1
assert result["documents"][0].content == "eastern side of "
assert result["documents"][0].meta["__parent_id"] == docs["documents"][2].id
def test_run_return_leafs_document_different_parents(self):
docs = [Document(content="The monarch of the wild blue yonder rises from the eastern side of the horizon.")]
builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word")
docs = builder.run(docs)
doc_store_parents = InMemoryDocumentStore()
for doc in docs["documents"]:
if doc.meta["__level"] == 1:
doc_store_parents.write_documents([doc])
leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]]
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.6)
result = retriever.run([leaf_docs[4], leaf_docs[3]])
assert len(result["documents"]) == 2
assert result["documents"][0].meta["__parent_id"] != result["documents"][1].meta["__parent_id"]
def test_run_go_up_hierarchy_multiple_levels(self):
"""
Test if the retriever can go up the hierarchy multiple levels to find the parent document.
Simulate a scenario where we have 4 leaf-documents that matched some initial query. The leaf-documents
are continuously merged up the hierarchy until the threshold is no longer met.
In this case it goes from the 4th level in the hierarchy up the 1st level.
"""
text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing."
docs = [Document(content=text)]
builder = HierarchicalDocumentSplitter(block_sizes={6, 4, 2, 1}, split_overlap=0, split_by="word")
docs = builder.run(docs)
# store all non-leaf documents
doc_store_parents = InMemoryDocumentStore()
for doc in docs["documents"]:
if doc.meta["__children_ids"]:
doc_store_parents.write_documents([doc])
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.4)
# simulate a scenario where we have 4 leaf-documents that matched some initial query
retrieved_leaf_docs = [d for d in docs["documents"] if d.content in {"The ", "sun ", "rose ", "early "}]
result = retriever.run(retrieved_leaf_docs)
assert len(result["documents"]) == 1
assert result["documents"][0].content == "The sun rose early in the "
def test_run_go_up_hierarchy_multiple_levels_hit_root_document(self):
"""
Test case where we go up hierarchy until the root document, so the root document is returned.
It's the only document in the hierarchy which has no parent.
"""
text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing."
docs = [Document(content=text)]
builder = HierarchicalDocumentSplitter(block_sizes={6, 4}, split_overlap=0, split_by="word")
docs = builder.run(docs)
# store all non-leaf documents
doc_store_parents = InMemoryDocumentStore()
for doc in docs["documents"]:
if doc.meta["__children_ids"]:
doc_store_parents.write_documents([doc])
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.1) # set a low threshold to hit root document
# simulate a scenario where we have 4 leaf-documents that matched some initial query
retrieved_leaf_docs = [
d
for d in docs["documents"]
if d.content in {"The sun rose early ", "in the ", "morning. It cast a ", "over the trees. Birds "}
]
result = retriever.run(retrieved_leaf_docs)
assert len(result["documents"]) == 1
assert result["documents"][0].meta["__level"] == 0 # hit root document