feat: Port NLTKDocumentSplitter from dC to Haystack (#8350)

* Port NLTKDocumentSplitter from dC to Haystack

* Improve pydocs

* Use haystack logging

* Add NLTKDocumentSplitter to __init__.py

* Use haystack logging, rename test classes

* Fixing _needs_join return

* Linting

* PR feedback

* More static methods

* Increase test coverage

* Compile pattern

---------

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
Vladimir Blagojevic 2024-09-17 13:59:19 +02:00 committed by GitHub
parent d76e654c63
commit badd0594cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 843 additions and 1 deletions

View File

@ -4,6 +4,7 @@
from .document_cleaner import DocumentCleaner
from .document_splitter import DocumentSplitter
from .nltk_document_splitter import NLTKDocumentSplitter
from .text_cleaner import TextCleaner
__all__ = ["DocumentSplitter", "DocumentCleaner", "TextCleaner"]
__all__ = ["DocumentSplitter", "DocumentCleaner", "TextCleaner", "NLTKDocumentSplitter"]

View File

@ -0,0 +1,239 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
from typing import Dict, List, Literal, Tuple
from haystack import Document, component, logging
from haystack.components.preprocessors.document_splitter import DocumentSplitter
from haystack.components.preprocessors.utils import Language, SentenceSplitter
logger = logging.getLogger(__name__)
@component
class NLTKDocumentSplitter(DocumentSplitter):
def __init__(
self,
split_by: Literal["word", "sentence", "page", "passage"] = "word",
split_length: int = 200,
split_overlap: int = 0,
split_threshold: int = 0,
respect_sentence_boundary: bool = False,
language: Language = "en",
use_split_rules: bool = True,
extend_abbreviations: bool = True,
):
"""
Splits your documents using NLTK to respect sentence boundaries.
Initialize the NLTKDocumentSplitter.
:param split_by: Select the unit for splitting your documents. Choose from `word` for splitting by spaces (" "),
`sentence` for splitting by NLTK sentence tokenizer, `page` for splitting by the form feed ("\\f") or
`passage` for splitting by double line breaks ("\\n\\n").
:param split_length: The maximum number of units in each split.
:param split_overlap: The number of overlapping units for each split.
:param split_threshold: The minimum number of units per split. If a split has fewer units
than the threshold, it's attached to the previous split.
:param respect_sentence_boundary: Choose whether to respect sentence boundaries when splitting by "word".
If True, uses NLTK to detect sentence boundaries, ensuring splits occur only between sentences.
:param language: Choose the language for the NLTK tokenizer. The default is English ("en").
:param use_split_rules: Choose whether to use additional split rules when splitting by `sentence`.
:param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list
of curated abbreviations, if available.
This is currently supported for English ("en") and German ("de").
"""
super(NLTKDocumentSplitter, self).__init__(
split_by=split_by, split_length=split_length, split_overlap=split_overlap, split_threshold=split_threshold
)
if respect_sentence_boundary and split_by != "word":
logger.warning(
"The 'respect_sentence_boundary' option is only supported for `split_by='word'`. "
"The option `respect_sentence_boundary` will be set to `False`."
)
respect_sentence_boundary = False
self.respect_sentence_boundary = respect_sentence_boundary
self.sentence_splitter = SentenceSplitter(
language=language,
use_split_rules=use_split_rules,
extend_abbreviations=extend_abbreviations,
keep_white_spaces=True,
)
self.language = language
def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage", "page"]) -> List[str]:
"""
Splits the text into units based on the specified split_by parameter.
:param text: The text to split.
:param split_by: The unit to split the text by. Choose from "word", "sentence", "passage", or "page".
:returns: A list of units.
"""
if split_by == "page":
self.split_at = "\f"
units = text.split(self.split_at)
elif split_by == "passage":
self.split_at = "\n\n"
units = text.split(self.split_at)
elif split_by == "sentence":
# whitespace is preserved while splitting text into sentences when using keep_white_spaces=True
# so split_at is set to an empty string
self.split_at = ""
result = self.sentence_splitter.split_sentences(text)
units = [sentence["sentence"] for sentence in result]
elif split_by == "word":
self.split_at = " "
units = text.split(self.split_at)
else:
raise NotImplementedError(
"DocumentSplitter only supports 'word', 'sentence', 'page' or 'passage' split_by options."
)
# Add the delimiter back to all units except the last one
for i in range(len(units) - 1):
units[i] += self.split_at
return units
@component.output_types(documents=List[Document])
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
"""
Split documents into smaller parts.
Splits documents by the unit expressed in `split_by`, with a length of `split_length`
and an overlap of `split_overlap`.
:param documents: The documents to split.
:returns: A dictionary with the following key:
- `documents`: List of documents with the split texts. Each document includes:
- A metadata field source_id to track the original document.
- A metadata field page_number to track the original page number.
- All other metadata copied from the original document.
:raises TypeError: if the input is not a list of Documents.
:raises ValueError: if the content of a document is None.
"""
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
raise TypeError("DocumentSplitter expects a List of Documents as input.")
split_docs = []
for doc in documents:
if doc.content is None:
raise ValueError(
f"DocumentSplitter only works with text documents but content for document ID {doc.id} is None."
)
if self.respect_sentence_boundary:
units = self._split_into_units(doc.content, "sentence")
text_splits, splits_pages, splits_start_idxs = self._concatenate_sentences_based_on_word_amount(
sentences=units, split_length=self.split_length, split_overlap=self.split_overlap
)
else:
units = self._split_into_units(doc.content, self.split_by)
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
elements=units,
split_length=self.split_length,
split_overlap=self.split_overlap,
split_threshold=self.split_threshold,
)
metadata = deepcopy(doc.meta)
metadata["source_id"] = doc.id
split_docs += self._create_docs_from_splits(
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
)
return {"documents": split_docs}
@staticmethod
def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int:
"""
Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`.
:param sentences: The list of sentences to split.
:param split_length: The maximum number of words in each split.
:param split_overlap: The number of overlapping words in each split.
:returns: The number of sentences to keep in the next chunk.
"""
# If the split_overlap is 0, we don't need to keep any sentences
if split_overlap == 0:
return 0
num_sentences_to_keep = 0
num_words = 0
for sent in reversed(sentences):
num_words += len(sent.split())
# If the number of words is larger than the split_length then don't add any more sentences
if num_words > split_length:
break
num_sentences_to_keep += 1
if num_words > split_overlap:
break
return num_sentences_to_keep
def _concatenate_sentences_based_on_word_amount(
self, sentences: List[str], split_length: int, split_overlap: int
) -> Tuple[List[str], List[int], List[int]]:
"""
Groups the sentences into chunks of `split_length` words while respecting sentence boundaries.
:param sentences: The list of sentences to split.
:param split_length: The maximum number of words in each split.
:param split_overlap: The number of overlapping words in each split.
:returns: A tuple containing the concatenated sentences, the start page numbers, and the start indices.
"""
# Chunk information
chunk_word_count = 0
chunk_starting_page_number = 1
chunk_start_idx = 0
current_chunk: List[str] = []
# Output lists
split_start_page_numbers = []
list_of_splits: List[List[str]] = []
split_start_indices = []
for sentence_idx, sentence in enumerate(sentences):
current_chunk.append(sentence)
chunk_word_count += len(sentence.split())
next_sentence_word_count = (
len(sentences[sentence_idx + 1].split()) if sentence_idx < len(sentences) - 1 else 0
)
# Number of words in the current chunk plus the next sentence is larger than the split_length
# or we reached the last sentence
if (chunk_word_count + next_sentence_word_count) > split_length or sentence_idx == len(sentences) - 1:
# Save current chunk and start a new one
list_of_splits.append(current_chunk)
split_start_page_numbers.append(chunk_starting_page_number)
split_start_indices.append(chunk_start_idx)
# Get the number of sentences that overlap with the next chunk
num_sentences_to_keep = self._number_of_sentences_to_keep(
sentences=current_chunk, split_length=split_length, split_overlap=split_overlap
)
# Set up information for the new chunk
if num_sentences_to_keep > 0:
# Processed sentences are the ones that are not overlapping with the next chunk
processed_sentences = current_chunk[:-num_sentences_to_keep]
chunk_starting_page_number += sum(sent.count("\f") for sent in processed_sentences)
chunk_start_idx += len("".join(processed_sentences))
# Next chunk starts with the sentences that were overlapping with the previous chunk
current_chunk = current_chunk[-num_sentences_to_keep:]
chunk_word_count = sum(len(s.split()) for s in current_chunk)
else:
# Here processed_sentences is the same as current_chunk since there is no overlap
chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk)
chunk_start_idx += len("".join(current_chunk))
current_chunk = []
chunk_word_count = 0
# Concatenate the sentences together within each split
text_splits = []
for split in list_of_splits:
text = "".join(split)
if len(text) > 0:
text_splits.append(text)
return text_splits, split_start_page_numbers, split_start_indices

