mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: adding AutoMergingRetriever
and HierarchicalDocumentSplitter
(#9067)
* adding Auto-Merging-Retriever * adding release notes * updating tests * adding renamed file * Update haystack/components/preprocessors/hierarchical_document_splitter.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Update haystack/components/retrievers/auto_merging_retriever.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * fixing tests and imports * adding pydoc * adding to type checking --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
9a046ed431
commit
be2d1fb303
@ -1,7 +1,14 @@
|
|||||||
loaders:
|
loaders:
|
||||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||||
search_path: [../../../haystack/components/preprocessors]
|
search_path: [../../../haystack/components/preprocessors]
|
||||||
modules: ["csv_document_cleaner", "csv_document_splitter", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"]
|
modules: [
|
||||||
|
"csv_document_cleaner",
|
||||||
|
"csv_document_splitter",
|
||||||
|
"document_cleaner",
|
||||||
|
"document_splitter",
|
||||||
|
"hierarchical_document_splitter",
|
||||||
|
"recursive_splitter",
|
||||||
|
"text_cleaner"]
|
||||||
ignore_when_discovered: ["__init__"]
|
ignore_when_discovered: ["__init__"]
|
||||||
processors:
|
processors:
|
||||||
- type: filter
|
- type: filter
|
||||||
|
@ -3,6 +3,7 @@ loaders:
|
|||||||
search_path: [../../../haystack/components/retrievers]
|
search_path: [../../../haystack/components/retrievers]
|
||||||
modules:
|
modules:
|
||||||
[
|
[
|
||||||
|
"auto_merging_retriever",
|
||||||
"in_memory/bm25_retriever",
|
"in_memory/bm25_retriever",
|
||||||
"in_memory/embedding_retriever",
|
"in_memory/embedding_retriever",
|
||||||
"filter_retriever",
|
"filter_retriever",
|
||||||
|
@ -12,6 +12,7 @@ _import_structure = {
|
|||||||
"csv_document_splitter": ["CSVDocumentSplitter"],
|
"csv_document_splitter": ["CSVDocumentSplitter"],
|
||||||
"document_cleaner": ["DocumentCleaner"],
|
"document_cleaner": ["DocumentCleaner"],
|
||||||
"document_splitter": ["DocumentSplitter"],
|
"document_splitter": ["DocumentSplitter"],
|
||||||
|
"hierarchical_document_splitter": ["HierarchicalDocumentSplitter"],
|
||||||
"recursive_splitter": ["RecursiveDocumentSplitter"],
|
"recursive_splitter": ["RecursiveDocumentSplitter"],
|
||||||
"text_cleaner": ["TextCleaner"],
|
"text_cleaner": ["TextCleaner"],
|
||||||
}
|
}
|
||||||
@ -21,6 +22,7 @@ if TYPE_CHECKING:
|
|||||||
from .csv_document_splitter import CSVDocumentSplitter
|
from .csv_document_splitter import CSVDocumentSplitter
|
||||||
from .document_cleaner import DocumentCleaner
|
from .document_cleaner import DocumentCleaner
|
||||||
from .document_splitter import DocumentSplitter
|
from .document_splitter import DocumentSplitter
|
||||||
|
from .hierarchical_document_splitter import HierarchicalDocumentSplitter
|
||||||
from .recursive_splitter import RecursiveDocumentSplitter
|
from .recursive_splitter import RecursiveDocumentSplitter
|
||||||
from .text_cleaner import TextCleaner
|
from .text_cleaner import TextCleaner
|
||||||
|
|
||||||
|
@ -0,0 +1,144 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Literal, Set
|
||||||
|
|
||||||
|
from haystack import Document, component, default_from_dict, default_to_dict
|
||||||
|
from haystack.components.preprocessors import DocumentSplitter
|
||||||
|
|
||||||
|
|
||||||
|
@component
|
||||||
|
class HierarchicalDocumentSplitter:
|
||||||
|
"""
|
||||||
|
Splits a documents into different block sizes building a hierarchical tree structure of blocks of different sizes.
|
||||||
|
|
||||||
|
The root node of the tree is the original document, the leaf nodes are the smallest blocks. The blocks in between
|
||||||
|
are connected such that the smaller blocks are children of the parent-larger blocks.
|
||||||
|
|
||||||
|
## Usage example
|
||||||
|
```python
|
||||||
|
from haystack import Document
|
||||||
|
from haystack.components.preprocessors import HierarchicalDocumentSplitter
|
||||||
|
|
||||||
|
doc = Document(content="This is a simple test document")
|
||||||
|
splitter = HierarchicalDocumentSplitter(block_sizes={3, 2}, split_overlap=0, split_by="word")
|
||||||
|
splitter.run([doc])
|
||||||
|
>> {'documents': [Document(id=3f7..., content: 'This is a simple test document', meta: {'block_size': 0, 'parent_id': None, 'children_ids': ['5ff..', '8dc..'], 'level': 0}),
|
||||||
|
>> Document(id=5ff.., content: 'This is a ', meta: {'block_size': 3, 'parent_id': '3f7..', 'children_ids': ['f19..', '52c..'], 'level': 1, 'source_id': '3f7..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}),
|
||||||
|
>> Document(id=8dc.., content: 'simple test document', meta: {'block_size': 3, 'parent_id': '3f7..', 'children_ids': ['39d..', 'e23..'], 'level': 1, 'source_id': '3f7..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 10}),
|
||||||
|
>> Document(id=f19.., content: 'This is ', meta: {'block_size': 2, 'parent_id': '5ff..', 'children_ids': [], 'level': 2, 'source_id': '5ff..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}),
|
||||||
|
>> Document(id=52c.., content: 'a ', meta: {'block_size': 2, 'parent_id': '5ff..', 'children_ids': [], 'level': 2, 'source_id': '5ff..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 8}),
|
||||||
|
>> Document(id=39d.., content: 'simple test ', meta: {'block_size': 2, 'parent_id': '8dc..', 'children_ids': [], 'level': 2, 'source_id': '8dc..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}),
|
||||||
|
>> Document(id=e23.., content: 'document', meta: {'block_size': 2, 'parent_id': '8dc..', 'children_ids': [], 'level': 2, 'source_id': '8dc..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 12})]}
|
||||||
|
```
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_sizes: Set[int],
|
||||||
|
split_overlap: int = 0,
|
||||||
|
split_by: Literal["word", "sentence", "page", "passage"] = "word",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize HierarchicalDocumentSplitter.
|
||||||
|
|
||||||
|
:param block_sizes: Set of block sizes to split the document into. The blocks are split in descending order.
|
||||||
|
:param split_overlap: The number of overlapping units for each split.
|
||||||
|
:param split_by: The unit for splitting your documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.block_sizes = sorted(set(block_sizes), reverse=True)
|
||||||
|
self.splitters: Dict[int, DocumentSplitter] = {}
|
||||||
|
self.split_overlap = split_overlap
|
||||||
|
self.split_by = split_by
|
||||||
|
self._build_block_sizes()
|
||||||
|
|
||||||
|
@component.output_types(documents=List[Document])
|
||||||
|
def run(self, documents: List[Document]):
|
||||||
|
"""
|
||||||
|
Builds a hierarchical document structure for each document in a list of documents.
|
||||||
|
|
||||||
|
:param documents: List of Documents to split into hierarchical blocks.
|
||||||
|
:returns: List of HierarchicalDocument
|
||||||
|
"""
|
||||||
|
hierarchical_docs = []
|
||||||
|
for doc in documents:
|
||||||
|
hierarchical_docs.extend(self.build_hierarchy_from_doc(doc))
|
||||||
|
return {"documents": hierarchical_docs}
|
||||||
|
|
||||||
|
def _build_block_sizes(self):
|
||||||
|
for block_size in self.block_sizes:
|
||||||
|
self.splitters[block_size] = DocumentSplitter(
|
||||||
|
split_length=block_size, split_overlap=self.split_overlap, split_by=self.split_by
|
||||||
|
)
|
||||||
|
self.splitters[block_size].warm_up()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_meta_data(document: Document):
|
||||||
|
document.meta["__block_size"] = 0
|
||||||
|
document.meta["__parent_id"] = None
|
||||||
|
document.meta["__children_ids"] = []
|
||||||
|
document.meta["__level"] = 0
|
||||||
|
return document
|
||||||
|
|
||||||
|
def build_hierarchy_from_doc(self, document: Document) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Build a hierarchical tree document structure from a single document.
|
||||||
|
|
||||||
|
Given a document, this function splits the document into hierarchical blocks of different sizes represented
|
||||||
|
as HierarchicalDocument objects.
|
||||||
|
|
||||||
|
:param document: Document to split into hierarchical blocks.
|
||||||
|
:returns:
|
||||||
|
List of HierarchicalDocument
|
||||||
|
"""
|
||||||
|
|
||||||
|
root = self._add_meta_data(document)
|
||||||
|
current_level_nodes = [root]
|
||||||
|
all_docs = []
|
||||||
|
|
||||||
|
for block in self.block_sizes:
|
||||||
|
next_level_nodes = []
|
||||||
|
for doc in current_level_nodes:
|
||||||
|
splitted_docs = self.splitters[block].run([doc])
|
||||||
|
child_docs = splitted_docs["documents"]
|
||||||
|
# if it's only one document skip
|
||||||
|
if len(child_docs) == 1:
|
||||||
|
next_level_nodes.append(doc)
|
||||||
|
continue
|
||||||
|
for child_doc in child_docs:
|
||||||
|
child_doc = self._add_meta_data(child_doc)
|
||||||
|
child_doc.meta["__level"] = doc.meta["__level"] + 1
|
||||||
|
child_doc.meta["__block_size"] = block
|
||||||
|
child_doc.meta["__parent_id"] = doc.id
|
||||||
|
all_docs.append(child_doc)
|
||||||
|
doc.meta["__children_ids"].append(child_doc.id)
|
||||||
|
next_level_nodes.append(child_doc)
|
||||||
|
current_level_nodes = next_level_nodes
|
||||||
|
|
||||||
|
return [root] + all_docs
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Returns a dictionary representation of the component.
|
||||||
|
|
||||||
|
:returns:
|
||||||
|
Serialized dictionary representation of the component.
|
||||||
|
"""
|
||||||
|
return default_to_dict(
|
||||||
|
self, block_sizes=self.block_sizes, split_overlap=self.split_overlap, split_by=self.split_by
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "HierarchicalDocumentSplitter":
|
||||||
|
"""
|
||||||
|
Deserialize this component from a dictionary.
|
||||||
|
|
||||||
|
:param data:
|
||||||
|
The dictionary to deserialize and create the component.
|
||||||
|
|
||||||
|
:returns:
|
||||||
|
The deserialized component.
|
||||||
|
"""
|
||||||
|
return default_from_dict(cls, data)
|
@ -8,12 +8,14 @@ from typing import TYPE_CHECKING
|
|||||||
from lazy_imports import LazyImporter
|
from lazy_imports import LazyImporter
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
|
"auto_merging_retriever": ["AutoMergingRetriever"],
|
||||||
"filter_retriever": ["FilterRetriever"],
|
"filter_retriever": ["FilterRetriever"],
|
||||||
"in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"],
|
"in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"],
|
||||||
"sentence_window_retriever": ["SentenceWindowRetriever"],
|
"sentence_window_retriever": ["SentenceWindowRetriever"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .auto_merging_retriever import AutoMergingRetriever
|
||||||
from .filter_retriever import FilterRetriever
|
from .filter_retriever import FilterRetriever
|
||||||
from .in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
from .in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||||
from .sentence_window_retriever import SentenceWindowRetriever
|
from .sentence_window_retriever import SentenceWindowRetriever
|
||||||
|
169
haystack/components/retrievers/auto_merging_retriever.py
Normal file
169
haystack/components/retrievers/auto_merging_retriever.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from haystack import Document, component, default_to_dict
|
||||||
|
from haystack.core.serialization import default_from_dict
|
||||||
|
from haystack.document_stores.types import DocumentStore
|
||||||
|
from haystack.utils import deserialize_document_store_in_init_params_inplace
|
||||||
|
|
||||||
|
|
||||||
|
@component
|
||||||
|
class AutoMergingRetriever:
|
||||||
|
"""
|
||||||
|
A retriever which returns parent documents of the matched leaf nodes documents, based on a threshold setting.
|
||||||
|
|
||||||
|
The AutoMergingRetriever assumes you have a hierarchical tree structure of documents, where the leaf nodes
|
||||||
|
are indexed in a document store. See the HierarchicalDocumentSplitter for more information on how to create
|
||||||
|
such a structure. During retrieval, if the number of matched leaf documents below the same parent is
|
||||||
|
higher than a defined threshold, the retriever will return the parent document instead of the individual leaf
|
||||||
|
documents.
|
||||||
|
|
||||||
|
The rational is, given that a paragraph is split into multiple chunks represented as leaf documents, and if for
|
||||||
|
a given query, multiple chunks are matched, the whole paragraph might be more informative than the individual
|
||||||
|
chunks alone.
|
||||||
|
|
||||||
|
Currently the AutoMergingRetriever can only be used by the following DocumentStores:
|
||||||
|
- [AstraDB](https://haystack.deepset.ai/integrations/astradb)
|
||||||
|
- [ElasticSearch](https://haystack.deepset.ai/docs/latest/documentstore/elasticsearch)
|
||||||
|
- [OpenSearch](https://haystack.deepset.ai/docs/latest/documentstore/opensearch)
|
||||||
|
- [PGVector](https://haystack.deepset.ai/docs/latest/documentstore/pgvector)
|
||||||
|
- [Qdrant](https://haystack.deepset.ai/docs/latest/documentstore/qdrant)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from haystack import Document
|
||||||
|
from haystack.components.preprocessors import HierarchicalDocumentSplitter
|
||||||
|
from haystack.components.retrievers.auto_merging_retriever import AutoMergingRetriever
|
||||||
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
|
|
||||||
|
# create a hierarchical document structure with 3 levels, where the parent document has 3 children
|
||||||
|
text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing."
|
||||||
|
original_document = Document(content=text)
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes=[10, 3], split_overlap=0, split_by="word")
|
||||||
|
docs = builder.run([original_document])["documents"]
|
||||||
|
|
||||||
|
# store level-1 parent documents and initialize the retriever
|
||||||
|
doc_store_parents = InMemoryDocumentStore()
|
||||||
|
for doc in docs["documents"]:
|
||||||
|
if doc.meta["children_ids"] and doc.meta["level"] == 1:
|
||||||
|
doc_store_parents.write_documents([doc])
|
||||||
|
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5)
|
||||||
|
|
||||||
|
# assume we retrieved 2 leaf docs from the same parent, the parent document should be returned,
|
||||||
|
# since it has 3 children and the threshold=0.5, and we retrieved 2 children (2/3 > 0.66(6))
|
||||||
|
leaf_docs = [doc for doc in docs["documents"] if not doc.meta["children_ids"]]
|
||||||
|
docs = retriever.run(leaf_docs[4:6])
|
||||||
|
>> {'documents': [Document(id=538..),
|
||||||
|
>> content: 'warm glow over the trees. Birds began to sing.',
|
||||||
|
>> meta: {'block_size': 10, 'parent_id': '835..', 'children_ids': ['c17...', '3ff...', '352...'], 'level': 1, 'source_id': '835...',
|
||||||
|
>> 'page_number': 1, 'split_id': 1, 'split_idx_start': 45})]}
|
||||||
|
```
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self, document_store: DocumentStore, threshold: float = 0.5):
|
||||||
|
"""
|
||||||
|
Initialize the AutoMergingRetriever.
|
||||||
|
|
||||||
|
:param document_store: DocumentStore from which to retrieve the parent documents
|
||||||
|
:param threshold: Threshold to decide whether the parent instead of the individual documents is returned
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not 0 < threshold < 1:
|
||||||
|
raise ValueError("The threshold parameter must be between 0 and 1.")
|
||||||
|
|
||||||
|
self.document_store = document_store
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes the component to a dictionary.
|
||||||
|
|
||||||
|
:returns:
|
||||||
|
Dictionary with serialized data.
|
||||||
|
"""
|
||||||
|
docstore = self.document_store.to_dict()
|
||||||
|
return default_to_dict(self, document_store=docstore, threshold=self.threshold)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "AutoMergingRetriever":
|
||||||
|
"""
|
||||||
|
Deserializes the component from a dictionary.
|
||||||
|
|
||||||
|
:param data:
|
||||||
|
Dictionary with serialized data.
|
||||||
|
:returns:
|
||||||
|
An instance of the component.
|
||||||
|
"""
|
||||||
|
deserialize_document_store_in_init_params_inplace(data)
|
||||||
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_valid_documents(matched_leaf_documents: List[Document]):
|
||||||
|
# check if the matched leaf documents have the required meta fields
|
||||||
|
if not all(doc.meta.get("__parent_id") for doc in matched_leaf_documents):
|
||||||
|
raise ValueError("The matched leaf documents do not have the required meta field '__parent_id'")
|
||||||
|
|
||||||
|
if not all(doc.meta.get("__level") for doc in matched_leaf_documents):
|
||||||
|
raise ValueError("The matched leaf documents do not have the required meta field '__level'")
|
||||||
|
|
||||||
|
if not all(doc.meta.get("__block_size") for doc in matched_leaf_documents):
|
||||||
|
raise ValueError("The matched leaf documents do not have the required meta field '__block_size'")
|
||||||
|
|
||||||
|
@component.output_types(documents=List[Document])
|
||||||
|
def run(self, documents: List[Document]):
|
||||||
|
"""
|
||||||
|
Run the AutoMergingRetriever.
|
||||||
|
|
||||||
|
Recursively groups documents by their parents and merges them if they meet the threshold,
|
||||||
|
continuing up the hierarchy until no more merges are possible.
|
||||||
|
|
||||||
|
:param documents: List of leaf documents that were matched by a retriever
|
||||||
|
:returns:
|
||||||
|
List of documents (could be a mix of different hierarchy levels)
|
||||||
|
"""
|
||||||
|
|
||||||
|
AutoMergingRetriever._check_valid_documents(documents)
|
||||||
|
|
||||||
|
def _get_parent_doc(parent_id: str) -> Document:
|
||||||
|
parent_docs = self.document_store.filter_documents({"field": "id", "operator": "==", "value": parent_id})
|
||||||
|
if len(parent_docs) != 1:
|
||||||
|
raise ValueError(f"Expected 1 parent document with id {parent_id}, found {len(parent_docs)}")
|
||||||
|
|
||||||
|
parent_doc = parent_docs[0]
|
||||||
|
if not parent_doc.meta.get("__children_ids"):
|
||||||
|
raise ValueError(f"Parent document with id {parent_id} does not have any children.")
|
||||||
|
|
||||||
|
return parent_doc
|
||||||
|
|
||||||
|
def _try_merge_level(docs_to_merge: List[Document], docs_to_return: List[Document]) -> List[Document]:
|
||||||
|
parent_doc_id_to_child_docs: Dict[str, List[Document]] = defaultdict(list) # to group documents by parent
|
||||||
|
|
||||||
|
for doc in docs_to_merge:
|
||||||
|
if doc.meta.get("__parent_id"): # only docs that have parents
|
||||||
|
parent_doc_id_to_child_docs[doc.meta["__parent_id"]].append(doc)
|
||||||
|
else:
|
||||||
|
docs_to_return.append(doc) # keep docs that have no parents
|
||||||
|
|
||||||
|
# Process each parent group
|
||||||
|
merged_docs = []
|
||||||
|
for parent_doc_id, child_docs in parent_doc_id_to_child_docs.items():
|
||||||
|
parent_doc = _get_parent_doc(parent_doc_id)
|
||||||
|
|
||||||
|
# Calculate merge score
|
||||||
|
score = len(child_docs) / len(parent_doc.meta["__children_ids"])
|
||||||
|
if score > self.threshold:
|
||||||
|
merged_docs.append(parent_doc) # Merge into parent
|
||||||
|
else:
|
||||||
|
docs_to_return.extend(child_docs) # Keep children separate
|
||||||
|
|
||||||
|
# if no new merges were made, we're done
|
||||||
|
if not merged_docs:
|
||||||
|
return merged_docs + docs_to_return
|
||||||
|
|
||||||
|
# Recursively try to merge the next level
|
||||||
|
return _try_merge_level(merged_docs, docs_to_return)
|
||||||
|
|
||||||
|
return {"documents": _try_merge_level(documents, [])}
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
We added a new retrieval technique, `AutoMergingRetriever` which together with the `HierarchicalDocumentSplitter` implement a auto-merging retrieval technique.
|
246
test/components/preprocessors/test_hierarchical_doc_splitter.py
Normal file
246
test/components/preprocessors/test_hierarchical_doc_splitter.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack import Document, Pipeline
|
||||||
|
from haystack.components.preprocessors import HierarchicalDocumentSplitter
|
||||||
|
from haystack.components.writers import DocumentWriter
|
||||||
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
|
|
||||||
|
|
||||||
|
class TestHierarchicalDocumentSplitter:
|
||||||
|
def test_init_with_default_params(self):
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300})
|
||||||
|
assert builder.block_sizes == [300, 200, 100]
|
||||||
|
assert builder.split_overlap == 0
|
||||||
|
assert builder.split_by == "word"
|
||||||
|
|
||||||
|
def test_init_with_custom_params(self):
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}, split_overlap=25, split_by="word")
|
||||||
|
assert builder.block_sizes == [300, 200, 100]
|
||||||
|
assert builder.split_overlap == 25
|
||||||
|
assert builder.split_by == "word"
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}, split_overlap=25, split_by="word")
|
||||||
|
expected = builder.to_dict()
|
||||||
|
assert expected == {
|
||||||
|
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
|
||||||
|
"init_parameters": {"block_sizes": [300, 200, 100], "split_overlap": 25, "split_by": "word"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_from_dict(self):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
|
||||||
|
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
|
||||||
|
}
|
||||||
|
|
||||||
|
builder = HierarchicalDocumentSplitter.from_dict(data)
|
||||||
|
assert builder.block_sizes == [10, 5, 2]
|
||||||
|
assert builder.split_overlap == 0
|
||||||
|
assert builder.split_by == "word"
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word")
|
||||||
|
text = "one two three four five six seven eight nine ten"
|
||||||
|
doc = Document(content=text)
|
||||||
|
output = builder.run([doc])
|
||||||
|
docs = output["documents"]
|
||||||
|
builder.run([doc])
|
||||||
|
|
||||||
|
assert len(docs) == 9
|
||||||
|
assert docs[0].content == "one two three four five six seven eight nine ten"
|
||||||
|
|
||||||
|
# level 1 - root node
|
||||||
|
assert docs[0].meta["__level"] == 0
|
||||||
|
assert len(docs[0].meta["__children_ids"]) == 2
|
||||||
|
|
||||||
|
# level 2 -left branch
|
||||||
|
assert docs[1].meta["__parent_id"] == docs[0].id
|
||||||
|
assert docs[1].meta["__level"] == 1
|
||||||
|
assert len(docs[1].meta["__children_ids"]) == 3
|
||||||
|
|
||||||
|
# level 2 - right branch
|
||||||
|
assert docs[2].meta["__parent_id"] == docs[0].id
|
||||||
|
assert docs[2].meta["__level"] == 1
|
||||||
|
assert len(docs[2].meta["__children_ids"]) == 3
|
||||||
|
|
||||||
|
# level 3 - left branch - leaf nodes
|
||||||
|
assert docs[3].meta["__parent_id"] == docs[1].id
|
||||||
|
assert docs[4].meta["__parent_id"] == docs[1].id
|
||||||
|
assert docs[5].meta["__parent_id"] == docs[1].id
|
||||||
|
assert docs[3].meta["__level"] == 2
|
||||||
|
assert docs[4].meta["__level"] == 2
|
||||||
|
assert docs[5].meta["__level"] == 2
|
||||||
|
assert len(docs[3].meta["__children_ids"]) == 0
|
||||||
|
assert len(docs[4].meta["__children_ids"]) == 0
|
||||||
|
assert len(docs[5].meta["__children_ids"]) == 0
|
||||||
|
|
||||||
|
# level 3 - right branch - leaf nodes
|
||||||
|
assert docs[6].meta["__parent_id"] == docs[2].id
|
||||||
|
assert docs[7].meta["__parent_id"] == docs[2].id
|
||||||
|
assert docs[8].meta["__parent_id"] == docs[2].id
|
||||||
|
assert docs[6].meta["__level"] == 2
|
||||||
|
assert docs[7].meta["__level"] == 2
|
||||||
|
assert docs[8].meta["__level"] == 2
|
||||||
|
assert len(docs[6].meta["__children_ids"]) == 0
|
||||||
|
assert len(docs[7].meta["__children_ids"]) == 0
|
||||||
|
assert len(docs[8].meta["__children_ids"]) == 0
|
||||||
|
|
||||||
|
def test_to_dict_in_pipeline(self):
|
||||||
|
pipeline = Pipeline()
|
||||||
|
hierarchical_doc_builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2})
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
|
doc_writer = DocumentWriter(document_store=doc_store)
|
||||||
|
pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder)
|
||||||
|
pipeline.add_component(name="doc_writer", instance=doc_writer)
|
||||||
|
pipeline.connect("hierarchical_doc_splitter", "doc_writer")
|
||||||
|
expected = pipeline.to_dict()
|
||||||
|
|
||||||
|
assert expected.keys() == {
|
||||||
|
"connections",
|
||||||
|
"connection_type_validation",
|
||||||
|
"components",
|
||||||
|
"max_runs_per_component",
|
||||||
|
"metadata",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert expected["components"].keys() == {"hierarchical_doc_splitter", "doc_writer"}
|
||||||
|
|
||||||
|
assert expected["components"]["hierarchical_doc_splitter"] == {
|
||||||
|
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
|
||||||
|
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_from_dict_in_pipeline(self):
|
||||||
|
data = {
|
||||||
|
"metadata": {},
|
||||||
|
"max_runs_per_component": 100,
|
||||||
|
"components": {
|
||||||
|
"hierarchical_document_splitter": {
|
||||||
|
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
|
||||||
|
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
|
||||||
|
},
|
||||||
|
"doc_writer": {
|
||||||
|
"type": "haystack.components.writers.document_writer.DocumentWriter",
|
||||||
|
"init_parameters": {
|
||||||
|
"document_store": {
|
||||||
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||||
|
"init_parameters": {
|
||||||
|
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
|
||||||
|
"bm25_algorithm": "BM25L",
|
||||||
|
"bm25_parameters": {},
|
||||||
|
"embedding_similarity_function": "dot_product",
|
||||||
|
"index": "f32ad5bf-43cb-4035-9823-1de1ae9853c1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"policy": "NONE",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"connections": [{"sender": "hierarchical_document_splitter.documents", "receiver": "doc_writer.documents"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert Pipeline.from_dict(data)
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_example_in_pipeline(self):
|
||||||
|
pipeline = Pipeline()
|
||||||
|
hierarchical_doc_builder = HierarchicalDocumentSplitter(
|
||||||
|
block_sizes={10, 5, 2}, split_overlap=0, split_by="word"
|
||||||
|
)
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
|
doc_writer = DocumentWriter(document_store=doc_store)
|
||||||
|
|
||||||
|
pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder)
|
||||||
|
pipeline.add_component(name="doc_writer", instance=doc_writer)
|
||||||
|
pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer")
|
||||||
|
|
||||||
|
text = "one two three four five six seven eight nine ten"
|
||||||
|
doc = Document(content=text)
|
||||||
|
docs = pipeline.run({"hierarchical_doc_splitter": {"documents": [doc]}})
|
||||||
|
|
||||||
|
assert docs["doc_writer"]["documents_written"] == 9
|
||||||
|
assert len(doc_store.storage.values()) == 9
|
||||||
|
|
||||||
|
def test_serialization_deserialization_pipeline(self):
|
||||||
|
pipeline = Pipeline()
|
||||||
|
hierarchical_doc_builder = HierarchicalDocumentSplitter(
|
||||||
|
block_sizes={10, 5, 2}, split_overlap=0, split_by="word"
|
||||||
|
)
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
|
doc_writer = DocumentWriter(document_store=doc_store)
|
||||||
|
|
||||||
|
pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder)
|
||||||
|
pipeline.add_component(name="doc_writer", instance=doc_writer)
|
||||||
|
pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer")
|
||||||
|
pipeline_dict = pipeline.to_dict()
|
||||||
|
|
||||||
|
new_pipeline = Pipeline.from_dict(pipeline_dict)
|
||||||
|
assert new_pipeline == pipeline
|
||||||
|
|
||||||
|
def test_split_by_sentence_assure_warm_up_was_called(self):
|
||||||
|
pipeline = Pipeline()
|
||||||
|
hierarchical_doc_builder = HierarchicalDocumentSplitter(
|
||||||
|
block_sizes={10, 5, 2}, split_overlap=0, split_by="sentence"
|
||||||
|
)
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
|
doc_writer = DocumentWriter(document_store=doc_store)
|
||||||
|
|
||||||
|
pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder)
|
||||||
|
pipeline.add_component(name="doc_writer", instance=doc_writer)
|
||||||
|
pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer")
|
||||||
|
|
||||||
|
text = "This is one sentence. This is another sentence. This is the third sentence."
|
||||||
|
doc = Document(content=text)
|
||||||
|
docs = pipeline.run({"hierarchical_doc_splitter": {"documents": [doc]}})
|
||||||
|
|
||||||
|
assert docs["doc_writer"]["documents_written"] == 3
|
||||||
|
assert len(doc_store.storage.values()) == 3
|
||||||
|
|
||||||
|
def test_hierarchical_splitter_multiple_block_sizes(self):
|
||||||
|
# Test with three different block sizes
|
||||||
|
doc = Document(
|
||||||
|
content="This is a simple test document with multiple sentences. It should be split into various sizes. This helps test the hierarchy."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Using three block sizes: 10, 5, 2 words
|
||||||
|
splitter = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word")
|
||||||
|
result = splitter.run([doc])
|
||||||
|
|
||||||
|
documents = result["documents"]
|
||||||
|
|
||||||
|
# Verify root document
|
||||||
|
assert len(documents) > 1
|
||||||
|
root = documents[0]
|
||||||
|
assert root.meta["__level"] == 0
|
||||||
|
assert root.meta["__parent_id"] is None
|
||||||
|
|
||||||
|
# Verify level 1 documents (block_size=10)
|
||||||
|
level_1_docs = [d for d in documents if d.meta["__level"] == 1]
|
||||||
|
for doc in level_1_docs:
|
||||||
|
assert doc.meta["__block_size"] == 10
|
||||||
|
assert doc.meta["__parent_id"] == root.id
|
||||||
|
|
||||||
|
# Verify level 2 documents (block_size=5)
|
||||||
|
level_2_docs = [d for d in documents if d.meta["__level"] == 2]
|
||||||
|
for doc in level_2_docs:
|
||||||
|
assert doc.meta["__block_size"] == 5
|
||||||
|
assert doc.meta["__parent_id"] in [d.id for d in level_1_docs]
|
||||||
|
|
||||||
|
# Verify level 3 documents (block_size=2)
|
||||||
|
level_3_docs = [d for d in documents if d.meta["__level"] == 3]
|
||||||
|
for doc in level_3_docs:
|
||||||
|
assert doc.meta["__block_size"] == 2
|
||||||
|
assert doc.meta["__parent_id"] in [d.id for d in level_2_docs]
|
||||||
|
|
||||||
|
# Verify children references
|
||||||
|
for doc in documents:
|
||||||
|
if doc.meta["__children_ids"]:
|
||||||
|
child_ids = doc.meta["__children_ids"]
|
||||||
|
children = [d for d in documents if d.id in child_ids]
|
||||||
|
for child in children:
|
||||||
|
assert child.meta["__parent_id"] == doc.id
|
||||||
|
assert child.meta["__level"] == doc.meta["__level"] + 1
|
258
test/components/retrievers/test_auto_merging_retriever.py
Normal file
258
test/components/retrievers/test_auto_merging_retriever.py
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack import Document, Pipeline
|
||||||
|
from haystack.components.retrievers import InMemoryBM25Retriever
|
||||||
|
from haystack.components.preprocessors import HierarchicalDocumentSplitter
|
||||||
|
from haystack.components.retrievers.auto_merging_retriever import AutoMergingRetriever
|
||||||
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
|
|
||||||
|
|
||||||
|
class TestAutoMergingRetriever:
|
||||||
|
def test_init_default(self):
|
||||||
|
retriever = AutoMergingRetriever(InMemoryDocumentStore())
|
||||||
|
assert retriever.threshold == 0.5
|
||||||
|
|
||||||
|
def test_init_with_parameters(self):
|
||||||
|
retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7)
|
||||||
|
assert retriever.threshold == 0.7
|
||||||
|
|
||||||
|
def test_init_with_invalid_threshold(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AutoMergingRetriever(InMemoryDocumentStore(), threshold=-2)
|
||||||
|
|
||||||
|
def test_run_missing_parent_id(self):
|
||||||
|
docs = [Document(content="test", meta={"__level": 1, "__block_size": 10})]
|
||||||
|
retriever = AutoMergingRetriever(InMemoryDocumentStore())
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"
|
||||||
|
):
|
||||||
|
retriever.run(documents=docs)
|
||||||
|
|
||||||
|
def test_run_missing_level(self):
|
||||||
|
docs = [Document(content="test", meta={"__parent_id": "parent1", "__block_size": 10})]
|
||||||
|
|
||||||
|
retriever = AutoMergingRetriever(InMemoryDocumentStore())
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="The matched leaf documents do not have the required meta field '__level'"
|
||||||
|
):
|
||||||
|
retriever.run(documents=docs)
|
||||||
|
|
||||||
|
def test_run_missing_block_size(self):
|
||||||
|
docs = [Document(content="test", meta={"__parent_id": "parent1", "__level": 1})]
|
||||||
|
|
||||||
|
retriever = AutoMergingRetriever(InMemoryDocumentStore())
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="The matched leaf documents do not have the required meta field '__block_size'"
|
||||||
|
):
|
||||||
|
retriever.run(documents=docs)
|
||||||
|
|
||||||
|
def test_run_mixed_valid_and_invalid_documents(self):
|
||||||
|
docs = [
|
||||||
|
Document(content="valid", meta={"__parent_id": "parent1", "__level": 1, "__block_size": 10}),
|
||||||
|
Document(content="invalid", meta={"__level": 1, "__block_size": 10}),
|
||||||
|
]
|
||||||
|
retriever = AutoMergingRetriever(InMemoryDocumentStore())
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"
|
||||||
|
):
|
||||||
|
retriever.run(documents=docs)
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7)
|
||||||
|
expected = retriever.to_dict()
|
||||||
|
assert expected["type"] == "haystack.components.retrievers.auto_merging_retriever.AutoMergingRetriever"
|
||||||
|
assert expected["init_parameters"]["threshold"] == 0.7
|
||||||
|
assert (
|
||||||
|
expected["init_parameters"]["document_store"]["type"]
|
||||||
|
== "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_from_dict(self):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.retrievers.auto_merging_retriever.AutoMergingRetriever",
|
||||||
|
"init_parameters": {
|
||||||
|
"document_store": {
|
||||||
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||||
|
"init_parameters": {
|
||||||
|
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
|
||||||
|
"bm25_algorithm": "BM25L",
|
||||||
|
"bm25_parameters": {},
|
||||||
|
"embedding_similarity_function": "dot_product",
|
||||||
|
"index": "6b122bb4-211b-465e-804d-77c5857bf4c5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"threshold": 0.7,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
retriever = AutoMergingRetriever.from_dict(data)
|
||||||
|
assert retriever.threshold == 0.7
|
||||||
|
|
||||||
|
def test_serialization_deserialization_pipeline(self):
|
||||||
|
pipeline = Pipeline()
|
||||||
|
doc_store_parents = InMemoryDocumentStore()
|
||||||
|
bm_25_retriever = InMemoryBM25Retriever(doc_store_parents)
|
||||||
|
auto_merging_retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5)
|
||||||
|
|
||||||
|
pipeline.add_component(name="bm_25_retriever", instance=bm_25_retriever)
|
||||||
|
pipeline.add_component(name="auto_merging_retriever", instance=auto_merging_retriever)
|
||||||
|
pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.documents")
|
||||||
|
pipeline_dict = pipeline.to_dict()
|
||||||
|
|
||||||
|
new_pipeline = Pipeline.from_dict(pipeline_dict)
|
||||||
|
assert new_pipeline == pipeline
|
||||||
|
|
||||||
|
def test_run_parent_not_found(self):
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
|
retriever = AutoMergingRetriever(doc_store, threshold=0.5)
|
||||||
|
|
||||||
|
# a leaf document with a non-existent parent_id
|
||||||
|
leaf_doc = Document(
|
||||||
|
content="test", meta={"__parent_id": "non_existent_parent", "__level": 1, "__block_size": 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Expected 1 parent document with id non_existent_parent, found 0"):
|
||||||
|
retriever.run([leaf_doc])
|
||||||
|
|
||||||
|
def test_run_parent_without_children_metadata(self):
|
||||||
|
"""Test case where a parent document exists but doesn't have the __children_ids metadata field"""
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
|
|
||||||
|
# Create and store a parent document without __children_ids metadata
|
||||||
|
parent_doc = Document(
|
||||||
|
content="parent content",
|
||||||
|
id="parent1",
|
||||||
|
meta={
|
||||||
|
"__level": 1, # Add other required metadata
|
||||||
|
"__block_size": 10,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
doc_store.write_documents([parent_doc])
|
||||||
|
|
||||||
|
retriever = AutoMergingRetriever(doc_store, threshold=0.5)
|
||||||
|
|
||||||
|
# Create a leaf document that points to this parent
|
||||||
|
leaf_doc = Document(content="leaf content", meta={"__parent_id": "parent1", "__level": 2, "__block_size": 5})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Parent document with id parent1 does not have any children"):
|
||||||
|
retriever.run([leaf_doc])
|
||||||
|
|
||||||
|
def test_run_empty_documents(self):
|
||||||
|
retriever = AutoMergingRetriever(InMemoryDocumentStore())
|
||||||
|
assert retriever.run([]) == {"documents": []}
|
||||||
|
|
||||||
|
def test_run_return_parent_document(self):
|
||||||
|
text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing."
|
||||||
|
|
||||||
|
docs = [Document(content=text)]
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word")
|
||||||
|
docs = builder.run(docs)
|
||||||
|
|
||||||
|
# store all non-leaf documents
|
||||||
|
doc_store_parents = InMemoryDocumentStore()
|
||||||
|
for doc in docs["documents"]:
|
||||||
|
if doc.meta["__children_ids"]:
|
||||||
|
doc_store_parents.write_documents([doc])
|
||||||
|
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5)
|
||||||
|
|
||||||
|
# assume we retrieved 2 leaf docs from the same parent, the parent document should be returned,
|
||||||
|
# since it has 3 children and the threshold=0.5, and we retrieved 2 children (2/3 > 0.66(6))
|
||||||
|
leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]]
|
||||||
|
docs = retriever.run(leaf_docs[4:6])
|
||||||
|
assert len(docs["documents"]) == 1
|
||||||
|
assert docs["documents"][0].content == "warm glow over the trees. Birds began to sing."
|
||||||
|
assert len(docs["documents"][0].meta["__children_ids"]) == 3
|
||||||
|
|
||||||
|
def test_run_return_leafs_document(self):
|
||||||
|
docs = [Document(content="The monarch of the wild blue yonder rises from the eastern side of the horizon.")]
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word")
|
||||||
|
docs = builder.run(docs)
|
||||||
|
|
||||||
|
doc_store_parents = InMemoryDocumentStore()
|
||||||
|
for doc in docs["documents"]:
|
||||||
|
if doc.meta["__level"] == 1:
|
||||||
|
doc_store_parents.write_documents([doc])
|
||||||
|
|
||||||
|
leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]]
|
||||||
|
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.6)
|
||||||
|
result = retriever.run([leaf_docs[4]])
|
||||||
|
|
||||||
|
assert len(result["documents"]) == 1
|
||||||
|
assert result["documents"][0].content == "eastern side of "
|
||||||
|
assert result["documents"][0].meta["__parent_id"] == docs["documents"][2].id
|
||||||
|
|
||||||
|
def test_run_return_leafs_document_different_parents(self):
|
||||||
|
docs = [Document(content="The monarch of the wild blue yonder rises from the eastern side of the horizon.")]
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word")
|
||||||
|
docs = builder.run(docs)
|
||||||
|
|
||||||
|
doc_store_parents = InMemoryDocumentStore()
|
||||||
|
for doc in docs["documents"]:
|
||||||
|
if doc.meta["__level"] == 1:
|
||||||
|
doc_store_parents.write_documents([doc])
|
||||||
|
|
||||||
|
leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]]
|
||||||
|
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.6)
|
||||||
|
result = retriever.run([leaf_docs[4], leaf_docs[3]])
|
||||||
|
|
||||||
|
assert len(result["documents"]) == 2
|
||||||
|
assert result["documents"][0].meta["__parent_id"] != result["documents"][1].meta["__parent_id"]
|
||||||
|
|
||||||
|
def test_run_go_up_hierarchy_multiple_levels(self):
|
||||||
|
"""
|
||||||
|
Test if the retriever can go up the hierarchy multiple levels to find the parent document.
|
||||||
|
|
||||||
|
Simulate a scenario where we have 4 leaf-documents that matched some initial query. The leaf-documents
|
||||||
|
are continuously merged up the hierarchy until the threshold is no longer met.
|
||||||
|
In this case it goes from the 4th level in the hierarchy up the 1st level.
|
||||||
|
"""
|
||||||
|
text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing."
|
||||||
|
|
||||||
|
docs = [Document(content=text)]
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={6, 4, 2, 1}, split_overlap=0, split_by="word")
|
||||||
|
docs = builder.run(docs)
|
||||||
|
|
||||||
|
# store all non-leaf documents
|
||||||
|
doc_store_parents = InMemoryDocumentStore()
|
||||||
|
for doc in docs["documents"]:
|
||||||
|
if doc.meta["__children_ids"]:
|
||||||
|
doc_store_parents.write_documents([doc])
|
||||||
|
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.4)
|
||||||
|
|
||||||
|
# simulate a scenario where we have 4 leaf-documents that matched some initial query
|
||||||
|
retrieved_leaf_docs = [d for d in docs["documents"] if d.content in {"The ", "sun ", "rose ", "early "}]
|
||||||
|
|
||||||
|
result = retriever.run(retrieved_leaf_docs)
|
||||||
|
|
||||||
|
assert len(result["documents"]) == 1
|
||||||
|
assert result["documents"][0].content == "The sun rose early in the "
|
||||||
|
|
||||||
|
def test_run_go_up_hierarchy_multiple_levels_hit_root_document(self):
|
||||||
|
"""
|
||||||
|
Test case where we go up hierarchy until the root document, so the root document is returned.
|
||||||
|
|
||||||
|
It's the only document in the hierarchy which has no parent.
|
||||||
|
"""
|
||||||
|
text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing."
|
||||||
|
|
||||||
|
docs = [Document(content=text)]
|
||||||
|
builder = HierarchicalDocumentSplitter(block_sizes={6, 4}, split_overlap=0, split_by="word")
|
||||||
|
docs = builder.run(docs)
|
||||||
|
|
||||||
|
# store all non-leaf documents
|
||||||
|
doc_store_parents = InMemoryDocumentStore()
|
||||||
|
for doc in docs["documents"]:
|
||||||
|
if doc.meta["__children_ids"]:
|
||||||
|
doc_store_parents.write_documents([doc])
|
||||||
|
retriever = AutoMergingRetriever(doc_store_parents, threshold=0.1) # set a low threshold to hit root document
|
||||||
|
|
||||||
|
# simulate a scenario where we have 4 leaf-documents that matched some initial query
|
||||||
|
retrieved_leaf_docs = [
|
||||||
|
d
|
||||||
|
for d in docs["documents"]
|
||||||
|
if d.content in {"The sun rose early ", "in the ", "morning. It cast a ", "over the trees. Birds "}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = retriever.run(retrieved_leaf_docs)
|
||||||
|
|
||||||
|
assert len(result["documents"]) == 1
|
||||||
|
assert result["documents"][0].meta["__level"] == 0 # hit root document
|
Loading…
x
Reference in New Issue
Block a user