From a492771b4d4bffe1d163fe03359f29da63a76c4e Mon Sep 17 00:00:00 2001 From: Ben Heckmann <79015931+benheckmann@users.noreply.github.com> Date: Thu, 23 Nov 2023 03:26:37 -0800 Subject: [PATCH] feat: PreProcessor split by token (tiktoken & Hugging Face) (#5276) * #4983 implemented split by token for tiktoken tokenizer * #4983 added unit test for tiktoken splitting * #4983 implemented and added a test for splitting documents with HuggingFace tokenizer * #4983 added support for passing HF model names (instead of objects) and added an example to the HF token splitting test * mocked HTTP model loading in unit tests, fixed pylint error * fix lossy tokenizers splitting, use LazyImport, ignore UnicodeEncodeError for tiktoken * reno * rename reno file --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Co-authored-by: ZanSara --- haystack/nodes/preprocessor/base.py | 14 ++- haystack/nodes/preprocessor/preprocessor.py | 112 +++++++++++++----- .../split-by-token-b9a4f954d4077ecc.yaml | 2 + test/nodes/test_preprocessor.py | 110 +++++++++++++++++ 4 files changed, 206 insertions(+), 32 deletions(-) create mode 100644 releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml diff --git a/haystack/nodes/preprocessor/base.py b/haystack/nodes/preprocessor/base.py index 6213fe1ac..79480f254 100644 --- a/haystack/nodes/preprocessor/base.py +++ b/haystack/nodes/preprocessor/base.py @@ -2,6 +2,8 @@ from typing import List, Optional, Union, Literal from abc import abstractmethod +from transformers import PreTrainedTokenizerBase + from haystack.nodes.base import BaseComponent from haystack.schema import Document @@ -17,10 +19,11 @@ class BasePreProcessor(BaseComponent): clean_header_footer: Optional[bool] = False, clean_empty_lines: Optional[bool] = True, remove_substrings: Optional[List[str]] = None, - split_by: Literal["word", "sentence", "passage", None] = "word", + split_by: Literal["token", "word", "sentence", "passage", None] = "word", split_length: Optional[int] = 1000, split_overlap: Optional[int] = None, split_respect_sentence_boundary: Optional[bool] = True, + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken", id_hash_keys: Optional[List[str]] = None, ) -> List[Document]: """ @@ -44,10 +47,11 @@ class BasePreProcessor(BaseComponent): def split( self, document: Union[dict, Document], - split_by: Literal["word", "sentence", "passage", None], + split_by: Literal["token", "word", "sentence", "passage", None], split_length: int, split_overlap: int, split_respect_sentence_boundary: bool, + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None, ) -> List[Document]: raise NotImplementedError @@ -57,10 +61,11 @@ class BasePreProcessor(BaseComponent): clean_whitespace: Optional[bool] = None, clean_header_footer: Optional[bool] = None, clean_empty_lines: Optional[bool] = None, - split_by: Literal["word", "sentence", "passage", None] = None, + split_by: Literal["token", "word", "sentence", "passage", None] = None, split_length: Optional[int] = None, split_overlap: Optional[int] = None, split_respect_sentence_boundary: Optional[bool] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None, id_hash_keys: Optional[List[str]] = None, ): processed_documents = self.process( @@ -83,10 +88,11 @@ class BasePreProcessor(BaseComponent): clean_whitespace: Optional[bool] = None, clean_header_footer: Optional[bool] = None, clean_empty_lines: Optional[bool] = None, - split_by: Literal["word", "sentence", "passage", None] = None, + split_by: Literal["token", "word", "sentence", "passage", None] = None, split_length: Optional[int] = None, split_overlap: Optional[int] = None, split_respect_sentence_boundary: Optional[bool] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None, id_hash_keys: Optional[List[str]] = None, ): return self.run( diff --git a/haystack/nodes/preprocessor/preprocessor.py b/haystack/nodes/preprocessor/preprocessor.py index ca8ec6dd8..9dddc22e7 100644 --- a/haystack/nodes/preprocessor/preprocessor.py +++ b/haystack/nodes/preprocessor/preprocessor.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal +from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal, Callable, Any import logging import re @@ -17,14 +17,18 @@ from haystack.errors import HaystackError from haystack.schema import Document from haystack.lazy_imports import LazyImport +with LazyImport("Run 'pip install transformers'") as transformers_import: + from transformers import PreTrainedTokenizerBase + from transformers import AutoTokenizer + +with LazyImport("Run 'pip install tiktoken'") as tiktoken_import: + import tiktoken logger = logging.getLogger(__name__) - with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import: import nltk - iso639_to_nltk = { "ru": "russian", "sl": "slovene", @@ -55,11 +59,12 @@ class PreProcessor(BasePreProcessor): clean_header_footer: bool = False, clean_empty_lines: bool = True, remove_substrings: Optional[List[str]] = None, - split_by: Optional[Literal["word", "sentence", "passage"]] = "word", + split_by: Optional[Literal["token", "word", "sentence", "passage"]] = "word", split_length: int = 200, split_overlap: int = 0, split_respect_sentence_boundary: bool = True, tokenizer_model_folder: Optional[Union[str, Path]] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken", language: str = "en", id_hash_keys: Optional[List[str]] = None, progress_bar: bool = True, @@ -86,6 +91,9 @@ class PreProcessor(BasePreProcessor): :param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set to True, the individual split will always have complete sentences & the number of words will be <= split_length. + :param tokenizer: Specifies the tokenizer to use if split_by="token". Supported options are "tiktoken" + (for OpenAI's GPT-3.5 and GPT-4) and any HuggingFace tokenizer (e.g. 'bert-base-uncased'). + HuggingFace tokenizers can also be passed directly as an PreTrainedTokenizerBase object. :param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format. Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml" :param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise. @@ -124,6 +132,7 @@ class PreProcessor(BasePreProcessor): self.split_length = split_length self.split_overlap = split_overlap self.split_respect_sentence_boundary = split_respect_sentence_boundary + self.tokenizer = tokenizer self.language = language self.tokenizer_model_folder = tokenizer_model_folder self.print_log: Set[str] = set() @@ -139,10 +148,11 @@ class PreProcessor(BasePreProcessor): clean_header_footer: Optional[bool] = None, clean_empty_lines: Optional[bool] = None, remove_substrings: Optional[List[str]] = None, - split_by: Optional[Literal["word", "sentence", "passage"]] = None, + split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None, split_length: Optional[int] = None, split_overlap: Optional[int] = None, split_respect_sentence_boundary: Optional[bool] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None, id_hash_keys: Optional[List[str]] = None, ) -> List[Document]: """ @@ -167,6 +177,7 @@ class PreProcessor(BasePreProcessor): "split_length": split_length, "split_overlap": split_overlap, "split_respect_sentence_boundary": split_respect_sentence_boundary, + "tokenizer": tokenizer, } if id_hash_keys is None: @@ -219,10 +230,11 @@ class PreProcessor(BasePreProcessor): clean_header_footer: Optional[bool] = None, clean_empty_lines: Optional[bool] = None, remove_substrings: Optional[List[str]] = None, - split_by: Optional[Literal["word", "sentence", "passage"]] = None, + split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None, split_length: Optional[int] = None, split_overlap: Optional[int] = None, split_respect_sentence_boundary: Optional[bool] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None, id_hash_keys: Optional[List[str]] = None, ) -> List[Document]: if remove_substrings is None: @@ -243,6 +255,8 @@ class PreProcessor(BasePreProcessor): split_overlap = self.split_overlap if split_respect_sentence_boundary is None: split_respect_sentence_boundary = self.split_respect_sentence_boundary + if tokenizer is None: + tokenizer = self.tokenizer cleaned_document = self.clean( document=document, @@ -258,6 +272,7 @@ class PreProcessor(BasePreProcessor): split_length=split_length, split_overlap=split_overlap, split_respect_sentence_boundary=split_respect_sentence_boundary, + tokenizer=tokenizer, id_hash_keys=id_hash_keys, ) @@ -332,10 +347,11 @@ class PreProcessor(BasePreProcessor): def split( self, document: Union[dict, Document], - split_by: Optional[Literal["word", "sentence", "passage"]], + split_by: Optional[Literal["token", "word", "sentence", "passage"]], split_length: int, split_overlap: int, split_respect_sentence_boundary: bool, + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None, id_hash_keys: Optional[List[str]] = None, ) -> List[Document]: """Perform document splitting on a single document. This method can split on different units, at different lengths, @@ -359,8 +375,10 @@ class PreProcessor(BasePreProcessor): if not split_length: raise Exception("split_length needs be set when using split_by.") - if split_respect_sentence_boundary and split_by != "word": - raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with split_by='word'.") + if split_respect_sentence_boundary and split_by not in ["word", "token"]: + raise NotImplementedError( + "'split_respect_sentence_boundary=True' is only compatible with split_by='word' or 'token'." + ) if type(document.content) is not str: logger.error("Document content is not of type str. Nothing to split.") @@ -369,13 +387,17 @@ class PreProcessor(BasePreProcessor): text = document.content headlines = document.meta["headlines"] if "headlines" in document.meta else [] - if split_respect_sentence_boundary and split_by == "word": - text_splits, splits_pages, splits_start_idxs = self._split_by_word_respecting_sent_boundary( - text=text, split_length=split_length, split_overlap=split_overlap + if split_respect_sentence_boundary and split_by in ["word", "token"]: + + def split_function(text): + return self._split_tokens(text, tokenizer=tokenizer) if split_by == "token" else text.split() + + text_splits, splits_pages, splits_start_idxs = self._split_into_units_respecting_sent_boundary( + text=text, split_length=split_length, split_overlap=split_overlap, split_function=split_function ) else: # create individual "elements" of passage, sentence, or word - elements, split_at = self._split_into_units(text=text, split_by=split_by) + elements, split_at = self._split_into_units(text=text, split_by=split_by, tokenizer=tokenizer) # concatenate individual elements based on split_length & split_stride text_splits, splits_pages, splits_start_idxs = self._concatenate_units( @@ -467,15 +489,15 @@ class PreProcessor(BasePreProcessor): cleaned_text = text.replace(substring, "") return cleaned_text, headlines - def _split_by_word_respecting_sent_boundary( - self, text: str, split_length: int, split_overlap: int + def _split_into_units_respecting_sent_boundary( + self, text: str, split_length: int, split_overlap: int, split_function: Callable ) -> Tuple[List[str], List[int], List[int]]: """ Splits the text into parts of split_length words while respecting sentence boundaries. """ sentences = self._split_sentences(text) - word_count_slice = 0 + unit_count_slice = 0 cur_page = 1 cur_start_idx = 0 splits_pages = [] @@ -483,17 +505,17 @@ class PreProcessor(BasePreProcessor): splits_start_idxs = [] current_slice: List[str] = [] for sen in sentences: - word_count_sen = len(sen.split()) + unit_count_sen = len(split_function(sen)) - if word_count_sen > split_length: + if unit_count_sen > split_length: long_sentence_message = ( - "We found one or more sentences whose word count is higher than the split length." + "We found one or more sentences whose split count is higher than the split length." ) if long_sentence_message not in self.print_log: self.print_log.add(long_sentence_message) logger.warning(long_sentence_message) - if word_count_slice + word_count_sen > split_length: + if unit_count_slice + unit_count_sen > split_length: # Number of words exceeds split_length -> save current slice and start a new one if current_slice: list_splits.append(current_slice) @@ -501,13 +523,13 @@ class PreProcessor(BasePreProcessor): splits_start_idxs.append(cur_start_idx) if split_overlap: - processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice( - current_slice, split_length, split_overlap + processed_sents, current_slice, unit_count_slice = self._get_overlap_from_slice( + current_slice, split_length, split_overlap, split_function ) else: processed_sents = current_slice current_slice = [] - word_count_slice = 0 + unit_count_slice = 0 cur_start_idx += len("".join(processed_sents)) @@ -522,7 +544,7 @@ class PreProcessor(BasePreProcessor): cur_page += num_page_breaks current_slice.append(sen) - word_count_slice += word_count_sen + unit_count_slice += unit_count_sen if current_slice: list_splits.append(current_slice) @@ -539,7 +561,7 @@ class PreProcessor(BasePreProcessor): @staticmethod def _get_overlap_from_slice( - current_slice: List[str], split_length: int, split_overlap: int + current_slice: List[str], split_length: int, split_overlap: int, split_function: Callable ) -> Tuple[List[str], List[str], int]: """ Returns a tuple with the following elements: @@ -553,7 +575,7 @@ class PreProcessor(BasePreProcessor): current_slice_copy = deepcopy(current_slice) # Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence for idx, s in reversed(list(enumerate(current_slice))[1:]): - sen_len = len(s.split()) + sen_len = len(split_function(s)) if word_count_overlap < split_overlap and sen_len < split_length: overlap.append(s) word_count_overlap += sen_len @@ -566,7 +588,7 @@ class PreProcessor(BasePreProcessor): return processed_sents, next_slice, word_count_slice - def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]: + def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[List[str], str]: if split_by == "passage": elements = text.split("\n\n") split_at = "\n\n" @@ -576,8 +598,13 @@ class PreProcessor(BasePreProcessor): elif split_by == "word": elements = text.split(" ") split_at = " " + elif split_by == "token": + elements = self._split_tokens(text, tokenizer) + split_at = "" else: - raise NotImplementedError("PreProcessor only supports 'passage', 'sentence' or 'word' split_by options.") + raise NotImplementedError( + "PreProcessor only supports 'passage', 'sentence', 'word' or 'token' split_by options." + ) return elements, split_at @@ -823,6 +850,35 @@ class PreProcessor(BasePreProcessor): sentences = sentence_tokenizer.tokenize(text) return sentences + def _split_tokens(self, text: str, tokenizer: Any) -> List[str]: + if tokenizer == "tiktoken": + tiktoken_import.check() + enc = tiktoken.get_encoding("cl100k_base") # tiktoken is reversible and lossless + integer_tokens = enc.encode(text, disallowed_special=()) + elements = [enc.decode_single_token_bytes(token).decode(errors="ignore") for token in integer_tokens] + return elements + if isinstance(tokenizer, str): + transformers_import.check() + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + except Exception: + raise ValueError( + f"Could not load tokenizer '{tokenizer}' from HuggingFace model hub. " + f"Please make sure that the tokenizer is correct and exists." + ) + if isinstance(tokenizer, PreTrainedTokenizerBase): + encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False) + elements = [] + for i in range(l := len(encoded.offset_mapping)): + start_current = encoded.offset_mapping[i][0] + start_next = encoded.offset_mapping[i + 1][0] if i < l - 1 else len(text) + elements.append(text[start_current:start_next]) + return elements + raise ValueError( + f"Unsupported tokenizer specification {tokenizer}. " + f"Please provide either the string 'tiktoken' or a HuggingFace tokenizer (PreTrainedTokenizerBase)." + ) + def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer": # Try to load a custom model from 'tokenizer_model_path' if self.tokenizer_model_folder is not None: diff --git a/releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml b/releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml new file mode 100644 index 000000000..643793b7b --- /dev/null +++ b/releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml @@ -0,0 +1,2 @@ +features: + - Add `split_length` by token in PreProcessor diff --git a/test/nodes/test_preprocessor.py b/test/nodes/test_preprocessor.py index 222841a54..604654d57 100644 --- a/test/nodes/test_preprocessor.py +++ b/test/nodes/test_preprocessor.py @@ -1,12 +1,15 @@ import sys from pathlib import Path from typing import Any, Optional, List +from unittest import mock from unittest.mock import Mock import nltk.data import pytest +import tiktoken from _pytest.monkeypatch import MonkeyPatch from _pytest.tmpdir import TempPathFactory +from transformers import AutoTokenizer, PreTrainedTokenizerBase from haystack import Document from haystack.nodes.file_converter.pdf import PDFToTextConverter @@ -84,6 +87,44 @@ def patched_nltk_data_path(module_tmp_dir: Path, monkeypatch: MonkeyPatch, tmp_p return tmp_path +@pytest.fixture +def mock_huggingface_tokenizer(): + class MockTokenizer(PreTrainedTokenizerBase): + """Simple Mock tokenizer splitting the text into 2-character chunks.""" + + @staticmethod + def tokenize(text, **kwargs): + return [text[i : i + 2] for i in range(0, len(text), 2)] + + @staticmethod + def encode_plus(text, **kwargs): + return Mock(offset_mapping=[(i, min(len(text), i + 2)) for i in range(0, len(text), 2)]) + + mock_tokenizer_instance = MockTokenizer() + + with mock.patch.object(AutoTokenizer, "from_pretrained", return_value=mock_tokenizer_instance): + yield mock_tokenizer_instance + + +@pytest.fixture +def mock_tiktoken_tokenizer(): + class MockTokenizer: + """Simple Mock tokenizer "encoding" the text into a 0 for every 5-character chunk.""" + + @staticmethod + def encode(text, **kwargs): + return [0 for i in range(0, len(text), 5)] + + @staticmethod + def decode_single_token_bytes(token): + return b"mock " + + mock_tokenizer_instance = MockTokenizer() + + with mock.patch.object(tiktoken, "get_encoding", return_value=mock_tokenizer_instance): + yield mock_tokenizer_instance + + @pytest.mark.unit @pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)]) def test_preprocess_sentence_split(split_length_and_results): @@ -178,6 +219,75 @@ def test_preprocess_word_split(): assert len(documents) == 15 +@pytest.mark.unit +def test_preprocess_tiktoken_token_split(mock_tiktoken_tokenizer): + raw_docs = [ + "This is a document. It has two sentences and eleven words.", + "This is a document with a long sentence (longer than my split length), it has seventeen words.", + ] + docs = [Document(content=content) for content in raw_docs] + split_length = 10 + token_split_docs_not_respecting_sentences = PreProcessor( + split_by="token", + split_length=split_length, + split_respect_sentence_boundary=False, + split_overlap=0, + tokenizer="tiktoken", + ).process(docs) + assert len(token_split_docs_not_respecting_sentences) == 4 + enc = tiktoken.get_encoding("cl100k_base") + split_documents_encoded = [ + enc.encode(d.content, allowed_special="all", disallowed_special=()) + for d in token_split_docs_not_respecting_sentences + ] + assert all([len(d) <= split_length for d in split_documents_encoded]) + token_split_docs_respecting_sentences = PreProcessor( + split_by="token", + split_length=split_length, + split_respect_sentence_boundary=True, + split_overlap=0, + tokenizer="tiktoken", + ).process(docs) + assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences + + +@pytest.mark.unit +def test_preprocess_huggingface_token_split(mock_huggingface_tokenizer): + raw_docs = [ + "This is a document. It has two sentences and eleven words.", + "This is a document with a long sentence (longer than my split length), it has seventeen words.", + ] + docs = [Document(content=content) for content in raw_docs] + split_length = 10 + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + token_split_docs_not_respecting_sentences = PreProcessor( + split_by="token", + split_length=split_length, + split_respect_sentence_boundary=False, + split_overlap=0, + tokenizer=tokenizer, + ).process(docs) + assert len(token_split_docs_not_respecting_sentences) == 8 + split_documents_retokenized = [tokenizer.tokenize(d.content) for d in token_split_docs_not_respecting_sentences] + assert all([len(d) <= split_length for d in split_documents_retokenized]) + token_split_docs_respecting_sentences = PreProcessor( + split_by="token", + split_length=split_length, + split_respect_sentence_boundary=True, + split_overlap=0, + tokenizer=tokenizer, + ).process(docs) + assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences + token_split_docs_not_respecting_sentences_instantiate_by_name = PreProcessor( + split_by="token", + split_length=split_length, + split_respect_sentence_boundary=False, + split_overlap=0, + tokenizer="bert-base-uncased", + ).process(docs) + assert token_split_docs_not_respecting_sentences == token_split_docs_not_respecting_sentences_instantiate_by_name + + @pytest.mark.unit @pytest.mark.parametrize("split_length_and_results", [(1, 3), (2, 2)]) def test_preprocess_passage_split(split_length_and_results):