feat: adding EmbeddingBasedDocumentSplitter (#10286)

* adding EmbeddingBasedDocumentSplitter

* adding release notes

* adding missing files

* release notes

* PR comments/fixes

* removing aux function and adding all validations to the init()
This commit is contained in:
David S. Batista 2025-12-30 10:48:47 +01:00 committed by GitHub
parent a249bd32ca
commit eb0dc76d1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 990 additions and 2 deletions

View File

@ -2,6 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0
from .protocol import TextEmbedder
from .protocol import DocumentEmbedder, TextEmbedder
__all__ = ["TextEmbedder"]
__all__ = ["DocumentEmbedder", "TextEmbedder"]

View File

@ -4,6 +4,8 @@
from typing import Any, Protocol
from haystack import Document
# See https://github.com/pylint-dev/pylint/issues/9319.
# pylint: disable=unnecessary-ellipsis
@ -28,3 +30,25 @@ class TextEmbedder(Protocol):
- any optional keys such as 'metadata'.
"""
...
class DocumentEmbedder(Protocol):
"""
Protocol for Document Embedders.
"""
def run(self, documents: list[Document]) -> dict[str, Any]:
"""
Generate embeddings for the input documents.
Implementing classes may accept additional optional parameters in their run method.
For example: `def run (self, documents: List[Document], param_a="default", param_b="another_default")`.
:param documents:
The input documents to be embedded.
:returns:
A dictionary containing the keys:
- 'documents', which is expected to be a List[Document] with embeddings added to each document.
- any optional keys such as 'metadata'.
"""
...

View File

@ -13,6 +13,7 @@ _import_structure = {
"document_cleaner": ["DocumentCleaner"],
"document_preprocessor": ["DocumentPreprocessor"],
"document_splitter": ["DocumentSplitter"],
"embedding_based_document_splitter": ["EmbeddingBasedDocumentSplitter"],
"hierarchical_document_splitter": ["HierarchicalDocumentSplitter"],
"recursive_splitter": ["RecursiveDocumentSplitter"],
"text_cleaner": ["TextCleaner"],
@ -24,6 +25,7 @@ if TYPE_CHECKING:
from .document_cleaner import DocumentCleaner as DocumentCleaner
from .document_preprocessor import DocumentPreprocessor as DocumentPreprocessor
from .document_splitter import DocumentSplitter as DocumentSplitter
from .embedding_based_document_splitter import EmbeddingBasedDocumentSplitter as EmbeddingBasedDocumentSplitter
from .hierarchical_document_splitter import HierarchicalDocumentSplitter as HierarchicalDocumentSplitter
from .recursive_splitter import RecursiveDocumentSplitter as RecursiveDocumentSplitter
from .text_cleaner import TextCleaner as TextCleaner

View File

@ -0,0 +1,435 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
from typing import Any
import numpy as np
from haystack import Document, component, logging
from haystack.components.embedders.types import DocumentEmbedder
from haystack.components.preprocessors.sentence_tokenizer import Language, SentenceSplitter
from haystack.core.serialization import component_to_dict, default_from_dict, default_to_dict
from haystack.utils.deserialization import deserialize_component_inplace
logger = logging.getLogger(__name__)
@component
class EmbeddingBasedDocumentSplitter:
"""
Splits documents based on embedding similarity using cosine distances between sequential sentence groups.
This component first splits text into sentences, optionally groups them, calculates embeddings for each group,
and then uses cosine distance between sequential embeddings to determine split points. Any distance above
the specified percentile is treated as a break point. The component also tracks page numbers based on form feed
characters (`\f`) in the original document.
This component is inspired by [5 Levels of Text Splitting](
https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/tutorials/LevelsOfTextSplitting/5_Levels_Of_Text_Splitting.ipynb
) by Greg Kamradt.
### Usage example
```python
from haystack import Document
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.preprocessors import EmbeddingBasedDocumentSplitter
# Create a document with content that has a clear topic shift
doc = Document(
content="This is a first sentence. This is a second sentence. This is a third sentence. "
"Completely different topic. The same completely different topic."
)
# Initialize the embedder to calculate semantic similarities
embedder = SentenceTransformersDocumentEmbedder()
# Configure the splitter with parameters that control splitting behavior
splitter = EmbeddingBasedDocumentSplitter(
document_embedder=embedder,
sentences_per_group=2, # Group 2 sentences before calculating embeddings
percentile=0.95, # Split when cosine distance exceeds 95th percentile
min_length=50, # Merge splits shorter than 50 characters
max_length=1000 # Further split chunks longer than 1000 characters
)
splitter.warm_up()
result = splitter.run(documents=[doc])
# The result contains a list of Document objects, each representing a semantic chunk
# Each split document includes metadata: source_id, split_id, and page_number
print(f"Original document split into {len(result['documents'])} chunks")
for i, split_doc in enumerate(result['documents']):
print(f"Chunk {i}: {split_doc.content[:50]}...")
```
"""
def __init__(
self,
*,
document_embedder: DocumentEmbedder,
sentences_per_group: int = 3,
percentile: float = 0.95,
min_length: int = 50,
max_length: int = 1000,
language: Language = "en",
use_split_rules: bool = True,
extend_abbreviations: bool = True,
):
"""
Initialize EmbeddingBasedDocumentSplitter.
:param document_embedder: The DocumentEmbedder to use for calculating embeddings.
:param sentences_per_group: Number of sentences to group together before embedding.
:param percentile: Percentile threshold for cosine distance. Distances above this percentile
are treated as break points.
:param min_length: Minimum length of splits in characters. Splits below this length will be merged.
:param max_length: Maximum length of splits in characters. Splits above this length will be recursively split.
:param language: Language for sentence tokenization.
:param use_split_rules: Whether to use additional split rules for sentence tokenization. Applies additional
split rules from SentenceSplitter to the sentence spans.
:param extend_abbreviations: If True, the abbreviations used by NLTK's PunktTokenizer are extended by a list
of curated abbreviations. Currently supported languages are: en, de.
If False, the default abbreviations are used.
"""
self.document_embedder = document_embedder
if sentences_per_group <= 0:
raise ValueError("sentences_per_group must be greater than 0.")
self.sentences_per_group = sentences_per_group
if not 0.0 <= percentile <= 1.0:
raise ValueError("percentile must be between 0.0 and 1.0.")
self.percentile = percentile
if min_length < 0:
raise ValueError("min_length must be greater than or equal to 0.")
self.min_length = min_length
if max_length <= min_length:
raise ValueError("max_length must be greater than min_length.")
self.max_length = max_length
self.language = language
self.use_split_rules = use_split_rules
self.extend_abbreviations = extend_abbreviations
self.sentence_splitter: SentenceSplitter | None = None
self._is_warmed_up = False
def warm_up(self) -> None:
"""
Warm up the component by initializing the sentence splitter.
"""
self.sentence_splitter = SentenceSplitter(
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
keep_white_spaces=True,
)
if hasattr(self.document_embedder, "warm_up"):
self.document_embedder.warm_up()
self._is_warmed_up = True
@component.output_types(documents=list[Document])
def run(self, documents: list[Document]) -> dict[str, list[Document]]:
"""
Split documents based on embedding similarity.
:param documents: The documents to split.
:returns: A dictionary with the following key:
- `documents`: List of documents with the split texts. Each document includes:
- A metadata field `source_id` to track the original document.
- A metadata field `split_id` to track the split number.
- A metadata field `page_number` to track the original page number.
- All other metadata copied from the original document.
:raises:
- `RuntimeError`: If the component wasn't warmed up.
- `TypeError`: If the input is not a list of Documents.
- `ValueError`: If the document content is None or empty.
"""
if not self._is_warmed_up:
raise RuntimeError(
"The component EmbeddingBasedDocumentSplitter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
raise TypeError("EmbeddingBasedDocumentSplitter expects a List of Documents as input.")
split_docs: list[Document] = []
for doc in documents:
if doc.content is None:
raise ValueError(
f"EmbeddingBasedDocumentSplitter only works with text documents but content for "
f"document ID {doc.id} is None."
)
if doc.content == "":
logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id)
continue
doc_splits = self._split_document(doc=doc)
split_docs.extend(doc_splits)
return {"documents": split_docs}
def _split_document(self, doc: Document) -> list[Document]:
"""
Split a single document based on embedding similarity.
"""
# Create an initial split of the document content into smaller chunks
# doc.content is validated in `run`
splits = self._split_text(text=doc.content) # type: ignore
# Merge splits smaller than min_length
merged_splits = self._merge_small_splits(splits=splits)
# Recursively split splits larger than max_length
final_splits = self._split_large_splits(splits=merged_splits)
# Create Document objects from the final splits
return EmbeddingBasedDocumentSplitter._create_documents_from_splits(splits=final_splits, original_doc=doc)
def _split_text(self, text: str) -> list[str]:
"""
Split a text into smaller chunks based on embedding similarity.
"""
# NOTE: `self.sentence_splitter.split_sentences` strips all white space types (e.g. new lines, page breaks,
# etc.) at the end of the provided text. So to not lose them, we need keep track of them and add them back to
# the last sentence.
rstripped_text = text.rstrip()
trailing_whitespaces = text[len(rstripped_text) :]
# Split the text into sentences
sentences_result = self.sentence_splitter.split_sentences(rstripped_text) # type: ignore[union-attr]
# Add back the stripped white spaces to the last sentence
if sentences_result and trailing_whitespaces:
sentences_result[-1]["sentence"] += trailing_whitespaces
sentences_result[-1]["end"] += len(trailing_whitespaces)
sentences = [sentence["sentence"] for sentence in sentences_result]
sentence_groups = self._group_sentences(sentences=sentences)
embeddings = self._calculate_embeddings(sentence_groups=sentence_groups)
split_points = self._find_split_points(embeddings=embeddings)
sub_splits = self._create_splits_from_points(sentence_groups=sentence_groups, split_points=split_points)
return sub_splits
def _group_sentences(self, sentences: list[str]) -> list[str]:
"""
Group sentences into groups of sentences_per_group.
"""
if self.sentences_per_group == 1:
return sentences
groups = []
for i in range(0, len(sentences), self.sentences_per_group):
group = sentences[i : i + self.sentences_per_group]
groups.append("".join(group))
return groups
def _calculate_embeddings(self, sentence_groups: list[str]) -> list[list[float]]:
"""
Calculate embeddings for each sentence group using the DocumentEmbedder.
"""
# Create Document objects for each group
group_docs = [Document(content=group) for group in sentence_groups]
result = self.document_embedder.run(group_docs)
embedded_docs = result["documents"]
embeddings = [doc.embedding for doc in embedded_docs]
return embeddings
def _find_split_points(self, embeddings: list[list[float]]) -> list[int]:
"""
Find split points based on cosine distances between sequential embeddings.
"""
if len(embeddings) <= 1:
return []
# Calculate cosine distances between sequential pairs
distances = []
for i in range(len(embeddings) - 1):
distance = EmbeddingBasedDocumentSplitter._cosine_distance(
embedding1=embeddings[i], embedding2=embeddings[i + 1]
)
distances.append(distance)
# Calculate threshold based on percentile
threshold = np.percentile(distances, self.percentile * 100)
# Find indices where distance exceeds threshold
split_points = []
for i, distance in enumerate(distances):
if distance > threshold:
split_points.append(i + 1) # +1 because we want to split after this point
return split_points
@staticmethod
def _cosine_distance(embedding1: list[float], embedding2: list[float]) -> float:
"""
Calculate cosine distance between two embeddings.
"""
vec1 = np.array(embedding1)
vec2 = np.array(embedding2)
norm1 = float(np.linalg.norm(vec1))
norm2 = float(np.linalg.norm(vec2))
if norm1 == 0 or norm2 == 0:
return 1.0
cosine_sim = float(np.dot(vec1, vec2) / (norm1 * norm2))
return 1.0 - cosine_sim
@staticmethod
def _create_splits_from_points(sentence_groups: list[str], split_points: list[int]) -> list[str]:
"""
Create splits based on split points.
"""
if not split_points:
return ["".join(sentence_groups)]
splits = []
start = 0
for point in split_points:
split_text = "".join(sentence_groups[start:point])
if split_text:
splits.append(split_text)
start = point
# Add the last split
if start < len(sentence_groups):
split_text = "".join(sentence_groups[start:])
if split_text:
splits.append(split_text)
return splits
def _merge_small_splits(self, splits: list[str]) -> list[str]:
"""
Merge splits that are below min_length.
"""
if not splits:
return splits
merged = []
current_split = splits[0]
for split in splits[1:]:
# We merge splits that are smaller than min_length but only if the newly merged split is still below
# max_length.
if len(current_split) < self.min_length and len(current_split) + len(split) < self.max_length:
# Merge with next split
current_split += split
else:
# Current split is long enough, save it and start a new one
merged.append(current_split)
current_split = split
# Don't forget the last split
merged.append(current_split)
return merged
def _split_large_splits(self, splits: list[str]) -> list[str]:
"""
Recursively split splits that are above max_length.
This method checks each split and if it exceeds max_length, it attempts to split it further using the same
embedding-based approach. This is done recursively until all splits are within the max_length limit or no
further splitting is possible.
This works because the threshold for splits is calculated dynamically based on the provided of embeddings.
"""
final_splits = []
for split in splits:
if len(split) <= self.max_length:
final_splits.append(split)
else:
# Recursively split large splits
# We can reuse the same _split_text method to split the text into smaller chunks because the threshold
# for splits is calculated dynamically based on embeddings from `split`.
sub_splits = self._split_text(text=split)
# Stop splitting if no further split is possible or continue with recursion
if len(sub_splits) == 1:
logger.warning(
f"Could not split a chunk further below max_length={self.max_length}. "
f"Returning chunk of length {len(split)}."
)
final_splits.append(split)
else:
final_splits.extend(self._split_large_splits(splits=sub_splits))
return final_splits
@staticmethod
def _create_documents_from_splits(splits: list[str], original_doc: Document) -> list[Document]:
"""
Create Document objects from splits.
"""
documents = []
metadata = deepcopy(original_doc.meta)
metadata["source_id"] = original_doc.id
# Calculate page numbers for each split
current_page = 1
for i, split_text in enumerate(splits):
split_meta = deepcopy(metadata)
split_meta["split_id"] = i
# Calculate page number for this split
# Count page breaks in the split itself
page_breaks_in_split = split_text.count("\f")
# Calculate the page number for this split
split_meta["page_number"] = current_page
doc = Document(content=split_text, meta=split_meta)
documents.append(doc)
# Update page counter for next split
current_page += page_breaks_in_split
return documents
def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Serialized dictionary representation of the component.
"""
return default_to_dict(
self,
document_embedder=component_to_dict(obj=self.document_embedder, name="document_embedder"),
sentences_per_group=self.sentences_per_group,
percentile=self.percentile,
min_length=self.min_length,
max_length=self.max_length,
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "EmbeddingBasedDocumentSplitter":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize and create the component.
:returns:
The deserialized component.
"""
deserialize_component_inplace(data["init_parameters"], key="document_embedder")
return default_from_dict(cls, data)

View File

@ -0,0 +1,4 @@
---
features:
- |
Added an ``EmbeddingBasedDocumentSplitter`` component, allowing to split documents based on embedding similarity using cosine distances between sequential sentence groups.

View File

@ -0,0 +1,523 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import Mock, patch
import pytest
from haystack import Document
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.preprocessors import EmbeddingBasedDocumentSplitter
from haystack.utils import ComponentDevice
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# disable tqdm entirely for tests
from tqdm import tqdm
tqdm.disable = True
class TestEmbeddingBasedDocumentSplitter:
def test_init(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(
document_embedder=mock_embedder, sentences_per_group=2, percentile=0.9, min_length=50, max_length=1000
)
assert splitter.document_embedder == mock_embedder
assert splitter.sentences_per_group == 2
assert splitter.percentile == 0.9
assert splitter.min_length == 50
assert splitter.max_length == 1000
def test_init_invalid_sentences_per_group(self):
mock_embedder = Mock()
with pytest.raises(ValueError, match="sentences_per_group must be greater than 0"):
EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, sentences_per_group=0)
def test_init_invalid_percentile(self):
mock_embedder = Mock()
with pytest.raises(ValueError, match="percentile must be between 0.0 and 1.0"):
EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, percentile=1.5)
def test_init_invalid_min_length(self):
mock_embedder = Mock()
with pytest.raises(ValueError, match="min_length must be greater than or equal to 0"):
EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, min_length=-1)
def test_init_invalid_max_length(self):
mock_embedder = Mock()
with pytest.raises(ValueError, match="max_length must be greater than min_length"):
EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, min_length=100, max_length=50)
def test_warm_up(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
with patch(
"haystack.components.preprocessors.embedding_based_document_splitter.SentenceSplitter"
) as mock_splitter_class:
mock_splitter = Mock()
mock_splitter_class.return_value = mock_splitter
splitter.warm_up()
assert splitter.sentence_splitter == mock_splitter
mock_splitter_class.assert_called_once()
def test_run_not_warmed_up(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
with pytest.raises(RuntimeError, match="wasn't warmed up"):
splitter.run(documents=[Document(content="test")])
def test_run_invalid_input(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
splitter.sentence_splitter = Mock()
splitter._is_warmed_up = True
with pytest.raises(TypeError, match="expects a List of Documents"):
splitter.run(documents="not a list")
def test_run_document_with_none_content(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
splitter.sentence_splitter = Mock()
splitter._is_warmed_up = True
with pytest.raises(ValueError, match="content for document ID"):
splitter.run(documents=[Document(content=None)])
def test_run_empty_document(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
splitter.sentence_splitter = Mock()
splitter._is_warmed_up = True
result = splitter.run(documents=[Document(content="")])
assert result["documents"] == []
def test_group_sentences_single(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, sentences_per_group=1)
sentences = ["Sentence 1.", "Sentence 2.", "Sentence 3."]
groups = splitter._group_sentences(sentences)
assert groups == sentences
def test_group_sentences_multiple(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, sentences_per_group=2)
sentences = ["Sentence 1. ", "Sentence 2. ", "Sentence 3. ", "Sentence 4."]
groups = splitter._group_sentences(sentences)
assert groups == ["Sentence 1. Sentence 2. ", "Sentence 3. Sentence 4."]
def test_cosine_distance(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
# Test with identical vectors
embedding1 = [1.0, 0.0, 0.0]
embedding2 = [1.0, 0.0, 0.0]
distance = splitter._cosine_distance(embedding1, embedding2)
assert distance == 0.0
# Test with orthogonal vectors
embedding1 = [1.0, 0.0, 0.0]
embedding2 = [0.0, 1.0, 0.0]
distance = splitter._cosine_distance(embedding1, embedding2)
assert distance == 1.0
# Test with zero vectors
embedding1 = [0.0, 0.0, 0.0]
embedding2 = [1.0, 0.0, 0.0]
distance = splitter._cosine_distance(embedding1, embedding2)
assert distance == 1.0
def test_find_split_points_empty(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
split_points = splitter._find_split_points([])
assert split_points == []
split_points = splitter._find_split_points([[1.0, 0.0]])
assert split_points == []
def test_find_split_points(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, percentile=0.5)
# Create embeddings where the second pair has high distance
embeddings = [
[1.0, 0.0, 0.0], # Similar to next
[0.9, 0.1, 0.0], # Similar to previous
[0.0, 1.0, 0.0], # Very different from next
[0.1, 0.9, 0.0], # Similar to previous
]
split_points = splitter._find_split_points(embeddings)
# Should find a split point after the second embedding (index 2)
assert 2 in split_points
def test_create_splits_from_points(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
sentence_groups = ["Group 1 ", "Group 2 ", "Group 3 ", "Group 4"]
split_points = [2] # Split after index 1
splits = splitter._create_splits_from_points(sentence_groups, split_points)
assert splits == ["Group 1 Group 2 ", "Group 3 Group 4"]
def test_create_splits_from_points_no_points(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
sentence_groups = ["Group 1 ", "Group 2 ", "Group 3"]
split_points = []
splits = splitter._create_splits_from_points(sentence_groups, split_points)
assert splits == ["Group 1 Group 2 Group 3"]
def test_merge_small_splits(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, min_length=10)
splits = ["Short ", "Also short ", "Long enough text ", "Another short"]
merged = splitter._merge_small_splits(splits)
assert len(merged) == 3
assert merged[0] == "Short Also short "
assert merged[1] == "Long enough text "
assert merged[2] == "Another short"
def test_merge_small_splits_respect_max_length(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, min_length=10, max_length=15)
splits = ["123456", "123456789", "1234"]
merged = splitter._merge_small_splits(splits=splits)
assert len(merged) == 2
# First split remains beneath min_length b/c next split is too long
assert merged[0] == "123456"
# Second split is merged with third split to get above min_length and still beneath max_length
assert merged[1] == "1234567891234"
def test_create_documents_from_splits(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
original_doc = Document(content="test", meta={"key": "value"})
splits = ["Split 1", "Split 2"]
documents = splitter._create_documents_from_splits(splits, original_doc)
assert len(documents) == 2
assert documents[0].content == "Split 1"
assert documents[0].meta["source_id"] == original_doc.id
assert documents[0].meta["split_id"] == 0
assert documents[0].meta["key"] == "value"
assert documents[1].content == "Split 2"
assert documents[1].meta["split_id"] == 1
def test_create_documents_from_splits_with_page_numbers(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
original_doc = Document(content="Page 1 content.\fPage 2 content.\f\fPage 4 content.", meta={"key": "value"})
splits = ["Page 1 content.\f", "Page 2 content.\f\f", "Page 4 content."]
documents = splitter._create_documents_from_splits(splits, original_doc)
assert len(documents) == 3
assert documents[0].content == "Page 1 content.\f"
assert documents[0].meta["page_number"] == 1
assert documents[1].content == "Page 2 content.\f\f"
assert documents[1].meta["page_number"] == 2
assert documents[2].content == "Page 4 content."
assert documents[2].meta["page_number"] == 4
def test_create_documents_from_splits_with_consecutive_page_breaks(self):
mock_embedder = Mock()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
# Test with consecutive page breaks at the end
original_doc = Document(content="Page 1 content.\fPage 2 content.\f\f\f", meta={"key": "value"})
splits = ["Page 1 content.\f", "Page 2 content.\f\f\f"]
documents = splitter._create_documents_from_splits(splits, original_doc)
assert len(documents) == 2
assert documents[0].content == "Page 1 content.\f"
assert documents[0].meta["page_number"] == 1
assert documents[1].content == "Page 2 content.\f\f\f"
# Should be page 2, not 4, because consecutive page breaks at the end are adjusted
assert documents[1].meta["page_number"] == 2
def test_calculate_embeddings(self):
mock_embedder = Mock()
# Mock the document embedder to return documents with embeddings
def mock_run(documents):
for doc in documents:
doc.embedding = [1.0, 2.0, 3.0] # Simple mock embedding
return {"documents": documents}
mock_embedder.run = Mock(side_effect=mock_run)
splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder)
sentence_groups = ["Group 1", "Group 2", "Group 3"]
embeddings = splitter._calculate_embeddings(sentence_groups)
assert len(embeddings) == 3
assert all(embedding == [1.0, 2.0, 3.0] for embedding in embeddings)
mock_embedder.run.assert_called_once()
def test_to_dict(self):
mock_embedder = Mock()
mock_embedder.to_dict.return_value = {"type": "MockEmbedder"}
splitter = EmbeddingBasedDocumentSplitter(
document_embedder=mock_embedder, sentences_per_group=2, percentile=0.9, min_length=50, max_length=1000
)
result = splitter.to_dict()
assert "EmbeddingBasedDocumentSplitter" in result["type"]
assert result["init_parameters"]["sentences_per_group"] == 2
assert result["init_parameters"]["percentile"] == 0.9
assert result["init_parameters"]["min_length"] == 50
assert result["init_parameters"]["max_length"] == 1000
assert "document_embedder" in result["init_parameters"]
@pytest.mark.integration
def test_split_document_with_multiple_topics(self):
import os
import torch
# Force CPU usage to avoid MPS memory issues
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
torch.backends.mps.is_available = lambda: False
embedder = SentenceTransformersDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2", device=ComponentDevice.from_str("cpu")
)
embedder.warm_up()
splitter = EmbeddingBasedDocumentSplitter(
document_embedder=embedder, sentences_per_group=2, percentile=0.9, min_length=30, max_length=300
)
splitter.warm_up()
# A document with multiple topics
text = (
"The weather today is beautiful. The sun is shining brightly. The temperature is perfect for a walk. "
"Machine learning has revolutionized many industries. Neural networks can process vast amounts of data. "
"Deep learning models achieve remarkable accuracy on complex tasks. "
"Cooking is both an art and a science. Fresh ingredients make all the difference. "
"Proper seasoning enhances the natural flavors of food. "
"The history of ancient civilizations fascinates researchers. Archaeological discoveries reveal new insights. " # noqa: E501
"Ancient texts provide valuable information about past societies."
)
doc = Document(content=text)
result = splitter.run(documents=[doc])
split_docs = result["documents"]
# There should be more than one split
assert len(split_docs) > 1
# Each split should be non-empty and respect min_length
for split_doc in split_docs:
assert split_doc.content.strip() != ""
assert len(split_doc.content) >= 30
# The splits should cover the original text
combined = "".join([d.content for d in split_docs])
original = text
assert combined in original or original in combined
@pytest.mark.integration
def test_trailing_whitespace_is_preserved(self):
embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
embedder.warm_up()
splitter = EmbeddingBasedDocumentSplitter(document_embedder=embedder, sentences_per_group=1)
splitter.warm_up()
# Normal trailing whitespace
text = "The weather today is beautiful. "
result = splitter.run(documents=[Document(content=text)])
assert result["documents"][0].content == text
# Newline at the end
text = "The weather today is beautiful.\n"
result = splitter.run(documents=[Document(content=text)])
assert result["documents"][0].content == text
# Page break at the end
text = "The weather today is beautiful.\f"
result = splitter.run(documents=[Document(content=text)])
assert result["documents"][0].content == text
@pytest.mark.integration
def test_no_extra_whitespaces_between_sentences(self):
embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
embedder.warm_up()
splitter = EmbeddingBasedDocumentSplitter(
document_embedder=embedder, sentences_per_group=1, percentile=0.9, min_length=10, max_length=500
)
splitter.warm_up()
text = (
"The weather today is beautiful. The sun is shining brightly. The temperature is perfect for a walk. "
"There are no clouds and no rain. Machine learning has revolutionized many industries. "
"Neural networks can process vast amounts of data. Deep learning models achieve remarkable accuracy on complex tasks." # noqa: E501
)
doc = Document(content=text)
result = splitter.run(documents=[doc])
split_docs = result["documents"]
assert len(split_docs) == 2
# Expect the original whitespace structure with trailing spaces where they exist
assert (
split_docs[0].content
== "The weather today is beautiful. The sun is shining brightly. The temperature is perfect for a walk. There are no clouds and no rain. " # noqa: E501
) # noqa: E501
assert (
split_docs[1].content
== "Machine learning has revolutionized many industries. Neural networks can process vast amounts of data. Deep learning models achieve remarkable accuracy on complex tasks." # noqa: E501
) # noqa: E501
@pytest.mark.integration
def test_split_large_splits_recursion(self):
"""
Test that _split_large_splits() works correctly without infinite loops.
This test uses a longer text that will trigger the recursive splitting logic.
If the chunk cannot be split further, it is allowed to be larger than max_length.
"""
embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2", batch_size=32)
semantic_chunker = EmbeddingBasedDocumentSplitter(
document_embedder=embedder, sentences_per_group=5, percentile=0.95, min_length=50, max_length=1000
)
semantic_chunker.warm_up()
text = """# Artificial intelligence and its Impact on Society
## Article from Wikipedia, the free encyclopedia
### Introduction to Artificial Intelligence
Artificial intelligence (AI) is the capability of computational systems to perform tasks typically associated with human intelligence, such as learning, reasoning, problem-solving, perception, and decision-making. It is a field of research in computer science that develops and studies methods and software that enable machines to perceive their environment and use learning and intelligence to take actions that maximize their chances of achieving defined goals.
### The History of Software
The history of software is closely tied to the development of digital computers in the mid-20th century. Early programs were written in the machine language specific to the hardware. The introduction of high-level programming languages in 1958 allowed for more human-readable instructions, making software development easier and more portable across different computer architectures. Software in a programming language is run through a compiler or interpreter to execute on the architecture's hardware. Over time, software has become complex, owing to developments in networking, operating systems, and databases.""" # noqa: E501
doc = Document(content=text)
result = semantic_chunker.run(documents=[doc])
split_docs = result["documents"]
assert len(split_docs) == 1
# If the chunk cannot be split further, it is allowed to be larger than max_length
# At least one split should be larger than max_length in this test case
assert any(len(split_doc.content) > 1000 for split_doc in split_docs)
# Verify that the splits cover the original content
combined_content = "".join([d.content for d in split_docs])
assert combined_content == text
for i, split_doc in enumerate(split_docs):
assert split_doc.meta["source_id"] == doc.id
assert split_doc.meta["split_id"] == i
assert "page_number" in split_doc.meta
@pytest.mark.integration
def test_split_large_splits_actually_splits(self):
"""
Test that _split_large_splits() actually works and can split long texts into multiple chunks.
This test uses a very long text that should be split into multiple chunks.
"""
embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2", batch_size=32)
semantic_chunker = EmbeddingBasedDocumentSplitter(
document_embedder=embedder,
sentences_per_group=3,
percentile=0.85, # Lower percentile to create more splits
min_length=100,
max_length=500, # Smaller max_length to force more splits
)
semantic_chunker.warm_up()
# Create a very long text with multiple paragraphs and topics
text = """# Comprehensive Guide to Machine Learning and Artificial Intelligence
## Introduction to Machine Learning
Machine learning is a subset of artificial intelligence that focuses on the development of computer programs that can access data and use it to learn for themselves. The process of learning begins with observations or data, such as examples, direct experience, or instruction, in order to look for patterns in data and make better decisions in the future based on the examples that we provide. The primary aim is to allow the computers learn automatically without human intervention or assistance and adjust actions accordingly.
## Types of Machine Learning
There are several types of machine learning algorithms, each with their own strengths and weaknesses. Supervised learning involves training a model on a labeled dataset, where the correct answers are provided. The model learns to map inputs to outputs based on these examples. Unsupervised learning, on the other hand, deals with unlabeled data and seeks to find hidden patterns or structures within the data. Reinforcement learning is a type of learning where an agent learns to behave in an environment by performing certain actions and receiving rewards or penalties.
## Deep Learning and Neural Networks
Deep learning is a subset of machine learning that uses neural networks with multiple layers to model and understand complex patterns. Neural networks are inspired by the human brain and consist of interconnected nodes or neurons. Each connection between neurons has a weight that is adjusted during training. The network learns by adjusting these weights based on the error between predicted and actual outputs. Deep learning has been particularly successful in areas such as computer vision, natural language processing, and speech recognition.
\f
## Natural Language Processing
Natural Language Processing (NLP) is a field of artificial intelligence that focuses on the interaction between computers and human language. It involves developing algorithms and models that can understand, interpret, and generate human language. NLP applications include machine translation, sentiment analysis, text summarization, and question answering systems. Recent advances in deep learning have significantly improved the performance of NLP systems, leading to more accurate and sophisticated language models.
## Computer Vision and Image Recognition
Computer vision is another important area of artificial intelligence that deals with how computers can gain high-level understanding from digital images or videos. It involves developing algorithms that can identify and understand visual information from the world. Applications include facial recognition, object detection, medical image analysis, and autonomous vehicle navigation. Deep learning models, particularly convolutional neural networks (CNNs), have revolutionized computer vision by achieving human-level performance on many tasks.
## The Future of Artificial Intelligence
The future of artificial intelligence holds immense potential for transforming various industries and aspects of human life. We can expect to see more sophisticated AI systems that can handle complex reasoning tasks, understand context better, and interact more naturally with humans. However, this rapid advancement also brings challenges related to ethics, privacy, and the impact on employment. It's crucial to develop AI systems that are not only powerful but also safe, fair, and beneficial to society as a whole.
\f
## Ethical Considerations in AI
As artificial intelligence becomes more prevalent, ethical considerations become increasingly important. Issues such as bias in AI systems, privacy concerns, and the potential for misuse need to be carefully addressed. AI systems can inherit biases from their training data, leading to unfair outcomes for certain groups. Privacy concerns arise from the vast amounts of data required to train AI systems. Additionally, there are concerns about the potential for AI to be used maliciously or to replace human workers in certain industries.
## Applications in Healthcare
Artificial intelligence has the potential to revolutionize healthcare by improving diagnosis, treatment planning, and patient care. Machine learning algorithms can analyze medical images to detect diseases earlier and more accurately than human doctors. AI systems can also help in drug discovery by predicting the effectiveness of potential treatments. In addition, AI-powered chatbots and virtual assistants can provide basic healthcare information and support to patients, reducing the burden on healthcare professionals.
## AI in Finance and Banking
The financial industry has been quick to adopt artificial intelligence for various applications. AI systems can analyze market data to make investment decisions, detect fraudulent transactions, and provide personalized financial advice. Machine learning algorithms can assess credit risk more accurately than traditional methods, leading to better lending decisions. Additionally, AI-powered chatbots can handle customer service inquiries, reducing costs and improving customer satisfaction.
\f
## Transportation and Autonomous Vehicles
Autonomous vehicles represent one of the most visible applications of artificial intelligence in transportation. Self-driving cars use a combination of sensors, cameras, and AI algorithms to navigate roads safely. These systems can detect obstacles, read traffic signs, and make decisions about speed and direction. Beyond autonomous cars, AI is also being used in logistics and supply chain management to optimize routes and reduce delivery times.
## Education and Personalized Learning
Artificial intelligence is transforming education by enabling personalized learning experiences. AI systems can adapt to individual student needs, providing customized content and pacing. Intelligent tutoring systems can provide immediate feedback and support to students, helping them learn more effectively. Additionally, AI can help educators by automating administrative tasks and providing insights into student performance and learning patterns.""" # noqa: E501
doc = Document(content=text)
result = semantic_chunker.run(documents=[doc])
split_docs = result["documents"]
assert len(split_docs) == 11
# Verify that the splits cover the original content
combined_content = "".join([d.content for d in split_docs])
assert combined_content == text
for i, split_doc in enumerate(split_docs):
assert split_doc.meta["source_id"] == doc.id
assert split_doc.meta["split_id"] == i
assert "page_number" in split_doc.meta
if i in [0, 1, 2, 3]:
assert split_doc.meta["page_number"] == 1
if i in [4, 5, 6]:
assert split_doc.meta["page_number"] == 2
if i in [7, 8]:
assert split_doc.meta["page_number"] == 3
if i in [9, 10]:
assert split_doc.meta["page_number"] == 4