feat: PreProcessor split by token (tiktoken & Hugging Face) (#5276)

* #4983 implemented split by token for tiktoken tokenizer

* #4983 added unit test for tiktoken splitting

* #4983 implemented and added a test for splitting documents with HuggingFace tokenizer

* #4983 added support for passing HF model names (instead of objects) and added an example to the HF token splitting test

* mocked HTTP model loading in unit tests, fixed pylint error

* fix lossy tokenizers splitting, use LazyImport, ignore UnicodeEncodeError for tiktoken

* reno

* rename reno file

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
Ben Heckmann 2023-11-23 03:26:37 -08:00 committed by GitHub
parent e04a1f16bb
commit a492771b4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 206 additions and 32 deletions

View File

@ -2,6 +2,8 @@ from typing import List, Optional, Union, Literal
from abc import abstractmethod from abc import abstractmethod
from transformers import PreTrainedTokenizerBase
from haystack.nodes.base import BaseComponent from haystack.nodes.base import BaseComponent
from haystack.schema import Document from haystack.schema import Document
@ -17,10 +19,11 @@ class BasePreProcessor(BaseComponent):
clean_header_footer: Optional[bool] = False, clean_header_footer: Optional[bool] = False,
clean_empty_lines: Optional[bool] = True, clean_empty_lines: Optional[bool] = True,
remove_substrings: Optional[List[str]] = None, remove_substrings: Optional[List[str]] = None,
split_by: Literal["word", "sentence", "passage", None] = "word", split_by: Literal["token", "word", "sentence", "passage", None] = "word",
split_length: Optional[int] = 1000, split_length: Optional[int] = 1000,
split_overlap: Optional[int] = None, split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = True, split_respect_sentence_boundary: Optional[bool] = True,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
id_hash_keys: Optional[List[str]] = None, id_hash_keys: Optional[List[str]] = None,
) -> List[Document]: ) -> List[Document]:
""" """
@ -44,10 +47,11 @@ class BasePreProcessor(BaseComponent):
def split( def split(
self, self,
document: Union[dict, Document], document: Union[dict, Document],
split_by: Literal["word", "sentence", "passage", None], split_by: Literal["token", "word", "sentence", "passage", None],
split_length: int, split_length: int,
split_overlap: int, split_overlap: int,
split_respect_sentence_boundary: bool, split_respect_sentence_boundary: bool,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError
@ -57,10 +61,11 @@ class BasePreProcessor(BaseComponent):
clean_whitespace: Optional[bool] = None, clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None, clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None, clean_empty_lines: Optional[bool] = None,
split_by: Literal["word", "sentence", "passage", None] = None, split_by: Literal["token", "word", "sentence", "passage", None] = None,
split_length: Optional[int] = None, split_length: Optional[int] = None,
split_overlap: Optional[int] = None, split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None, split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None, id_hash_keys: Optional[List[str]] = None,
): ):
processed_documents = self.process( processed_documents = self.process(
@ -83,10 +88,11 @@ class BasePreProcessor(BaseComponent):
clean_whitespace: Optional[bool] = None, clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None, clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None, clean_empty_lines: Optional[bool] = None,
split_by: Literal["word", "sentence", "passage", None] = None, split_by: Literal["token", "word", "sentence", "passage", None] = None,
split_length: Optional[int] = None, split_length: Optional[int] = None,
split_overlap: Optional[int] = None, split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None, split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None, id_hash_keys: Optional[List[str]] = None,
): ):
return self.run( return self.run(

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal, Callable, Any
import logging import logging
import re import re
@ -17,14 +17,18 @@ from haystack.errors import HaystackError
from haystack.schema import Document from haystack.schema import Document
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
with LazyImport("Run 'pip install transformers'") as transformers_import:
from transformers import PreTrainedTokenizerBase
from transformers import AutoTokenizer
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
import tiktoken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import: with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
import nltk import nltk
iso639_to_nltk = { iso639_to_nltk = {
"ru": "russian", "ru": "russian",
"sl": "slovene", "sl": "slovene",
@ -55,11 +59,12 @@ class PreProcessor(BasePreProcessor):
clean_header_footer: bool = False, clean_header_footer: bool = False,
clean_empty_lines: bool = True, clean_empty_lines: bool = True,
remove_substrings: Optional[List[str]] = None, remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = "word", split_by: Optional[Literal["token", "word", "sentence", "passage"]] = "word",
split_length: int = 200, split_length: int = 200,
split_overlap: int = 0, split_overlap: int = 0,
split_respect_sentence_boundary: bool = True, split_respect_sentence_boundary: bool = True,
tokenizer_model_folder: Optional[Union[str, Path]] = None, tokenizer_model_folder: Optional[Union[str, Path]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
language: str = "en", language: str = "en",
id_hash_keys: Optional[List[str]] = None, id_hash_keys: Optional[List[str]] = None,
progress_bar: bool = True, progress_bar: bool = True,
@ -86,6 +91,9 @@ class PreProcessor(BasePreProcessor):
:param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set :param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set
to True, the individual split will always have complete sentences & to True, the individual split will always have complete sentences &
the number of words will be <= split_length. the number of words will be <= split_length.
:param tokenizer: Specifies the tokenizer to use if split_by="token". Supported options are "tiktoken"
(for OpenAI's GPT-3.5 and GPT-4) and any HuggingFace tokenizer (e.g. 'bert-base-uncased').
HuggingFace tokenizers can also be passed directly as an PreTrainedTokenizerBase object.
:param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format. :param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format.
Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml" Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml"
:param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise. :param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise.
@ -124,6 +132,7 @@ class PreProcessor(BasePreProcessor):
self.split_length = split_length self.split_length = split_length
self.split_overlap = split_overlap self.split_overlap = split_overlap
self.split_respect_sentence_boundary = split_respect_sentence_boundary self.split_respect_sentence_boundary = split_respect_sentence_boundary
self.tokenizer = tokenizer
self.language = language self.language = language
self.tokenizer_model_folder = tokenizer_model_folder self.tokenizer_model_folder = tokenizer_model_folder
self.print_log: Set[str] = set() self.print_log: Set[str] = set()
@ -139,10 +148,11 @@ class PreProcessor(BasePreProcessor):
clean_header_footer: Optional[bool] = None, clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None, clean_empty_lines: Optional[bool] = None,
remove_substrings: Optional[List[str]] = None, remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = None, split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
split_length: Optional[int] = None, split_length: Optional[int] = None,
split_overlap: Optional[int] = None, split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None, split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None, id_hash_keys: Optional[List[str]] = None,
) -> List[Document]: ) -> List[Document]:
""" """
@ -167,6 +177,7 @@ class PreProcessor(BasePreProcessor):
"split_length": split_length, "split_length": split_length,
"split_overlap": split_overlap, "split_overlap": split_overlap,
"split_respect_sentence_boundary": split_respect_sentence_boundary, "split_respect_sentence_boundary": split_respect_sentence_boundary,
"tokenizer": tokenizer,
} }
if id_hash_keys is None: if id_hash_keys is None:
@ -219,10 +230,11 @@ class PreProcessor(BasePreProcessor):
clean_header_footer: Optional[bool] = None, clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None, clean_empty_lines: Optional[bool] = None,
remove_substrings: Optional[List[str]] = None, remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = None, split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
split_length: Optional[int] = None, split_length: Optional[int] = None,
split_overlap: Optional[int] = None, split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None, split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None, id_hash_keys: Optional[List[str]] = None,
) -> List[Document]: ) -> List[Document]:
if remove_substrings is None: if remove_substrings is None:
@ -243,6 +255,8 @@ class PreProcessor(BasePreProcessor):
split_overlap = self.split_overlap split_overlap = self.split_overlap
if split_respect_sentence_boundary is None: if split_respect_sentence_boundary is None:
split_respect_sentence_boundary = self.split_respect_sentence_boundary split_respect_sentence_boundary = self.split_respect_sentence_boundary
if tokenizer is None:
tokenizer = self.tokenizer
cleaned_document = self.clean( cleaned_document = self.clean(
document=document, document=document,
@ -258,6 +272,7 @@ class PreProcessor(BasePreProcessor):
split_length=split_length, split_length=split_length,
split_overlap=split_overlap, split_overlap=split_overlap,
split_respect_sentence_boundary=split_respect_sentence_boundary, split_respect_sentence_boundary=split_respect_sentence_boundary,
tokenizer=tokenizer,
id_hash_keys=id_hash_keys, id_hash_keys=id_hash_keys,
) )
@ -332,10 +347,11 @@ class PreProcessor(BasePreProcessor):
def split( def split(
self, self,
document: Union[dict, Document], document: Union[dict, Document],
split_by: Optional[Literal["word", "sentence", "passage"]], split_by: Optional[Literal["token", "word", "sentence", "passage"]],
split_length: int, split_length: int,
split_overlap: int, split_overlap: int,
split_respect_sentence_boundary: bool, split_respect_sentence_boundary: bool,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None, id_hash_keys: Optional[List[str]] = None,
) -> List[Document]: ) -> List[Document]:
"""Perform document splitting on a single document. This method can split on different units, at different lengths, """Perform document splitting on a single document. This method can split on different units, at different lengths,
@ -359,8 +375,10 @@ class PreProcessor(BasePreProcessor):
if not split_length: if not split_length:
raise Exception("split_length needs be set when using split_by.") raise Exception("split_length needs be set when using split_by.")
if split_respect_sentence_boundary and split_by != "word": if split_respect_sentence_boundary and split_by not in ["word", "token"]:
raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with split_by='word'.") raise NotImplementedError(
"'split_respect_sentence_boundary=True' is only compatible with split_by='word' or 'token'."
)
if type(document.content) is not str: if type(document.content) is not str:
logger.error("Document content is not of type str. Nothing to split.") logger.error("Document content is not of type str. Nothing to split.")
@ -369,13 +387,17 @@ class PreProcessor(BasePreProcessor):
text = document.content text = document.content
headlines = document.meta["headlines"] if "headlines" in document.meta else [] headlines = document.meta["headlines"] if "headlines" in document.meta else []
if split_respect_sentence_boundary and split_by == "word": if split_respect_sentence_boundary and split_by in ["word", "token"]:
text_splits, splits_pages, splits_start_idxs = self._split_by_word_respecting_sent_boundary(
text=text, split_length=split_length, split_overlap=split_overlap def split_function(text):
return self._split_tokens(text, tokenizer=tokenizer) if split_by == "token" else text.split()
text_splits, splits_pages, splits_start_idxs = self._split_into_units_respecting_sent_boundary(
text=text, split_length=split_length, split_overlap=split_overlap, split_function=split_function
) )
else: else:
# create individual "elements" of passage, sentence, or word # create individual "elements" of passage, sentence, or word
elements, split_at = self._split_into_units(text=text, split_by=split_by) elements, split_at = self._split_into_units(text=text, split_by=split_by, tokenizer=tokenizer)
# concatenate individual elements based on split_length & split_stride # concatenate individual elements based on split_length & split_stride
text_splits, splits_pages, splits_start_idxs = self._concatenate_units( text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
@ -467,15 +489,15 @@ class PreProcessor(BasePreProcessor):
cleaned_text = text.replace(substring, "") cleaned_text = text.replace(substring, "")
return cleaned_text, headlines return cleaned_text, headlines
def _split_by_word_respecting_sent_boundary( def _split_into_units_respecting_sent_boundary(
self, text: str, split_length: int, split_overlap: int self, text: str, split_length: int, split_overlap: int, split_function: Callable
) -> Tuple[List[str], List[int], List[int]]: ) -> Tuple[List[str], List[int], List[int]]:
""" """
Splits the text into parts of split_length words while respecting sentence boundaries. Splits the text into parts of split_length words while respecting sentence boundaries.
""" """
sentences = self._split_sentences(text) sentences = self._split_sentences(text)
word_count_slice = 0 unit_count_slice = 0
cur_page = 1 cur_page = 1
cur_start_idx = 0 cur_start_idx = 0
splits_pages = [] splits_pages = []
@ -483,17 +505,17 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs = [] splits_start_idxs = []
current_slice: List[str] = [] current_slice: List[str] = []
for sen in sentences: for sen in sentences:
word_count_sen = len(sen.split()) unit_count_sen = len(split_function(sen))
if word_count_sen > split_length: if unit_count_sen > split_length:
long_sentence_message = ( long_sentence_message = (
"We found one or more sentences whose word count is higher than the split length." "We found one or more sentences whose split count is higher than the split length."
) )
if long_sentence_message not in self.print_log: if long_sentence_message not in self.print_log:
self.print_log.add(long_sentence_message) self.print_log.add(long_sentence_message)
logger.warning(long_sentence_message) logger.warning(long_sentence_message)
if word_count_slice + word_count_sen > split_length: if unit_count_slice + unit_count_sen > split_length:
# Number of words exceeds split_length -> save current slice and start a new one # Number of words exceeds split_length -> save current slice and start a new one
if current_slice: if current_slice:
list_splits.append(current_slice) list_splits.append(current_slice)
@ -501,13 +523,13 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs.append(cur_start_idx) splits_start_idxs.append(cur_start_idx)
if split_overlap: if split_overlap:
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice( processed_sents, current_slice, unit_count_slice = self._get_overlap_from_slice(
current_slice, split_length, split_overlap current_slice, split_length, split_overlap, split_function
) )
else: else:
processed_sents = current_slice processed_sents = current_slice
current_slice = [] current_slice = []
word_count_slice = 0 unit_count_slice = 0
cur_start_idx += len("".join(processed_sents)) cur_start_idx += len("".join(processed_sents))
@ -522,7 +544,7 @@ class PreProcessor(BasePreProcessor):
cur_page += num_page_breaks cur_page += num_page_breaks
current_slice.append(sen) current_slice.append(sen)
word_count_slice += word_count_sen unit_count_slice += unit_count_sen
if current_slice: if current_slice:
list_splits.append(current_slice) list_splits.append(current_slice)
@ -539,7 +561,7 @@ class PreProcessor(BasePreProcessor):
@staticmethod @staticmethod
def _get_overlap_from_slice( def _get_overlap_from_slice(
current_slice: List[str], split_length: int, split_overlap: int current_slice: List[str], split_length: int, split_overlap: int, split_function: Callable
) -> Tuple[List[str], List[str], int]: ) -> Tuple[List[str], List[str], int]:
""" """
Returns a tuple with the following elements: Returns a tuple with the following elements:
@ -553,7 +575,7 @@ class PreProcessor(BasePreProcessor):
current_slice_copy = deepcopy(current_slice) current_slice_copy = deepcopy(current_slice)
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence # Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for idx, s in reversed(list(enumerate(current_slice))[1:]): for idx, s in reversed(list(enumerate(current_slice))[1:]):
sen_len = len(s.split()) sen_len = len(split_function(s))
if word_count_overlap < split_overlap and sen_len < split_length: if word_count_overlap < split_overlap and sen_len < split_length:
overlap.append(s) overlap.append(s)
word_count_overlap += sen_len word_count_overlap += sen_len
@ -566,7 +588,7 @@ class PreProcessor(BasePreProcessor):
return processed_sents, next_slice, word_count_slice return processed_sents, next_slice, word_count_slice
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]: def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[List[str], str]:
if split_by == "passage": if split_by == "passage":
elements = text.split("\n\n") elements = text.split("\n\n")
split_at = "\n\n" split_at = "\n\n"
@ -576,8 +598,13 @@ class PreProcessor(BasePreProcessor):
elif split_by == "word": elif split_by == "word":
elements = text.split(" ") elements = text.split(" ")
split_at = " " split_at = " "
elif split_by == "token":
elements = self._split_tokens(text, tokenizer)
split_at = ""
else: else:
raise NotImplementedError("PreProcessor only supports 'passage', 'sentence' or 'word' split_by options.") raise NotImplementedError(
"PreProcessor only supports 'passage', 'sentence', 'word' or 'token' split_by options."
)
return elements, split_at return elements, split_at
@ -823,6 +850,35 @@ class PreProcessor(BasePreProcessor):
sentences = sentence_tokenizer.tokenize(text) sentences = sentence_tokenizer.tokenize(text)
return sentences return sentences
def _split_tokens(self, text: str, tokenizer: Any) -> List[str]:
if tokenizer == "tiktoken":
tiktoken_import.check()
enc = tiktoken.get_encoding("cl100k_base") # tiktoken is reversible and lossless
integer_tokens = enc.encode(text, disallowed_special=())
elements = [enc.decode_single_token_bytes(token).decode(errors="ignore") for token in integer_tokens]
return elements
if isinstance(tokenizer, str):
transformers_import.check()
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
raise ValueError(
f"Could not load tokenizer '{tokenizer}' from HuggingFace model hub. "
f"Please make sure that the tokenizer is correct and exists."
)
if isinstance(tokenizer, PreTrainedTokenizerBase):
encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False)
elements = []
for i in range(l := len(encoded.offset_mapping)):
start_current = encoded.offset_mapping[i][0]
start_next = encoded.offset_mapping[i + 1][0] if i < l - 1 else len(text)
elements.append(text[start_current:start_next])
return elements
raise ValueError(
f"Unsupported tokenizer specification {tokenizer}. "
f"Please provide either the string 'tiktoken' or a HuggingFace tokenizer (PreTrainedTokenizerBase)."
)
def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer": def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
# Try to load a custom model from 'tokenizer_model_path' # Try to load a custom model from 'tokenizer_model_path'
if self.tokenizer_model_folder is not None: if self.tokenizer_model_folder is not None:

