mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-26 17:36:34 +00:00
feat: PreProcessor split by token (tiktoken & Hugging Face) (#5276)
* #4983 implemented split by token for tiktoken tokenizer * #4983 added unit test for tiktoken splitting * #4983 implemented and added a test for splitting documents with HuggingFace tokenizer * #4983 added support for passing HF model names (instead of objects) and added an example to the HF token splitting test * mocked HTTP model loading in unit tests, fixed pylint error * fix lossy tokenizers splitting, use LazyImport, ignore UnicodeEncodeError for tiktoken * reno * rename reno file --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
e04a1f16bb
commit
a492771b4d
@ -2,6 +2,8 @@ from typing import List, Optional, Union, Literal
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.base import BaseComponent
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
|
|
||||||
@ -17,10 +19,11 @@ class BasePreProcessor(BaseComponent):
|
|||||||
clean_header_footer: Optional[bool] = False,
|
clean_header_footer: Optional[bool] = False,
|
||||||
clean_empty_lines: Optional[bool] = True,
|
clean_empty_lines: Optional[bool] = True,
|
||||||
remove_substrings: Optional[List[str]] = None,
|
remove_substrings: Optional[List[str]] = None,
|
||||||
split_by: Literal["word", "sentence", "passage", None] = "word",
|
split_by: Literal["token", "word", "sentence", "passage", None] = "word",
|
||||||
split_length: Optional[int] = 1000,
|
split_length: Optional[int] = 1000,
|
||||||
split_overlap: Optional[int] = None,
|
split_overlap: Optional[int] = None,
|
||||||
split_respect_sentence_boundary: Optional[bool] = True,
|
split_respect_sentence_boundary: Optional[bool] = True,
|
||||||
|
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
|
||||||
id_hash_keys: Optional[List[str]] = None,
|
id_hash_keys: Optional[List[str]] = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
@ -44,10 +47,11 @@ class BasePreProcessor(BaseComponent):
|
|||||||
def split(
|
def split(
|
||||||
self,
|
self,
|
||||||
document: Union[dict, Document],
|
document: Union[dict, Document],
|
||||||
split_by: Literal["word", "sentence", "passage", None],
|
split_by: Literal["token", "word", "sentence", "passage", None],
|
||||||
split_length: int,
|
split_length: int,
|
||||||
split_overlap: int,
|
split_overlap: int,
|
||||||
split_respect_sentence_boundary: bool,
|
split_respect_sentence_boundary: bool,
|
||||||
|
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -57,10 +61,11 @@ class BasePreProcessor(BaseComponent):
|
|||||||
clean_whitespace: Optional[bool] = None,
|
clean_whitespace: Optional[bool] = None,
|
||||||
clean_header_footer: Optional[bool] = None,
|
clean_header_footer: Optional[bool] = None,
|
||||||
clean_empty_lines: Optional[bool] = None,
|
clean_empty_lines: Optional[bool] = None,
|
||||||
split_by: Literal["word", "sentence", "passage", None] = None,
|
split_by: Literal["token", "word", "sentence", "passage", None] = None,
|
||||||
split_length: Optional[int] = None,
|
split_length: Optional[int] = None,
|
||||||
split_overlap: Optional[int] = None,
|
split_overlap: Optional[int] = None,
|
||||||
split_respect_sentence_boundary: Optional[bool] = None,
|
split_respect_sentence_boundary: Optional[bool] = None,
|
||||||
|
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||||
id_hash_keys: Optional[List[str]] = None,
|
id_hash_keys: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
processed_documents = self.process(
|
processed_documents = self.process(
|
||||||
@ -83,10 +88,11 @@ class BasePreProcessor(BaseComponent):
|
|||||||
clean_whitespace: Optional[bool] = None,
|
clean_whitespace: Optional[bool] = None,
|
||||||
clean_header_footer: Optional[bool] = None,
|
clean_header_footer: Optional[bool] = None,
|
||||||
clean_empty_lines: Optional[bool] = None,
|
clean_empty_lines: Optional[bool] = None,
|
||||||
split_by: Literal["word", "sentence", "passage", None] = None,
|
split_by: Literal["token", "word", "sentence", "passage", None] = None,
|
||||||
split_length: Optional[int] = None,
|
split_length: Optional[int] = None,
|
||||||
split_overlap: Optional[int] = None,
|
split_overlap: Optional[int] = None,
|
||||||
split_respect_sentence_boundary: Optional[bool] = None,
|
split_respect_sentence_boundary: Optional[bool] = None,
|
||||||
|
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||||
id_hash_keys: Optional[List[str]] = None,
|
id_hash_keys: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
return self.run(
|
return self.run(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal
|
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal, Callable, Any
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -17,14 +17,18 @@ from haystack.errors import HaystackError
|
|||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
|
|
||||||
|
with LazyImport("Run 'pip install transformers'") as transformers_import:
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
|
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
|
|
||||||
iso639_to_nltk = {
|
iso639_to_nltk = {
|
||||||
"ru": "russian",
|
"ru": "russian",
|
||||||
"sl": "slovene",
|
"sl": "slovene",
|
||||||
@ -55,11 +59,12 @@ class PreProcessor(BasePreProcessor):
|
|||||||
clean_header_footer: bool = False,
|
clean_header_footer: bool = False,
|
||||||
clean_empty_lines: bool = True,
|
clean_empty_lines: bool = True,
|
||||||
remove_substrings: Optional[List[str]] = None,
|
remove_substrings: Optional[List[str]] = None,
|
||||||
split_by: Optional[Literal["word", "sentence", "passage"]] = "word",
|
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = "word",
|
||||||
split_length: int = 200,
|
split_length: int = 200,
|
||||||
split_overlap: int = 0,
|
split_overlap: int = 0,
|
||||||
split_respect_sentence_boundary: bool = True,
|
split_respect_sentence_boundary: bool = True,
|
||||||
tokenizer_model_folder: Optional[Union[str, Path]] = None,
|
tokenizer_model_folder: Optional[Union[str, Path]] = None,
|
||||||
|
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
|
||||||
language: str = "en",
|
language: str = "en",
|
||||||
id_hash_keys: Optional[List[str]] = None,
|
id_hash_keys: Optional[List[str]] = None,
|
||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
@ -86,6 +91,9 @@ class PreProcessor(BasePreProcessor):
|
|||||||
:param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set
|
:param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set
|
||||||
to True, the individual split will always have complete sentences &
|
to True, the individual split will always have complete sentences &
|
||||||
the number of words will be <= split_length.
|
the number of words will be <= split_length.
|
||||||
|
:param tokenizer: Specifies the tokenizer to use if split_by="token". Supported options are "tiktoken"
|
||||||
|
(for OpenAI's GPT-3.5 and GPT-4) and any HuggingFace tokenizer (e.g. 'bert-base-uncased').
|
||||||
|
HuggingFace tokenizers can also be passed directly as an PreTrainedTokenizerBase object.
|
||||||
:param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format.
|
:param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format.
|
||||||
Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml"
|
Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml"
|
||||||
:param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise.
|
:param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise.
|
||||||
@ -124,6 +132,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
self.split_length = split_length
|
self.split_length = split_length
|
||||||
self.split_overlap = split_overlap
|
self.split_overlap = split_overlap
|
||||||
self.split_respect_sentence_boundary = split_respect_sentence_boundary
|
self.split_respect_sentence_boundary = split_respect_sentence_boundary
|
||||||
|
self.tokenizer = tokenizer
|
||||||
self.language = language
|
self.language = language
|
||||||
self.tokenizer_model_folder = tokenizer_model_folder
|
self.tokenizer_model_folder = tokenizer_model_folder
|
||||||
self.print_log: Set[str] = set()
|
self.print_log: Set[str] = set()
|
||||||
@ -139,10 +148,11 @@ class PreProcessor(BasePreProcessor):
|
|||||||
clean_header_footer: Optional[bool] = None,
|
clean_header_footer: Optional[bool] = None,
|
||||||
clean_empty_lines: Optional[bool] = None,
|
clean_empty_lines: Optional[bool] = None,
|
||||||
remove_substrings: Optional[List[str]] = None,
|
remove_substrings: Optional[List[str]] = None,
|
||||||
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
|
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
|
||||||
split_length: Optional[int] = None,
|
split_length: Optional[int] = None,
|
||||||
split_overlap: Optional[int] = None,
|
split_overlap: Optional[int] = None,
|
||||||
split_respect_sentence_boundary: Optional[bool] = None,
|
split_respect_sentence_boundary: Optional[bool] = None,
|
||||||
|
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||||
id_hash_keys: Optional[List[str]] = None,
|
id_hash_keys: Optional[List[str]] = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
@ -167,6 +177,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
"split_length": split_length,
|
"split_length": split_length,
|
||||||
"split_overlap": split_overlap,
|
"split_overlap": split_overlap,
|
||||||
"split_respect_sentence_boundary": split_respect_sentence_boundary,
|
"split_respect_sentence_boundary": split_respect_sentence_boundary,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
if id_hash_keys is None:
|
if id_hash_keys is None:
|
||||||
@ -219,10 +230,11 @@ class PreProcessor(BasePreProcessor):
|
|||||||
clean_header_footer: Optional[bool] = None,
|
clean_header_footer: Optional[bool] = None,
|
||||||
clean_empty_lines: Optional[bool] = None,
|
clean_empty_lines: Optional[bool] = None,
|
||||||
remove_substrings: Optional[List[str]] = None,
|
remove_substrings: Optional[List[str]] = None,
|
||||||
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
|
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
|
||||||
split_length: Optional[int] = None,
|
split_length: Optional[int] = None,
|
||||||
split_overlap: Optional[int] = None,
|
split_overlap: Optional[int] = None,
|
||||||
split_respect_sentence_boundary: Optional[bool] = None,
|
split_respect_sentence_boundary: Optional[bool] = None,
|
||||||
|
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||||
id_hash_keys: Optional[List[str]] = None,
|
id_hash_keys: Optional[List[str]] = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
if remove_substrings is None:
|
if remove_substrings is None:
|
||||||
@ -243,6 +255,8 @@ class PreProcessor(BasePreProcessor):
|
|||||||
split_overlap = self.split_overlap
|
split_overlap = self.split_overlap
|
||||||
if split_respect_sentence_boundary is None:
|
if split_respect_sentence_boundary is None:
|
||||||
split_respect_sentence_boundary = self.split_respect_sentence_boundary
|
split_respect_sentence_boundary = self.split_respect_sentence_boundary
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
|
||||||
cleaned_document = self.clean(
|
cleaned_document = self.clean(
|
||||||
document=document,
|
document=document,
|
||||||
@ -258,6 +272,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
split_length=split_length,
|
split_length=split_length,
|
||||||
split_overlap=split_overlap,
|
split_overlap=split_overlap,
|
||||||
split_respect_sentence_boundary=split_respect_sentence_boundary,
|
split_respect_sentence_boundary=split_respect_sentence_boundary,
|
||||||
|
tokenizer=tokenizer,
|
||||||
id_hash_keys=id_hash_keys,
|
id_hash_keys=id_hash_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -332,10 +347,11 @@ class PreProcessor(BasePreProcessor):
|
|||||||
def split(
|
def split(
|
||||||
self,
|
self,
|
||||||
document: Union[dict, Document],
|
document: Union[dict, Document],
|
||||||
split_by: Optional[Literal["word", "sentence", "passage"]],
|
split_by: Optional[Literal["token", "word", "sentence", "passage"]],
|
||||||
split_length: int,
|
split_length: int,
|
||||||
split_overlap: int,
|
split_overlap: int,
|
||||||
split_respect_sentence_boundary: bool,
|
split_respect_sentence_boundary: bool,
|
||||||
|
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||||
id_hash_keys: Optional[List[str]] = None,
|
id_hash_keys: Optional[List[str]] = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
|
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
|
||||||
@ -359,8 +375,10 @@ class PreProcessor(BasePreProcessor):
|
|||||||
if not split_length:
|
if not split_length:
|
||||||
raise Exception("split_length needs be set when using split_by.")
|
raise Exception("split_length needs be set when using split_by.")
|
||||||
|
|
||||||
if split_respect_sentence_boundary and split_by != "word":
|
if split_respect_sentence_boundary and split_by not in ["word", "token"]:
|
||||||
raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with split_by='word'.")
|
raise NotImplementedError(
|
||||||
|
"'split_respect_sentence_boundary=True' is only compatible with split_by='word' or 'token'."
|
||||||
|
)
|
||||||
|
|
||||||
if type(document.content) is not str:
|
if type(document.content) is not str:
|
||||||
logger.error("Document content is not of type str. Nothing to split.")
|
logger.error("Document content is not of type str. Nothing to split.")
|
||||||
@ -369,13 +387,17 @@ class PreProcessor(BasePreProcessor):
|
|||||||
text = document.content
|
text = document.content
|
||||||
headlines = document.meta["headlines"] if "headlines" in document.meta else []
|
headlines = document.meta["headlines"] if "headlines" in document.meta else []
|
||||||
|
|
||||||
if split_respect_sentence_boundary and split_by == "word":
|
if split_respect_sentence_boundary and split_by in ["word", "token"]:
|
||||||
text_splits, splits_pages, splits_start_idxs = self._split_by_word_respecting_sent_boundary(
|
|
||||||
text=text, split_length=split_length, split_overlap=split_overlap
|
def split_function(text):
|
||||||
|
return self._split_tokens(text, tokenizer=tokenizer) if split_by == "token" else text.split()
|
||||||
|
|
||||||
|
text_splits, splits_pages, splits_start_idxs = self._split_into_units_respecting_sent_boundary(
|
||||||
|
text=text, split_length=split_length, split_overlap=split_overlap, split_function=split_function
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# create individual "elements" of passage, sentence, or word
|
# create individual "elements" of passage, sentence, or word
|
||||||
elements, split_at = self._split_into_units(text=text, split_by=split_by)
|
elements, split_at = self._split_into_units(text=text, split_by=split_by, tokenizer=tokenizer)
|
||||||
|
|
||||||
# concatenate individual elements based on split_length & split_stride
|
# concatenate individual elements based on split_length & split_stride
|
||||||
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
|
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
|
||||||
@ -467,15 +489,15 @@ class PreProcessor(BasePreProcessor):
|
|||||||
cleaned_text = text.replace(substring, "")
|
cleaned_text = text.replace(substring, "")
|
||||||
return cleaned_text, headlines
|
return cleaned_text, headlines
|
||||||
|
|
||||||
def _split_by_word_respecting_sent_boundary(
|
def _split_into_units_respecting_sent_boundary(
|
||||||
self, text: str, split_length: int, split_overlap: int
|
self, text: str, split_length: int, split_overlap: int, split_function: Callable
|
||||||
) -> Tuple[List[str], List[int], List[int]]:
|
) -> Tuple[List[str], List[int], List[int]]:
|
||||||
"""
|
"""
|
||||||
Splits the text into parts of split_length words while respecting sentence boundaries.
|
Splits the text into parts of split_length words while respecting sentence boundaries.
|
||||||
"""
|
"""
|
||||||
sentences = self._split_sentences(text)
|
sentences = self._split_sentences(text)
|
||||||
|
|
||||||
word_count_slice = 0
|
unit_count_slice = 0
|
||||||
cur_page = 1
|
cur_page = 1
|
||||||
cur_start_idx = 0
|
cur_start_idx = 0
|
||||||
splits_pages = []
|
splits_pages = []
|
||||||
@ -483,17 +505,17 @@ class PreProcessor(BasePreProcessor):
|
|||||||
splits_start_idxs = []
|
splits_start_idxs = []
|
||||||
current_slice: List[str] = []
|
current_slice: List[str] = []
|
||||||
for sen in sentences:
|
for sen in sentences:
|
||||||
word_count_sen = len(sen.split())
|
unit_count_sen = len(split_function(sen))
|
||||||
|
|
||||||
if word_count_sen > split_length:
|
if unit_count_sen > split_length:
|
||||||
long_sentence_message = (
|
long_sentence_message = (
|
||||||
"We found one or more sentences whose word count is higher than the split length."
|
"We found one or more sentences whose split count is higher than the split length."
|
||||||
)
|
)
|
||||||
if long_sentence_message not in self.print_log:
|
if long_sentence_message not in self.print_log:
|
||||||
self.print_log.add(long_sentence_message)
|
self.print_log.add(long_sentence_message)
|
||||||
logger.warning(long_sentence_message)
|
logger.warning(long_sentence_message)
|
||||||
|
|
||||||
if word_count_slice + word_count_sen > split_length:
|
if unit_count_slice + unit_count_sen > split_length:
|
||||||
# Number of words exceeds split_length -> save current slice and start a new one
|
# Number of words exceeds split_length -> save current slice and start a new one
|
||||||
if current_slice:
|
if current_slice:
|
||||||
list_splits.append(current_slice)
|
list_splits.append(current_slice)
|
||||||
@ -501,13 +523,13 @@ class PreProcessor(BasePreProcessor):
|
|||||||
splits_start_idxs.append(cur_start_idx)
|
splits_start_idxs.append(cur_start_idx)
|
||||||
|
|
||||||
if split_overlap:
|
if split_overlap:
|
||||||
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
|
processed_sents, current_slice, unit_count_slice = self._get_overlap_from_slice(
|
||||||
current_slice, split_length, split_overlap
|
current_slice, split_length, split_overlap, split_function
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
processed_sents = current_slice
|
processed_sents = current_slice
|
||||||
current_slice = []
|
current_slice = []
|
||||||
word_count_slice = 0
|
unit_count_slice = 0
|
||||||
|
|
||||||
cur_start_idx += len("".join(processed_sents))
|
cur_start_idx += len("".join(processed_sents))
|
||||||
|
|
||||||
@ -522,7 +544,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
cur_page += num_page_breaks
|
cur_page += num_page_breaks
|
||||||
|
|
||||||
current_slice.append(sen)
|
current_slice.append(sen)
|
||||||
word_count_slice += word_count_sen
|
unit_count_slice += unit_count_sen
|
||||||
|
|
||||||
if current_slice:
|
if current_slice:
|
||||||
list_splits.append(current_slice)
|
list_splits.append(current_slice)
|
||||||
@ -539,7 +561,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_overlap_from_slice(
|
def _get_overlap_from_slice(
|
||||||
current_slice: List[str], split_length: int, split_overlap: int
|
current_slice: List[str], split_length: int, split_overlap: int, split_function: Callable
|
||||||
) -> Tuple[List[str], List[str], int]:
|
) -> Tuple[List[str], List[str], int]:
|
||||||
"""
|
"""
|
||||||
Returns a tuple with the following elements:
|
Returns a tuple with the following elements:
|
||||||
@ -553,7 +575,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
current_slice_copy = deepcopy(current_slice)
|
current_slice_copy = deepcopy(current_slice)
|
||||||
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
|
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
|
||||||
for idx, s in reversed(list(enumerate(current_slice))[1:]):
|
for idx, s in reversed(list(enumerate(current_slice))[1:]):
|
||||||
sen_len = len(s.split())
|
sen_len = len(split_function(s))
|
||||||
if word_count_overlap < split_overlap and sen_len < split_length:
|
if word_count_overlap < split_overlap and sen_len < split_length:
|
||||||
overlap.append(s)
|
overlap.append(s)
|
||||||
word_count_overlap += sen_len
|
word_count_overlap += sen_len
|
||||||
@ -566,7 +588,7 @@ class PreProcessor(BasePreProcessor):
|
|||||||
|
|
||||||
return processed_sents, next_slice, word_count_slice
|
return processed_sents, next_slice, word_count_slice
|
||||||
|
|
||||||
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
|
def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[List[str], str]:
|
||||||
if split_by == "passage":
|
if split_by == "passage":
|
||||||
elements = text.split("\n\n")
|
elements = text.split("\n\n")
|
||||||
split_at = "\n\n"
|
split_at = "\n\n"
|
||||||
@ -576,8 +598,13 @@ class PreProcessor(BasePreProcessor):
|
|||||||
elif split_by == "word":
|
elif split_by == "word":
|
||||||
elements = text.split(" ")
|
elements = text.split(" ")
|
||||||
split_at = " "
|
split_at = " "
|
||||||
|
elif split_by == "token":
|
||||||
|
elements = self._split_tokens(text, tokenizer)
|
||||||
|
split_at = ""
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("PreProcessor only supports 'passage', 'sentence' or 'word' split_by options.")
|
raise NotImplementedError(
|
||||||
|
"PreProcessor only supports 'passage', 'sentence', 'word' or 'token' split_by options."
|
||||||
|
)
|
||||||
|
|
||||||
return elements, split_at
|
return elements, split_at
|
||||||
|
|
||||||
@ -823,6 +850,35 @@ class PreProcessor(BasePreProcessor):
|
|||||||
sentences = sentence_tokenizer.tokenize(text)
|
sentences = sentence_tokenizer.tokenize(text)
|
||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
|
def _split_tokens(self, text: str, tokenizer: Any) -> List[str]:
|
||||||
|
if tokenizer == "tiktoken":
|
||||||
|
tiktoken_import.check()
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base") # tiktoken is reversible and lossless
|
||||||
|
integer_tokens = enc.encode(text, disallowed_special=())
|
||||||
|
elements = [enc.decode_single_token_bytes(token).decode(errors="ignore") for token in integer_tokens]
|
||||||
|
return elements
|
||||||
|
if isinstance(tokenizer, str):
|
||||||
|
transformers_import.check()
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not load tokenizer '{tokenizer}' from HuggingFace model hub. "
|
||||||
|
f"Please make sure that the tokenizer is correct and exists."
|
||||||
|
)
|
||||||
|
if isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||||
|
encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||||
|
elements = []
|
||||||
|
for i in range(l := len(encoded.offset_mapping)):
|
||||||
|
start_current = encoded.offset_mapping[i][0]
|
||||||
|
start_next = encoded.offset_mapping[i + 1][0] if i < l - 1 else len(text)
|
||||||
|
elements.append(text[start_current:start_next])
|
||||||
|
return elements
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported tokenizer specification {tokenizer}. "
|
||||||
|
f"Please provide either the string 'tiktoken' or a HuggingFace tokenizer (PreTrainedTokenizerBase)."
|
||||||
|
)
|
||||||
|
|
||||||
def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
|
def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
|
||||||
# Try to load a custom model from 'tokenizer_model_path'
|
# Try to load a custom model from 'tokenizer_model_path'
|
||||||
if self.tokenizer_model_folder is not None:
|
if self.tokenizer_model_folder is not None:
|
||||||
|
2
releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml
Normal file
2
releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
features:
|
||||||
|
- Add `split_length` by token in PreProcessor
|
@ -1,12 +1,15 @@
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, List
|
from typing import Any, Optional, List
|
||||||
|
from unittest import mock
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import nltk.data
|
import nltk.data
|
||||||
import pytest
|
import pytest
|
||||||
|
import tiktoken
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from _pytest.tmpdir import TempPathFactory
|
from _pytest.tmpdir import TempPathFactory
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.nodes.file_converter.pdf import PDFToTextConverter
|
from haystack.nodes.file_converter.pdf import PDFToTextConverter
|
||||||
@ -84,6 +87,44 @@ def patched_nltk_data_path(module_tmp_dir: Path, monkeypatch: MonkeyPatch, tmp_p
|
|||||||
return tmp_path
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_huggingface_tokenizer():
|
||||||
|
class MockTokenizer(PreTrainedTokenizerBase):
|
||||||
|
"""Simple Mock tokenizer splitting the text into 2-character chunks."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tokenize(text, **kwargs):
|
||||||
|
return [text[i : i + 2] for i in range(0, len(text), 2)]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_plus(text, **kwargs):
|
||||||
|
return Mock(offset_mapping=[(i, min(len(text), i + 2)) for i in range(0, len(text), 2)])
|
||||||
|
|
||||||
|
mock_tokenizer_instance = MockTokenizer()
|
||||||
|
|
||||||
|
with mock.patch.object(AutoTokenizer, "from_pretrained", return_value=mock_tokenizer_instance):
|
||||||
|
yield mock_tokenizer_instance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tiktoken_tokenizer():
|
||||||
|
class MockTokenizer:
|
||||||
|
"""Simple Mock tokenizer "encoding" the text into a 0 for every 5-character chunk."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode(text, **kwargs):
|
||||||
|
return [0 for i in range(0, len(text), 5)]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_single_token_bytes(token):
|
||||||
|
return b"mock "
|
||||||
|
|
||||||
|
mock_tokenizer_instance = MockTokenizer()
|
||||||
|
|
||||||
|
with mock.patch.object(tiktoken, "get_encoding", return_value=mock_tokenizer_instance):
|
||||||
|
yield mock_tokenizer_instance
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)])
|
@pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)])
|
||||||
def test_preprocess_sentence_split(split_length_and_results):
|
def test_preprocess_sentence_split(split_length_and_results):
|
||||||
@ -178,6 +219,75 @@ def test_preprocess_word_split():
|
|||||||
assert len(documents) == 15
|
assert len(documents) == 15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_preprocess_tiktoken_token_split(mock_tiktoken_tokenizer):
|
||||||
|
raw_docs = [
|
||||||
|
"This is a document. It has two sentences and eleven words.",
|
||||||
|
"This is a document with a long sentence (longer than my split length), it has seventeen words.",
|
||||||
|
]
|
||||||
|
docs = [Document(content=content) for content in raw_docs]
|
||||||
|
split_length = 10
|
||||||
|
token_split_docs_not_respecting_sentences = PreProcessor(
|
||||||
|
split_by="token",
|
||||||
|
split_length=split_length,
|
||||||
|
split_respect_sentence_boundary=False,
|
||||||
|
split_overlap=0,
|
||||||
|
tokenizer="tiktoken",
|
||||||
|
).process(docs)
|
||||||
|
assert len(token_split_docs_not_respecting_sentences) == 4
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
split_documents_encoded = [
|
||||||
|
enc.encode(d.content, allowed_special="all", disallowed_special=())
|
||||||
|
for d in token_split_docs_not_respecting_sentences
|
||||||
|
]
|
||||||
|
assert all([len(d) <= split_length for d in split_documents_encoded])
|
||||||
|
token_split_docs_respecting_sentences = PreProcessor(
|
||||||
|
split_by="token",
|
||||||
|
split_length=split_length,
|
||||||
|
split_respect_sentence_boundary=True,
|
||||||
|
split_overlap=0,
|
||||||
|
tokenizer="tiktoken",
|
||||||
|
).process(docs)
|
||||||
|
assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_preprocess_huggingface_token_split(mock_huggingface_tokenizer):
|
||||||
|
raw_docs = [
|
||||||
|
"This is a document. It has two sentences and eleven words.",
|
||||||
|
"This is a document with a long sentence (longer than my split length), it has seventeen words.",
|
||||||
|
]
|
||||||
|
docs = [Document(content=content) for content in raw_docs]
|
||||||
|
split_length = 10
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
token_split_docs_not_respecting_sentences = PreProcessor(
|
||||||
|
split_by="token",
|
||||||
|
split_length=split_length,
|
||||||
|
split_respect_sentence_boundary=False,
|
||||||
|
split_overlap=0,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
).process(docs)
|
||||||
|
assert len(token_split_docs_not_respecting_sentences) == 8
|
||||||
|
split_documents_retokenized = [tokenizer.tokenize(d.content) for d in token_split_docs_not_respecting_sentences]
|
||||||
|
assert all([len(d) <= split_length for d in split_documents_retokenized])
|
||||||
|
token_split_docs_respecting_sentences = PreProcessor(
|
||||||
|
split_by="token",
|
||||||
|
split_length=split_length,
|
||||||
|
split_respect_sentence_boundary=True,
|
||||||
|
split_overlap=0,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
).process(docs)
|
||||||
|
assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences
|
||||||
|
token_split_docs_not_respecting_sentences_instantiate_by_name = PreProcessor(
|
||||||
|
split_by="token",
|
||||||
|
split_length=split_length,
|
||||||
|
split_respect_sentence_boundary=False,
|
||||||
|
split_overlap=0,
|
||||||
|
tokenizer="bert-base-uncased",
|
||||||
|
).process(docs)
|
||||||
|
assert token_split_docs_not_respecting_sentences == token_split_docs_not_respecting_sentences_instantiate_by_name
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@pytest.mark.parametrize("split_length_and_results", [(1, 3), (2, 2)])
|
@pytest.mark.parametrize("split_length_and_results", [(1, 3), (2, 2)])
|
||||||
def test_preprocess_passage_split(split_length_and_results):
|
def test_preprocess_passage_split(split_length_and_results):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user