View File

@ -0,0 +1,233 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import re
from pathlib import Path
from typing import Any, Dict, List, Literal, Tuple
from haystack import logging
from haystack.lazy_imports import LazyImport
with LazyImport("Run 'pip install nltk'") as nltk_imports:
import nltk
nltk_imports.check()
logger = logging.getLogger(__name__)
Language = Literal[
"ru", "sl", "es", "sv", "tr", "cs", "da", "nl", "en", "et", "fi", "fr", "de", "el", "it", "no", "pl", "pt", "ml"
]
ISO639_TO_NLTK = {
"ru": "russian",
"sl": "slovene",
"es": "spanish",
"sv": "swedish",
"tr": "turkish",
"cs": "czech",
"da": "danish",
"nl": "dutch",
"en": "english",
"et": "estonian",
"fi": "finnish",
"fr": "french",
"de": "german",
"el": "greek",
"it": "italian",
"no": "norwegian",
"pl": "polish",
"pt": "portuguese",
"ml": "malayalam",
}
QUOTE_SPANS_RE = re.compile(r"\W(\"+|\'+).*?\1")
class CustomPunktLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
# The following adjustment of PunktSentenceTokenizer is inspired by:
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
# It is needed for preserving whitespace while splitting text into sentences.
_period_context_fmt = r"""
%(SentEndChars)s # a potential sentence ending
\s* # match potential whitespace [ \t\n\x0B\f\r]
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # or some other token - original version: \s+(?P<next_tok>\S+)
))"""
def period_context_re(self) -> re.Pattern:
"""
Compiles and returns a regular expression to find contexts including possible sentence boundaries.
:returns: A compiled regular expression pattern.
"""
try:
return self._re_period_context # type: ignore
except: # noqa: E722
self._re_period_context = re.compile(
self._period_context_fmt
% {
"NonWord": self._re_non_word_chars,
# SentEndChars might be followed by closing brackets, so we match them here.
"SentEndChars": self._re_sent_end_chars + r"[\)\]}]*",
},
re.UNICODE | re.VERBOSE,
)
return self._re_period_context
def load_sentence_tokenizer(
language: Language, keep_white_spaces: bool = False
) -> nltk.tokenize.punkt.PunktSentenceTokenizer:
"""
Utility function to load the nltk sentence tokenizer.
:param language: The language for the tokenizer.
:param keep_white_spaces: If True, the tokenizer will keep white spaces between sentences.
:returns: nltk sentence tokenizer.
"""
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
try:
nltk.download("punkt_tab")
except FileExistsError as error:
logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: {error}", error=error)
language_name = ISO639_TO_NLTK.get(language)
if language_name is not None:
sentence_tokenizer = nltk.data.load(f"tokenizers/punkt_tab/{language_name}.pickle")
else:
logger.warning(
"PreProcessor couldn't find the default sentence tokenizer model for {language}. "
" Using English instead. You may train your own model and use the 'tokenizer_model_folder' parameter.",
language=language,
)
sentence_tokenizer = nltk.data.load("tokenizers/punkt_tab/english.pickle")
if keep_white_spaces:
sentence_tokenizer._lang_vars = CustomPunktLanguageVars()
return sentence_tokenizer
class SentenceSplitter: # pylint: disable=too-few-public-methods
"""
SentenceSplitter splits a text into sentences using the nltk sentence tokenizer
"""
def __init__(
self,
language: Language = "en",
use_split_rules: bool = True,
extend_abbreviations: bool = True,
keep_white_spaces: bool = False,
) -> None:
"""
Initializes the SentenceSplitter with the specified language, split rules, and abbreviation handling.
:param language: The language for the tokenizer. Default is "en".
:param use_split_rules: If True, the additional split rules are used. If False, the rules are not used.
:param extend_abbreviations: If True, the abbreviations used by NLTK's PunktTokenizer are extended by a list
of curated abbreviations if available. If False, the default abbreviations are used.
Currently supported languages are: en, de.
:param keep_white_spaces: If True, the tokenizer will keep white spaces between sentences.
"""
self.language = language
self.sentence_tokenizer = load_sentence_tokenizer(language, keep_white_spaces=keep_white_spaces)
self.use_split_rules = use_split_rules
if extend_abbreviations:
abbreviations = SentenceSplitter._read_abbreviations(language)
self.sentence_tokenizer._params.abbrev_types.update(abbreviations)
self.keep_white_spaces = keep_white_spaces
def split_sentences(self, text: str) -> List[Dict[str, Any]]:
"""
Splits a text into sentences including references to original char positions for each split.
:param text: The text to split.
:returns: list of sentences with positions.
"""
sentence_spans = list(self.sentence_tokenizer.span_tokenize(text))
if self.use_split_rules:
sentence_spans = SentenceSplitter._apply_split_rules(text, sentence_spans)
sentences = [{"sentence": text[start:end], "start": start, "end": end} for start, end in sentence_spans]
return sentences
@staticmethod
def _apply_split_rules(text: str, sentence_spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
"""
Applies additional split rules to the sentence spans.
:param text: The text to split.
:param sentence_spans: The list of sentence spans to split.
:returns: The list of sentence spans after applying the split rules.
"""
new_sentence_spans = []
quote_spans = [match.span() for match in QUOTE_SPANS_RE.finditer(text)]
while sentence_spans:
span = sentence_spans.pop(0)
next_span = sentence_spans[0] if len(sentence_spans) > 0 else None
while next_span and SentenceSplitter._needs_join(text, span, next_span, quote_spans):
sentence_spans.pop(0)
span = (span[0], next_span[1])
next_span = sentence_spans[0] if len(sentence_spans) > 0 else None
start, end = span
new_sentence_spans.append((start, end))
return new_sentence_spans
@staticmethod
def _needs_join(
text: str, span: Tuple[int, int], next_span: Tuple[int, int], quote_spans: List[Tuple[int, int]]
) -> bool:
"""
Checks if the spans need to be joined as parts of one sentence.
:param text: The text containing the spans.
:param span: The current sentence span within text.
:param next_span: The next sentence span within text.
:param quote_spans: All quoted spans within text.
:returns: True if the spans needs to be joined.
"""
start, end = span
next_start, next_end = next_span
# sentence. sentence"\nsentence -> no split (end << quote_end)
# sentence.", sentence -> no split (end < quote_end)
# sentence?", sentence -> no split (end < quote_end)
if any(quote_start < end < quote_end for quote_start, quote_end in quote_spans):
# sentence boundary is inside a quote
return True
# sentence." sentence -> split (end == quote_end)
# sentence?" sentence -> no split (end == quote_end)
if any(quote_start < end == quote_end and text[quote_end - 2] == "?" for quote_start, quote_end in quote_spans):
# question is cited
return True
if re.search(r"(^|\n)\s*\d{1,2}\.$", text[start:end]) is not None:
# sentence ends with a numeration
return True
# next sentence starts with a bracket or we return False
return re.search(r"^\s*[\(\[]", text[next_start:next_end]) is not None
@staticmethod
def _read_abbreviations(language: Language) -> List[str]:
"""
Reads the abbreviations for a given language from the abbreviations file.
:param language: The language to read the abbreviations for.
:returns: List of abbreviations.
"""
abbreviations_file = Path(__file__).parent.parent / f"data/abbreviations/{language}.txt"
if not abbreviations_file.exists():
logger.warning("No abbreviations file found for {language}.Using default abbreviations.", language=language)
return []
abbreviations = abbreviations_file.read_text().split("\n")
return abbreviations

View File

@ -103,6 +103,8 @@ extra-dependencies = [
"python-pptx", # PPTXToDocument
"python-docx", # DocxToDocument
"nltk", # NLTKDocumentSplitter
# OpenAPI
"jsonref", # OpenAPIServiceConnector, OpenAPIServiceToFunctions
"openapi3",

View File

@ -0,0 +1,4 @@
---
features:
- |
Introduced a new NLTK document splitting component, enhancing document preprocessing capabilities. This feature allows for fine-grained control over the splitting of documents into smaller parts based on configurable criteria such as word count, sentence boundaries, and page breaks. It supports multiple languages and offers options for handling sentence boundaries and abbreviations, facilitating better handling of various document types for further processing tasks.

View File

@ -0,0 +1,363 @@
from typing import List
import pytest
from haystack import Document
from pytest import LogCaptureFixture
from haystack.components.preprocessors.nltk_document_splitter import NLTKDocumentSplitter
from haystack.components.preprocessors.utils import SentenceSplitter
def test_init_warning_message(caplog: LogCaptureFixture) -> None:
_ = NLTKDocumentSplitter(split_by="page", respect_sentence_boundary=True)
assert "The 'respect_sentence_boundary' option is only supported for" in caplog.text
class TestNLTKDocumentSplitterSplitIntoUnits:
def test_document_splitter_split_into_units_word(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="word", split_length=3, split_overlap=0, split_threshold=0, language="en"
)
text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything."
units = document_splitter._split_into_units(text=text, split_by="word")
assert units == [
"Moonlight ",
"shimmered ",
"softly, ",
"wolves ",
"howled ",
"nearby, ",
"night ",
"enveloped ",
"everything.",
]
def test_document_splitter_split_into_units_sentence(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="sentence", split_length=2, split_overlap=0, split_threshold=0, language="en"
)
text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night."
units = document_splitter._split_into_units(text=text, split_by="sentence")
assert units == [
"Moonlight shimmered softly, wolves howled nearby, night enveloped everything. ",
"It was a dark night.",
]
def test_document_splitter_split_into_units_passage(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="passage", split_length=2, split_overlap=0, split_threshold=0, language="en"
)
text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\n\nIt was a dark night."
units = document_splitter._split_into_units(text=text, split_by="passage")
assert units == [
"Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\n\n",
"It was a dark night.",
]
def test_document_splitter_split_into_units_page(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="page", split_length=2, split_overlap=0, split_threshold=0, language="en"
)
text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\fIt was a dark night."
units = document_splitter._split_into_units(text=text, split_by="page")
assert units == [
"Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\f",
"It was a dark night.",
]
def test_document_splitter_split_into_units_raise_error(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="word", split_length=3, split_overlap=0, split_threshold=0, language="en"
)
text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything."
with pytest.raises(NotImplementedError):
document_splitter._split_into_units(text=text, split_by="invalid") # type: ignore
class TestNLTKDocumentSplitterNumberOfSentencesToKeep:
@pytest.mark.parametrize(
"sentences, expected_num_sentences",
[
(["Moonlight shimmered softly, wolves howled nearby, night enveloped everything."], 0),
([" It was a dark night ..."], 0),
([" The moon was full."], 1),
],
)
def test_number_of_sentences_to_keep(self, sentences: List[str], expected_num_sentences: int) -> None:
num_sentences = NLTKDocumentSplitter._number_of_sentences_to_keep(
sentences=sentences, split_length=5, split_overlap=2
)
assert num_sentences == expected_num_sentences
def test_number_of_sentences_to_keep_split_overlap_zero(self) -> None:
sentences = [
"Moonlight shimmered softly, wolves howled nearby, night enveloped everything.",
" It was a dark night ...",
" The moon was full.",
]
num_sentences = NLTKDocumentSplitter._number_of_sentences_to_keep(
sentences=sentences, split_length=5, split_overlap=0
)
assert num_sentences == 0
class TestNLTKDocumentSplitterRun:
def test_run_type_error(self) -> None:
document_splitter = NLTKDocumentSplitter()
with pytest.raises(TypeError):
document_splitter.run(documents=Document(content="Moonlight shimmered softly.")) # type: ignore
def test_run_value_error(self) -> None:
document_splitter = NLTKDocumentSplitter()
with pytest.raises(ValueError):
document_splitter.run(documents=[Document(content=None)])
def test_run_split_by_sentence_1(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="sentence",
split_length=2,
split_overlap=0,
split_threshold=0,
language="en",
use_split_rules=True,
extend_abbreviations=True,
)
text = (
"Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night ... "
"The moon was full."
)
documents = document_splitter.run(documents=[Document(content=text)])["documents"]
assert len(documents) == 2
assert (
documents[0].content == "Moonlight shimmered softly, wolves howled nearby, night enveloped "
"everything. It was a dark night ... "
)
assert documents[1].content == "The moon was full."
def test_run_split_by_sentence_2(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="sentence",
split_length=1,
split_overlap=0,
split_threshold=0,
language="en",
use_split_rules=False,
extend_abbreviations=True,
)
text = (
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
"This is another test sentence. (This is a third test sentence.) "
"This is the last test sentence."
)
documents = document_splitter.run(documents=[Document(content=text)])["documents"]
assert len(documents) == 4
assert (
documents[0].content
== "This is a test sentence with many many words that exceeds the split length and should not be repeated. "
)
assert documents[0].meta["page_number"] == 1
assert documents[0].meta["split_id"] == 0
assert documents[0].meta["split_idx_start"] == text.index(documents[0].content)
assert documents[1].content == "This is another test sentence. "
assert documents[1].meta["page_number"] == 1
assert documents[1].meta["split_id"] == 1
assert documents[1].meta["split_idx_start"] == text.index(documents[1].content)
assert documents[2].content == "(This is a third test sentence.) "
assert documents[2].meta["page_number"] == 1
assert documents[2].meta["split_id"] == 2
assert documents[2].meta["split_idx_start"] == text.index(documents[2].content)
assert documents[3].content == "This is the last test sentence."
assert documents[3].meta["page_number"] == 1
assert documents[3].meta["split_id"] == 3
assert documents[3].meta["split_idx_start"] == text.index(documents[3].content)
def test_run_split_by_sentence_3(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="sentence",
split_length=1,
split_overlap=0,
split_threshold=0,
language="en",
use_split_rules=True,
extend_abbreviations=True,
)
text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5."
documents = document_splitter.run(documents=[Document(content=text)])["documents"]
assert len(documents) == 4
assert documents[0].content == "Sentence on page 1.\f"
assert documents[0].meta["page_number"] == 1
assert documents[0].meta["split_id"] == 0
assert documents[0].meta["split_idx_start"] == text.index(documents[0].content)
assert documents[1].content == "Sentence on page 2. \f"
assert documents[1].meta["page_number"] == 2
assert documents[1].meta["split_id"] == 1
assert documents[1].meta["split_idx_start"] == text.index(documents[1].content)
assert documents[2].content == "Sentence on page 3. \f\f "
assert documents[2].meta["page_number"] == 3
assert documents[2].meta["split_id"] == 2
assert documents[2].meta["split_idx_start"] == text.index(documents[2].content)
assert documents[3].content == "Sentence on page 5."
assert documents[3].meta["page_number"] == 5
assert documents[3].meta["split_id"] == 3
assert documents[3].meta["split_idx_start"] == text.index(documents[3].content)
def test_run_split_by_sentence_4(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="sentence",
split_length=2,
split_overlap=1,
split_threshold=0,
language="en",
use_split_rules=True,
extend_abbreviations=True,
)
text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5."
documents = document_splitter.run(documents=[Document(content=text)])["documents"]
assert len(documents) == 3
assert documents[0].content == "Sentence on page 1.\fSentence on page 2. \f"
assert documents[0].meta["page_number"] == 1
assert documents[0].meta["split_id"] == 0
assert documents[0].meta["split_idx_start"] == text.index(documents[0].content)
assert documents[1].content == "Sentence on page 2. \fSentence on page 3. \f\f "
assert documents[1].meta["page_number"] == 2
assert documents[1].meta["split_id"] == 1
assert documents[1].meta["split_idx_start"] == text.index(documents[1].content)
assert documents[2].content == "Sentence on page 3. \f\f Sentence on page 5."
assert documents[2].meta["page_number"] == 3
assert documents[2].meta["split_id"] == 2
assert documents[2].meta["split_idx_start"] == text.index(documents[2].content)
class TestNLTKDocumentSplitterRespectSentenceBoundary:
def test_run_split_by_word_respect_sentence_boundary(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="word",
split_length=3,
split_overlap=0,
split_threshold=0,
language="en",
respect_sentence_boundary=True,
)
text = (
"Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night.\f"
"The moon was full."
)
documents = document_splitter.run(documents=[Document(content=text)])["documents"]
assert len(documents) == 3
assert documents[0].content == "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. "
assert documents[0].meta["page_number"] == 1
assert documents[0].meta["split_id"] == 0
assert documents[0].meta["split_idx_start"] == text.index(documents[0].content)
assert documents[1].content == "It was a dark night.\f"
assert documents[1].meta["page_number"] == 1
assert documents[1].meta["split_id"] == 1
assert documents[1].meta["split_idx_start"] == text.index(documents[1].content)
assert documents[2].content == "The moon was full."
assert documents[2].meta["page_number"] == 2
assert documents[2].meta["split_id"] == 2
assert documents[2].meta["split_idx_start"] == text.index(documents[2].content)
def test_run_split_by_word_respect_sentence_boundary_no_repeats(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="word",
split_length=13,
split_overlap=3,
split_threshold=0,
language="en",
respect_sentence_boundary=True,
use_split_rules=False,
extend_abbreviations=False,
)
text = (
"This is a test sentence with many many words that exceeds the split length and should not be repeated. "
"This is another test sentence. (This is a third test sentence.) "
"This is the last test sentence."
)
documents = document_splitter.run([Document(content=text)])["documents"]
assert len(documents) == 3
assert (
documents[0].content
== "This is a test sentence with many many words that exceeds the split length and should not be repeated. "
)
assert "This is a test sentence with many many words" not in documents[1].content
assert "This is a test sentence with many many words" not in documents[2].content
def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page_breaks(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="word",
split_length=5,
split_overlap=1,
split_threshold=0,
language="en",
use_split_rules=True,
extend_abbreviations=True,
respect_sentence_boundary=True,
)
text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5."
documents = document_splitter.run(documents=[Document(content=text)])["documents"]
assert len(documents) == 4
assert documents[0].content == "Sentence on page 1.\f"
assert documents[0].meta["page_number"] == 1
assert documents[0].meta["split_id"] == 0
assert documents[0].meta["split_idx_start"] == text.index(documents[0].content)
assert documents[1].content == "Sentence on page 1.\fSentence on page 2. \f"
assert documents[1].meta["page_number"] == 1
assert documents[1].meta["split_id"] == 1
assert documents[1].meta["split_idx_start"] == text.index(documents[1].content)
assert documents[2].content == "Sentence on page 2. \fSentence on page 3. \f\f "
assert documents[2].meta["page_number"] == 2
assert documents[2].meta["split_id"] == 2
assert documents[2].meta["split_idx_start"] == text.index(documents[2].content)
assert documents[3].content == "Sentence on page 3. \f\f Sentence on page 5."
assert documents[3].meta["page_number"] == 3
assert documents[3].meta["split_id"] == 3
assert documents[3].meta["split_idx_start"] == text.index(documents[3].content)
class TestSentenceSplitter:
def test_apply_split_rules_second_while_loop(self) -> None:
text = "This is a test. (With a parenthetical statement.) And another sentence."
spans = [(0, 15), (16, 50), (51, 74)]
result = SentenceSplitter._apply_split_rules(text, spans)
assert len(result) == 2
assert result == [(0, 50), (51, 74)]
def test_apply_split_rules_no_join(self) -> None:
text = "This is a test. This is another test. And a third test."
spans = [(0, 15), (16, 36), (37, 54)]
result = SentenceSplitter._apply_split_rules(text, spans)
assert len(result) == 3
assert result == [(0, 15), (16, 36), (37, 54)]
@pytest.mark.parametrize(
"text,span,next_span,quote_spans,expected",
[
# triggers sentence boundary is inside a quote
('He said, "Hello World." Then left.', (0, 15), (16, 23), [(9, 23)], True)
],
)
def test_needs_join_cases(self, text, span, next_span, quote_spans, expected):
result = SentenceSplitter._needs_join(text, span, next_span, quote_spans)
assert result == expected, f"Expected {expected} for input: {text}, {span}, {next_span}, {quote_spans}"