View File

@ -0,0 +1,2 @@
features:
- Add `split_length` by token in PreProcessor

View File

@ -1,12 +1,15 @@
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Optional, List from typing import Any, Optional, List
from unittest import mock
from unittest.mock import Mock from unittest.mock import Mock
import nltk.data import nltk.data
import pytest import pytest
import tiktoken
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from _pytest.tmpdir import TempPathFactory from _pytest.tmpdir import TempPathFactory
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from haystack import Document from haystack import Document
from haystack.nodes.file_converter.pdf import PDFToTextConverter from haystack.nodes.file_converter.pdf import PDFToTextConverter
@ -84,6 +87,44 @@ def patched_nltk_data_path(module_tmp_dir: Path, monkeypatch: MonkeyPatch, tmp_p
return tmp_path return tmp_path
@pytest.fixture
def mock_huggingface_tokenizer():
class MockTokenizer(PreTrainedTokenizerBase):
"""Simple Mock tokenizer splitting the text into 2-character chunks."""
@staticmethod
def tokenize(text, **kwargs):
return [text[i : i + 2] for i in range(0, len(text), 2)]
@staticmethod
def encode_plus(text, **kwargs):
return Mock(offset_mapping=[(i, min(len(text), i + 2)) for i in range(0, len(text), 2)])
mock_tokenizer_instance = MockTokenizer()
with mock.patch.object(AutoTokenizer, "from_pretrained", return_value=mock_tokenizer_instance):
yield mock_tokenizer_instance
@pytest.fixture
def mock_tiktoken_tokenizer():
class MockTokenizer:
"""Simple Mock tokenizer "encoding" the text into a 0 for every 5-character chunk."""
@staticmethod
def encode(text, **kwargs):
return [0 for i in range(0, len(text), 5)]
@staticmethod
def decode_single_token_bytes(token):
return b"mock "
mock_tokenizer_instance = MockTokenizer()
with mock.patch.object(tiktoken, "get_encoding", return_value=mock_tokenizer_instance):
yield mock_tokenizer_instance
@pytest.mark.unit @pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)]) @pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)])
def test_preprocess_sentence_split(split_length_and_results): def test_preprocess_sentence_split(split_length_and_results):
@ -178,6 +219,75 @@ def test_preprocess_word_split():
assert len(documents) == 15 assert len(documents) == 15
@pytest.mark.unit
def test_preprocess_tiktoken_token_split(mock_tiktoken_tokenizer):
raw_docs = [
"This is a document. It has two sentences and eleven words.",
"This is a document with a long sentence (longer than my split length), it has seventeen words.",
]
docs = [Document(content=content) for content in raw_docs]
split_length = 10
token_split_docs_not_respecting_sentences = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=False,
split_overlap=0,
tokenizer="tiktoken",
).process(docs)
assert len(token_split_docs_not_respecting_sentences) == 4
enc = tiktoken.get_encoding("cl100k_base")
split_documents_encoded = [
enc.encode(d.content, allowed_special="all", disallowed_special=())
for d in token_split_docs_not_respecting_sentences
]
assert all([len(d) <= split_length for d in split_documents_encoded])
token_split_docs_respecting_sentences = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=True,
split_overlap=0,
tokenizer="tiktoken",
).process(docs)
assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences
@pytest.mark.unit
def test_preprocess_huggingface_token_split(mock_huggingface_tokenizer):
raw_docs = [
"This is a document. It has two sentences and eleven words.",
"This is a document with a long sentence (longer than my split length), it has seventeen words.",
]
docs = [Document(content=content) for content in raw_docs]
split_length = 10
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
token_split_docs_not_respecting_sentences = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=False,
split_overlap=0,
tokenizer=tokenizer,
).process(docs)
assert len(token_split_docs_not_respecting_sentences) == 8
split_documents_retokenized = [tokenizer.tokenize(d.content) for d in token_split_docs_not_respecting_sentences]
assert all([len(d) <= split_length for d in split_documents_retokenized])
token_split_docs_respecting_sentences = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=True,
split_overlap=0,
tokenizer=tokenizer,
).process(docs)
assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences
token_split_docs_not_respecting_sentences_instantiate_by_name = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=False,
split_overlap=0,
tokenizer="bert-base-uncased",
).process(docs)
assert token_split_docs_not_respecting_sentences == token_split_docs_not_respecting_sentences_instantiate_by_name
@pytest.mark.unit @pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 3), (2, 2)]) @pytest.mark.parametrize("split_length_and_results", [(1, 3), (2, 2)])
def test_preprocess_passage_split(split_length_and_results): def test_preprocess_passage_split(split_length_and_results):