feat: PreProcessor split by token (tiktoken & Hugging Face) (#5276)

* #4983 implemented split by token for tiktoken tokenizer

* #4983 added unit test for tiktoken splitting

* #4983 implemented and added a test for splitting documents with HuggingFace tokenizer

* #4983 added support for passing HF model names (instead of objects) and added an example to the HF token splitting test

* mocked HTTP model loading in unit tests, fixed pylint error

* fix lossy tokenizers splitting, use LazyImport, ignore UnicodeEncodeError for tiktoken

* reno

* rename reno file

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
Ben Heckmann 2023-11-23 03:26:37 -08:00 committed by GitHub
parent e04a1f16bb
commit a492771b4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 206 additions and 32 deletions

View File

@ -2,6 +2,8 @@ from typing import List, Optional, Union, Literal
from abc import abstractmethod
from transformers import PreTrainedTokenizerBase
from haystack.nodes.base import BaseComponent
from haystack.schema import Document
@ -17,10 +19,11 @@ class BasePreProcessor(BaseComponent):
clean_header_footer: Optional[bool] = False,
clean_empty_lines: Optional[bool] = True,
remove_substrings: Optional[List[str]] = None,
split_by: Literal["word", "sentence", "passage", None] = "word",
split_by: Literal["token", "word", "sentence", "passage", None] = "word",
split_length: Optional[int] = 1000,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = True,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""
@ -44,10 +47,11 @@ class BasePreProcessor(BaseComponent):
def split(
self,
document: Union[dict, Document],
split_by: Literal["word", "sentence", "passage", None],
split_by: Literal["token", "word", "sentence", "passage", None],
split_length: int,
split_overlap: int,
split_respect_sentence_boundary: bool,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
) -> List[Document]:
raise NotImplementedError
@ -57,10 +61,11 @@ class BasePreProcessor(BaseComponent):
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_by: Literal["token", "word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
):
processed_documents = self.process(
@ -83,10 +88,11 @@ class BasePreProcessor(BaseComponent):
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_by: Literal["token", "word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
):
return self.run(

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal, Callable, Any
import logging
import re
@ -17,14 +17,18 @@ from haystack.errors import HaystackError
from haystack.schema import Document
from haystack.lazy_imports import LazyImport
with LazyImport("Run 'pip install transformers'") as transformers_import:
from transformers import PreTrainedTokenizerBase
from transformers import AutoTokenizer
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
import tiktoken
logger = logging.getLogger(__name__)
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
import nltk
iso639_to_nltk = {
"ru": "russian",
"sl": "slovene",
@ -55,11 +59,12 @@ class PreProcessor(BasePreProcessor):
clean_header_footer: bool = False,
clean_empty_lines: bool = True,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = "word",
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = "word",
split_length: int = 200,
split_overlap: int = 0,
split_respect_sentence_boundary: bool = True,
tokenizer_model_folder: Optional[Union[str, Path]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
language: str = "en",
id_hash_keys: Optional[List[str]] = None,
progress_bar: bool = True,
@ -86,6 +91,9 @@ class PreProcessor(BasePreProcessor):
:param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set
to True, the individual split will always have complete sentences &
the number of words will be <= split_length.
:param tokenizer: Specifies the tokenizer to use if split_by="token". Supported options are "tiktoken"
(for OpenAI's GPT-3.5 and GPT-4) and any HuggingFace tokenizer (e.g. 'bert-base-uncased').
HuggingFace tokenizers can also be passed directly as an PreTrainedTokenizerBase object.
:param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format.
Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml"
:param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise.
@ -124,6 +132,7 @@ class PreProcessor(BasePreProcessor):
self.split_length = split_length
self.split_overlap = split_overlap
self.split_respect_sentence_boundary = split_respect_sentence_boundary
self.tokenizer = tokenizer
self.language = language
self.tokenizer_model_folder = tokenizer_model_folder
self.print_log: Set[str] = set()
@ -139,10 +148,11 @@ class PreProcessor(BasePreProcessor):
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""
@ -167,6 +177,7 @@ class PreProcessor(BasePreProcessor):
"split_length": split_length,
"split_overlap": split_overlap,
"split_respect_sentence_boundary": split_respect_sentence_boundary,
"tokenizer": tokenizer,
}
if id_hash_keys is None:
@ -219,10 +230,11 @@ class PreProcessor(BasePreProcessor):
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
if remove_substrings is None:
@ -243,6 +255,8 @@ class PreProcessor(BasePreProcessor):
split_overlap = self.split_overlap
if split_respect_sentence_boundary is None:
split_respect_sentence_boundary = self.split_respect_sentence_boundary
if tokenizer is None:
tokenizer = self.tokenizer
cleaned_document = self.clean(
document=document,
@ -258,6 +272,7 @@ class PreProcessor(BasePreProcessor):
split_length=split_length,
split_overlap=split_overlap,
split_respect_sentence_boundary=split_respect_sentence_boundary,
tokenizer=tokenizer,
id_hash_keys=id_hash_keys,
)
@ -332,10 +347,11 @@ class PreProcessor(BasePreProcessor):
def split(
self,
document: Union[dict, Document],
split_by: Optional[Literal["word", "sentence", "passage"]],
split_by: Optional[Literal["token", "word", "sentence", "passage"]],
split_length: int,
split_overlap: int,
split_respect_sentence_boundary: bool,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
@ -359,8 +375,10 @@ class PreProcessor(BasePreProcessor):
if not split_length:
raise Exception("split_length needs be set when using split_by.")
if split_respect_sentence_boundary and split_by != "word":
raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with split_by='word'.")
if split_respect_sentence_boundary and split_by not in ["word", "token"]:
raise NotImplementedError(
"'split_respect_sentence_boundary=True' is only compatible with split_by='word' or 'token'."
)
if type(document.content) is not str:
logger.error("Document content is not of type str. Nothing to split.")
@ -369,13 +387,17 @@ class PreProcessor(BasePreProcessor):
text = document.content
headlines = document.meta["headlines"] if "headlines" in document.meta else []
if split_respect_sentence_boundary and split_by == "word":
text_splits, splits_pages, splits_start_idxs = self._split_by_word_respecting_sent_boundary(
text=text, split_length=split_length, split_overlap=split_overlap
if split_respect_sentence_boundary and split_by in ["word", "token"]:
def split_function(text):
return self._split_tokens(text, tokenizer=tokenizer) if split_by == "token" else text.split()
text_splits, splits_pages, splits_start_idxs = self._split_into_units_respecting_sent_boundary(
text=text, split_length=split_length, split_overlap=split_overlap, split_function=split_function
)
else:
# create individual "elements" of passage, sentence, or word
elements, split_at = self._split_into_units(text=text, split_by=split_by)
elements, split_at = self._split_into_units(text=text, split_by=split_by, tokenizer=tokenizer)
# concatenate individual elements based on split_length & split_stride
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
@ -467,15 +489,15 @@ class PreProcessor(BasePreProcessor):
cleaned_text = text.replace(substring, "")
return cleaned_text, headlines
def _split_by_word_respecting_sent_boundary(
self, text: str, split_length: int, split_overlap: int
def _split_into_units_respecting_sent_boundary(
self, text: str, split_length: int, split_overlap: int, split_function: Callable
) -> Tuple[List[str], List[int], List[int]]:
"""
Splits the text into parts of split_length words while respecting sentence boundaries.
"""
sentences = self._split_sentences(text)
word_count_slice = 0
unit_count_slice = 0
cur_page = 1
cur_start_idx = 0
splits_pages = []
@ -483,17 +505,17 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs = []
current_slice: List[str] = []
for sen in sentences:
word_count_sen = len(sen.split())
unit_count_sen = len(split_function(sen))
if word_count_sen > split_length:
if unit_count_sen > split_length:
long_sentence_message = (
"We found one or more sentences whose word count is higher than the split length."
"We found one or more sentences whose split count is higher than the split length."
)
if long_sentence_message not in self.print_log:
self.print_log.add(long_sentence_message)
logger.warning(long_sentence_message)
if word_count_slice + word_count_sen > split_length:
if unit_count_slice + unit_count_sen > split_length:
# Number of words exceeds split_length -> save current slice and start a new one
if current_slice:
list_splits.append(current_slice)
@ -501,13 +523,13 @@ class PreProcessor(BasePreProcessor):
splits_start_idxs.append(cur_start_idx)
if split_overlap:
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
current_slice, split_length, split_overlap
processed_sents, current_slice, unit_count_slice = self._get_overlap_from_slice(
current_slice, split_length, split_overlap, split_function
)
else:
processed_sents = current_slice
current_slice = []
word_count_slice = 0
unit_count_slice = 0
cur_start_idx += len("".join(processed_sents))
@ -522,7 +544,7 @@ class PreProcessor(BasePreProcessor):
cur_page += num_page_breaks
current_slice.append(sen)
word_count_slice += word_count_sen
unit_count_slice += unit_count_sen
if current_slice:
list_splits.append(current_slice)
@ -539,7 +561,7 @@ class PreProcessor(BasePreProcessor):
@staticmethod
def _get_overlap_from_slice(
current_slice: List[str], split_length: int, split_overlap: int
current_slice: List[str], split_length: int, split_overlap: int, split_function: Callable
) -> Tuple[List[str], List[str], int]:
"""
Returns a tuple with the following elements:
@ -553,7 +575,7 @@ class PreProcessor(BasePreProcessor):
current_slice_copy = deepcopy(current_slice)
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for idx, s in reversed(list(enumerate(current_slice))[1:]):
sen_len = len(s.split())
sen_len = len(split_function(s))
if word_count_overlap < split_overlap and sen_len < split_length:
overlap.append(s)
word_count_overlap += sen_len
@ -566,7 +588,7 @@ class PreProcessor(BasePreProcessor):
return processed_sents, next_slice, word_count_slice
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[List[str], str]:
if split_by == "passage":
elements = text.split("\n\n")
split_at = "\n\n"
@ -576,8 +598,13 @@ class PreProcessor(BasePreProcessor):
elif split_by == "word":
elements = text.split(" ")
split_at = " "
elif split_by == "token":
elements = self._split_tokens(text, tokenizer)
split_at = ""
else:
raise NotImplementedError("PreProcessor only supports 'passage', 'sentence' or 'word' split_by options.")
raise NotImplementedError(
"PreProcessor only supports 'passage', 'sentence', 'word' or 'token' split_by options."
)
return elements, split_at
@ -823,6 +850,35 @@ class PreProcessor(BasePreProcessor):
sentences = sentence_tokenizer.tokenize(text)
return sentences
def _split_tokens(self, text: str, tokenizer: Any) -> List[str]:
if tokenizer == "tiktoken":
tiktoken_import.check()
enc = tiktoken.get_encoding("cl100k_base") # tiktoken is reversible and lossless
integer_tokens = enc.encode(text, disallowed_special=())
elements = [enc.decode_single_token_bytes(token).decode(errors="ignore") for token in integer_tokens]
return elements
if isinstance(tokenizer, str):
transformers_import.check()
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
raise ValueError(
f"Could not load tokenizer '{tokenizer}' from HuggingFace model hub. "
f"Please make sure that the tokenizer is correct and exists."
)
if isinstance(tokenizer, PreTrainedTokenizerBase):
encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False)
elements = []
for i in range(l := len(encoded.offset_mapping)):
start_current = encoded.offset_mapping[i][0]
start_next = encoded.offset_mapping[i + 1][0] if i < l - 1 else len(text)
elements.append(text[start_current:start_next])
return elements
raise ValueError(
f"Unsupported tokenizer specification {tokenizer}. "
f"Please provide either the string 'tiktoken' or a HuggingFace tokenizer (PreTrainedTokenizerBase)."
)
def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
# Try to load a custom model from 'tokenizer_model_path'
if self.tokenizer_model_folder is not None:

View File

@ -0,0 +1,2 @@
features:
- Add `split_length` by token in PreProcessor

View File

@ -1,12 +1,15 @@
import sys
from pathlib import Path
from typing import Any, Optional, List
from unittest import mock
from unittest.mock import Mock
import nltk.data
import pytest
import tiktoken
from _pytest.monkeypatch import MonkeyPatch
from _pytest.tmpdir import TempPathFactory
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from haystack import Document
from haystack.nodes.file_converter.pdf import PDFToTextConverter
@ -84,6 +87,44 @@ def patched_nltk_data_path(module_tmp_dir: Path, monkeypatch: MonkeyPatch, tmp_p
return tmp_path
@pytest.fixture
def mock_huggingface_tokenizer():
class MockTokenizer(PreTrainedTokenizerBase):
"""Simple Mock tokenizer splitting the text into 2-character chunks."""
@staticmethod
def tokenize(text, **kwargs):
return [text[i : i + 2] for i in range(0, len(text), 2)]
@staticmethod
def encode_plus(text, **kwargs):
return Mock(offset_mapping=[(i, min(len(text), i + 2)) for i in range(0, len(text), 2)])
mock_tokenizer_instance = MockTokenizer()
with mock.patch.object(AutoTokenizer, "from_pretrained", return_value=mock_tokenizer_instance):
yield mock_tokenizer_instance
@pytest.fixture
def mock_tiktoken_tokenizer():
class MockTokenizer:
"""Simple Mock tokenizer "encoding" the text into a 0 for every 5-character chunk."""
@staticmethod
def encode(text, **kwargs):
return [0 for i in range(0, len(text), 5)]
@staticmethod
def decode_single_token_bytes(token):
return b"mock "
mock_tokenizer_instance = MockTokenizer()
with mock.patch.object(tiktoken, "get_encoding", return_value=mock_tokenizer_instance):
yield mock_tokenizer_instance
@pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)])
def test_preprocess_sentence_split(split_length_and_results):
@ -178,6 +219,75 @@ def test_preprocess_word_split():
assert len(documents) == 15
@pytest.mark.unit
def test_preprocess_tiktoken_token_split(mock_tiktoken_tokenizer):
raw_docs = [
"This is a document. It has two sentences and eleven words.",
"This is a document with a long sentence (longer than my split length), it has seventeen words.",
]
docs = [Document(content=content) for content in raw_docs]
split_length = 10
token_split_docs_not_respecting_sentences = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=False,
split_overlap=0,
tokenizer="tiktoken",
).process(docs)
assert len(token_split_docs_not_respecting_sentences) == 4
enc = tiktoken.get_encoding("cl100k_base")
split_documents_encoded = [
enc.encode(d.content, allowed_special="all", disallowed_special=())
for d in token_split_docs_not_respecting_sentences
]
assert all([len(d) <= split_length for d in split_documents_encoded])
token_split_docs_respecting_sentences = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=True,
split_overlap=0,
tokenizer="tiktoken",
).process(docs)
assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences
@pytest.mark.unit
def test_preprocess_huggingface_token_split(mock_huggingface_tokenizer):
raw_docs = [
"This is a document. It has two sentences and eleven words.",
"This is a document with a long sentence (longer than my split length), it has seventeen words.",
]
docs = [Document(content=content) for content in raw_docs]
split_length = 10
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
token_split_docs_not_respecting_sentences = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=False,
split_overlap=0,
tokenizer=tokenizer,
).process(docs)
assert len(token_split_docs_not_respecting_sentences) == 8
split_documents_retokenized = [tokenizer.tokenize(d.content) for d in token_split_docs_not_respecting_sentences]
assert all([len(d) <= split_length for d in split_documents_retokenized])
token_split_docs_respecting_sentences = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=True,
split_overlap=0,
tokenizer=tokenizer,
).process(docs)
assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences
token_split_docs_not_respecting_sentences_instantiate_by_name = PreProcessor(
split_by="token",
split_length=split_length,
split_respect_sentence_boundary=False,
split_overlap=0,
tokenizer="bert-base-uncased",
).process(docs)
assert token_split_docs_not_respecting_sentences == token_split_docs_not_respecting_sentences_instantiate_by_name
@pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 3), (2, 2)])
def test_preprocess_passage_split(split_length_and_results):