mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-26 17:36:34 +00:00
feat: PreProcessor split by token (tiktoken & Hugging Face) (#5276)
* #4983 implemented split by token for tiktoken tokenizer * #4983 added unit test for tiktoken splitting * #4983 implemented and added a test for splitting documents with HuggingFace tokenizer * #4983 added support for passing HF model names (instead of objects) and added an example to the HF token splitting test * mocked HTTP model loading in unit tests, fixed pylint error * fix lossy tokenizers splitting, use LazyImport, ignore UnicodeEncodeError for tiktoken * reno * rename reno file --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
e04a1f16bb
commit
a492771b4d
@ -2,6 +2,8 @@ from typing import List, Optional, Union, Literal
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.schema import Document
|
||||
|
||||
@ -17,10 +19,11 @@ class BasePreProcessor(BaseComponent):
|
||||
clean_header_footer: Optional[bool] = False,
|
||||
clean_empty_lines: Optional[bool] = True,
|
||||
remove_substrings: Optional[List[str]] = None,
|
||||
split_by: Literal["word", "sentence", "passage", None] = "word",
|
||||
split_by: Literal["token", "word", "sentence", "passage", None] = "word",
|
||||
split_length: Optional[int] = 1000,
|
||||
split_overlap: Optional[int] = None,
|
||||
split_respect_sentence_boundary: Optional[bool] = True,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -44,10 +47,11 @@ class BasePreProcessor(BaseComponent):
|
||||
def split(
|
||||
self,
|
||||
document: Union[dict, Document],
|
||||
split_by: Literal["word", "sentence", "passage", None],
|
||||
split_by: Literal["token", "word", "sentence", "passage", None],
|
||||
split_length: int,
|
||||
split_overlap: int,
|
||||
split_respect_sentence_boundary: bool,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -57,10 +61,11 @@ class BasePreProcessor(BaseComponent):
|
||||
clean_whitespace: Optional[bool] = None,
|
||||
clean_header_footer: Optional[bool] = None,
|
||||
clean_empty_lines: Optional[bool] = None,
|
||||
split_by: Literal["word", "sentence", "passage", None] = None,
|
||||
split_by: Literal["token", "word", "sentence", "passage", None] = None,
|
||||
split_length: Optional[int] = None,
|
||||
split_overlap: Optional[int] = None,
|
||||
split_respect_sentence_boundary: Optional[bool] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
):
|
||||
processed_documents = self.process(
|
||||
@ -83,10 +88,11 @@ class BasePreProcessor(BaseComponent):
|
||||
clean_whitespace: Optional[bool] = None,
|
||||
clean_header_footer: Optional[bool] = None,
|
||||
clean_empty_lines: Optional[bool] = None,
|
||||
split_by: Literal["word", "sentence", "passage", None] = None,
|
||||
split_by: Literal["token", "word", "sentence", "passage", None] = None,
|
||||
split_length: Optional[int] = None,
|
||||
split_overlap: Optional[int] = None,
|
||||
split_respect_sentence_boundary: Optional[bool] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
):
|
||||
return self.run(
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal
|
||||
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal, Callable, Any
|
||||
|
||||
import logging
|
||||
import re
|
||||
@ -17,14 +17,18 @@ from haystack.errors import HaystackError
|
||||
from haystack.schema import Document
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport("Run 'pip install transformers'") as transformers_import:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
|
||||
import tiktoken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
|
||||
import nltk
|
||||
|
||||
|
||||
iso639_to_nltk = {
|
||||
"ru": "russian",
|
||||
"sl": "slovene",
|
||||
@ -55,11 +59,12 @@ class PreProcessor(BasePreProcessor):
|
||||
clean_header_footer: bool = False,
|
||||
clean_empty_lines: bool = True,
|
||||
remove_substrings: Optional[List[str]] = None,
|
||||
split_by: Optional[Literal["word", "sentence", "passage"]] = "word",
|
||||
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = "word",
|
||||
split_length: int = 200,
|
||||
split_overlap: int = 0,
|
||||
split_respect_sentence_boundary: bool = True,
|
||||
tokenizer_model_folder: Optional[Union[str, Path]] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
|
||||
language: str = "en",
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
progress_bar: bool = True,
|
||||
@ -86,6 +91,9 @@ class PreProcessor(BasePreProcessor):
|
||||
:param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set
|
||||
to True, the individual split will always have complete sentences &
|
||||
the number of words will be <= split_length.
|
||||
:param tokenizer: Specifies the tokenizer to use if split_by="token". Supported options are "tiktoken"
|
||||
(for OpenAI's GPT-3.5 and GPT-4) and any HuggingFace tokenizer (e.g. 'bert-base-uncased').
|
||||
HuggingFace tokenizers can also be passed directly as an PreTrainedTokenizerBase object.
|
||||
:param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format.
|
||||
Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml"
|
||||
:param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise.
|
||||
@ -124,6 +132,7 @@ class PreProcessor(BasePreProcessor):
|
||||
self.split_length = split_length
|
||||
self.split_overlap = split_overlap
|
||||
self.split_respect_sentence_boundary = split_respect_sentence_boundary
|
||||
self.tokenizer = tokenizer
|
||||
self.language = language
|
||||
self.tokenizer_model_folder = tokenizer_model_folder
|
||||
self.print_log: Set[str] = set()
|
||||
@ -139,10 +148,11 @@ class PreProcessor(BasePreProcessor):
|
||||
clean_header_footer: Optional[bool] = None,
|
||||
clean_empty_lines: Optional[bool] = None,
|
||||
remove_substrings: Optional[List[str]] = None,
|
||||
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
|
||||
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
|
||||
split_length: Optional[int] = None,
|
||||
split_overlap: Optional[int] = None,
|
||||
split_respect_sentence_boundary: Optional[bool] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -167,6 +177,7 @@ class PreProcessor(BasePreProcessor):
|
||||
"split_length": split_length,
|
||||
"split_overlap": split_overlap,
|
||||
"split_respect_sentence_boundary": split_respect_sentence_boundary,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
|
||||
if id_hash_keys is None:
|
||||
@ -219,10 +230,11 @@ class PreProcessor(BasePreProcessor):
|
||||
clean_header_footer: Optional[bool] = None,
|
||||
clean_empty_lines: Optional[bool] = None,
|
||||
remove_substrings: Optional[List[str]] = None,
|
||||
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
|
||||
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
|
||||
split_length: Optional[int] = None,
|
||||
split_overlap: Optional[int] = None,
|
||||
split_respect_sentence_boundary: Optional[bool] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
) -> List[Document]:
|
||||
if remove_substrings is None:
|
||||
@ -243,6 +255,8 @@ class PreProcessor(BasePreProcessor):
|
||||
split_overlap = self.split_overlap
|
||||
if split_respect_sentence_boundary is None:
|
||||
split_respect_sentence_boundary = self.split_respect_sentence_boundary
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
cleaned_document = self.clean(
|
||||
document=document,
|
||||
@ -258,6 +272,7 @@ class PreProcessor(BasePreProcessor):
|
||||
split_length=split_length,
|
||||
split_overlap=split_overlap,
|
||||
split_respect_sentence_boundary=split_respect_sentence_boundary,
|
||||
tokenizer=tokenizer,
|
||||
id_hash_keys=id_hash_keys,
|
||||
)
|
||||
|
||||
@ -332,10 +347,11 @@ class PreProcessor(BasePreProcessor):
|
||||
def split(
|
||||
self,
|
||||
document: Union[dict, Document],
|
||||
split_by: Optional[Literal["word", "sentence", "passage"]],
|
||||
split_by: Optional[Literal["token", "word", "sentence", "passage"]],
|
||||
split_length: int,
|
||||
split_overlap: int,
|
||||
split_respect_sentence_boundary: bool,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
) -> List[Document]:
|
||||
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
|
||||
@ -359,8 +375,10 @@ class PreProcessor(BasePreProcessor):
|
||||
if not split_length:
|
||||
raise Exception("split_length needs be set when using split_by.")
|
||||
|
||||
if split_respect_sentence_boundary and split_by != "word":
|
||||
raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with split_by='word'.")
|
||||
if split_respect_sentence_boundary and split_by not in ["word", "token"]:
|
||||
raise NotImplementedError(
|
||||
"'split_respect_sentence_boundary=True' is only compatible with split_by='word' or 'token'."
|
||||
)
|
||||
|
||||
if type(document.content) is not str:
|
||||
logger.error("Document content is not of type str. Nothing to split.")
|
||||
@ -369,13 +387,17 @@ class PreProcessor(BasePreProcessor):
|
||||
text = document.content
|
||||
headlines = document.meta["headlines"] if "headlines" in document.meta else []
|
||||
|
||||
if split_respect_sentence_boundary and split_by == "word":
|
||||
text_splits, splits_pages, splits_start_idxs = self._split_by_word_respecting_sent_boundary(
|
||||
text=text, split_length=split_length, split_overlap=split_overlap
|
||||
if split_respect_sentence_boundary and split_by in ["word", "token"]:
|
||||
|
||||
def split_function(text):
|
||||
return self._split_tokens(text, tokenizer=tokenizer) if split_by == "token" else text.split()
|
||||
|
||||
text_splits, splits_pages, splits_start_idxs = self._split_into_units_respecting_sent_boundary(
|
||||
text=text, split_length=split_length, split_overlap=split_overlap, split_function=split_function
|
||||
)
|
||||
else:
|
||||
# create individual "elements" of passage, sentence, or word
|
||||
elements, split_at = self._split_into_units(text=text, split_by=split_by)
|
||||
elements, split_at = self._split_into_units(text=text, split_by=split_by, tokenizer=tokenizer)
|
||||
|
||||
# concatenate individual elements based on split_length & split_stride
|
||||
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
|
||||
@ -467,15 +489,15 @@ class PreProcessor(BasePreProcessor):
|
||||
cleaned_text = text.replace(substring, "")
|
||||
return cleaned_text, headlines
|
||||
|
||||
def _split_by_word_respecting_sent_boundary(
|
||||
self, text: str, split_length: int, split_overlap: int
|
||||
def _split_into_units_respecting_sent_boundary(
|
||||
self, text: str, split_length: int, split_overlap: int, split_function: Callable
|
||||
) -> Tuple[List[str], List[int], List[int]]:
|
||||
"""
|
||||
Splits the text into parts of split_length words while respecting sentence boundaries.
|
||||
"""
|
||||
sentences = self._split_sentences(text)
|
||||
|
||||
word_count_slice = 0
|
||||
unit_count_slice = 0
|
||||
cur_page = 1
|
||||
cur_start_idx = 0
|
||||
splits_pages = []
|
||||
@ -483,17 +505,17 @@ class PreProcessor(BasePreProcessor):
|
||||
splits_start_idxs = []
|
||||
current_slice: List[str] = []
|
||||
for sen in sentences:
|
||||
word_count_sen = len(sen.split())
|
||||
unit_count_sen = len(split_function(sen))
|
||||
|
||||
if word_count_sen > split_length:
|
||||
if unit_count_sen > split_length:
|
||||
long_sentence_message = (
|
||||
"We found one or more sentences whose word count is higher than the split length."
|
||||
"We found one or more sentences whose split count is higher than the split length."
|
||||
)
|
||||
if long_sentence_message not in self.print_log:
|
||||
self.print_log.add(long_sentence_message)
|
||||
logger.warning(long_sentence_message)
|
||||
|
||||
if word_count_slice + word_count_sen > split_length:
|
||||
if unit_count_slice + unit_count_sen > split_length:
|
||||
# Number of words exceeds split_length -> save current slice and start a new one
|
||||
if current_slice:
|
||||
list_splits.append(current_slice)
|
||||
@ -501,13 +523,13 @@ class PreProcessor(BasePreProcessor):
|
||||
splits_start_idxs.append(cur_start_idx)
|
||||
|
||||
if split_overlap:
|
||||
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
|
||||
current_slice, split_length, split_overlap
|
||||
processed_sents, current_slice, unit_count_slice = self._get_overlap_from_slice(
|
||||
current_slice, split_length, split_overlap, split_function
|
||||
)
|
||||
else:
|
||||
processed_sents = current_slice
|
||||
current_slice = []
|
||||
word_count_slice = 0
|
||||
unit_count_slice = 0
|
||||
|
||||
cur_start_idx += len("".join(processed_sents))
|
||||
|
||||
@ -522,7 +544,7 @@ class PreProcessor(BasePreProcessor):
|
||||
cur_page += num_page_breaks
|
||||
|
||||
current_slice.append(sen)
|
||||
word_count_slice += word_count_sen
|
||||
unit_count_slice += unit_count_sen
|
||||
|
||||
if current_slice:
|
||||
list_splits.append(current_slice)
|
||||
@ -539,7 +561,7 @@ class PreProcessor(BasePreProcessor):
|
||||
|
||||
@staticmethod
|
||||
def _get_overlap_from_slice(
|
||||
current_slice: List[str], split_length: int, split_overlap: int
|
||||
current_slice: List[str], split_length: int, split_overlap: int, split_function: Callable
|
||||
) -> Tuple[List[str], List[str], int]:
|
||||
"""
|
||||
Returns a tuple with the following elements:
|
||||
@ -553,7 +575,7 @@ class PreProcessor(BasePreProcessor):
|
||||
current_slice_copy = deepcopy(current_slice)
|
||||
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
|
||||
for idx, s in reversed(list(enumerate(current_slice))[1:]):
|
||||
sen_len = len(s.split())
|
||||
sen_len = len(split_function(s))
|
||||
if word_count_overlap < split_overlap and sen_len < split_length:
|
||||
overlap.append(s)
|
||||
word_count_overlap += sen_len
|
||||
@ -566,7 +588,7 @@ class PreProcessor(BasePreProcessor):
|
||||
|
||||
return processed_sents, next_slice, word_count_slice
|
||||
|
||||
def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
|
||||
def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[List[str], str]:
|
||||
if split_by == "passage":
|
||||
elements = text.split("\n\n")
|
||||
split_at = "\n\n"
|
||||
@ -576,8 +598,13 @@ class PreProcessor(BasePreProcessor):
|
||||
elif split_by == "word":
|
||||
elements = text.split(" ")
|
||||
split_at = " "
|
||||
elif split_by == "token":
|
||||
elements = self._split_tokens(text, tokenizer)
|
||||
split_at = ""
|
||||
else:
|
||||
raise NotImplementedError("PreProcessor only supports 'passage', 'sentence' or 'word' split_by options.")
|
||||
raise NotImplementedError(
|
||||
"PreProcessor only supports 'passage', 'sentence', 'word' or 'token' split_by options."
|
||||
)
|
||||
|
||||
return elements, split_at
|
||||
|
||||
@ -823,6 +850,35 @@ class PreProcessor(BasePreProcessor):
|
||||
sentences = sentence_tokenizer.tokenize(text)
|
||||
return sentences
|
||||
|
||||
def _split_tokens(self, text: str, tokenizer: Any) -> List[str]:
|
||||
if tokenizer == "tiktoken":
|
||||
tiktoken_import.check()
|
||||
enc = tiktoken.get_encoding("cl100k_base") # tiktoken is reversible and lossless
|
||||
integer_tokens = enc.encode(text, disallowed_special=())
|
||||
elements = [enc.decode_single_token_bytes(token).decode(errors="ignore") for token in integer_tokens]
|
||||
return elements
|
||||
if isinstance(tokenizer, str):
|
||||
transformers_import.check()
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Could not load tokenizer '{tokenizer}' from HuggingFace model hub. "
|
||||
f"Please make sure that the tokenizer is correct and exists."
|
||||
)
|
||||
if isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
elements = []
|
||||
for i in range(l := len(encoded.offset_mapping)):
|
||||
start_current = encoded.offset_mapping[i][0]
|
||||
start_next = encoded.offset_mapping[i + 1][0] if i < l - 1 else len(text)
|
||||
elements.append(text[start_current:start_next])
|
||||
return elements
|
||||
raise ValueError(
|
||||
f"Unsupported tokenizer specification {tokenizer}. "
|
||||
f"Please provide either the string 'tiktoken' or a HuggingFace tokenizer (PreTrainedTokenizerBase)."
|
||||
)
|
||||
|
||||
def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
|
||||
# Try to load a custom model from 'tokenizer_model_path'
|
||||
if self.tokenizer_model_folder is not None:
|
||||
|
2
releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml
Normal file
2
releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml
Normal file
@ -0,0 +1,2 @@
|
||||
features:
|
||||
- Add `split_length` by token in PreProcessor
|
@ -1,12 +1,15 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, List
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import nltk.data
|
||||
import pytest
|
||||
import tiktoken
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from _pytest.tmpdir import TempPathFactory
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from haystack import Document
|
||||
from haystack.nodes.file_converter.pdf import PDFToTextConverter
|
||||
@ -84,6 +87,44 @@ def patched_nltk_data_path(module_tmp_dir: Path, monkeypatch: MonkeyPatch, tmp_p
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_huggingface_tokenizer():
|
||||
class MockTokenizer(PreTrainedTokenizerBase):
|
||||
"""Simple Mock tokenizer splitting the text into 2-character chunks."""
|
||||
|
||||
@staticmethod
|
||||
def tokenize(text, **kwargs):
|
||||
return [text[i : i + 2] for i in range(0, len(text), 2)]
|
||||
|
||||
@staticmethod
|
||||
def encode_plus(text, **kwargs):
|
||||
return Mock(offset_mapping=[(i, min(len(text), i + 2)) for i in range(0, len(text), 2)])
|
||||
|
||||
mock_tokenizer_instance = MockTokenizer()
|
||||
|
||||
with mock.patch.object(AutoTokenizer, "from_pretrained", return_value=mock_tokenizer_instance):
|
||||
yield mock_tokenizer_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tiktoken_tokenizer():
|
||||
class MockTokenizer:
|
||||
"""Simple Mock tokenizer "encoding" the text into a 0 for every 5-character chunk."""
|
||||
|
||||
@staticmethod
|
||||
def encode(text, **kwargs):
|
||||
return [0 for i in range(0, len(text), 5)]
|
||||
|
||||
@staticmethod
|
||||
def decode_single_token_bytes(token):
|
||||
return b"mock "
|
||||
|
||||
mock_tokenizer_instance = MockTokenizer()
|
||||
|
||||
with mock.patch.object(tiktoken, "get_encoding", return_value=mock_tokenizer_instance):
|
||||
yield mock_tokenizer_instance
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)])
|
||||
def test_preprocess_sentence_split(split_length_and_results):
|
||||
@ -178,6 +219,75 @@ def test_preprocess_word_split():
|
||||
assert len(documents) == 15
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_preprocess_tiktoken_token_split(mock_tiktoken_tokenizer):
|
||||
raw_docs = [
|
||||
"This is a document. It has two sentences and eleven words.",
|
||||
"This is a document with a long sentence (longer than my split length), it has seventeen words.",
|
||||
]
|
||||
docs = [Document(content=content) for content in raw_docs]
|
||||
split_length = 10
|
||||
token_split_docs_not_respecting_sentences = PreProcessor(
|
||||
split_by="token",
|
||||
split_length=split_length,
|
||||
split_respect_sentence_boundary=False,
|
||||
split_overlap=0,
|
||||
tokenizer="tiktoken",
|
||||
).process(docs)
|
||||
assert len(token_split_docs_not_respecting_sentences) == 4
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
split_documents_encoded = [
|
||||
enc.encode(d.content, allowed_special="all", disallowed_special=())
|
||||
for d in token_split_docs_not_respecting_sentences
|
||||
]
|
||||
assert all([len(d) <= split_length for d in split_documents_encoded])
|
||||
token_split_docs_respecting_sentences = PreProcessor(
|
||||
split_by="token",
|
||||
split_length=split_length,
|
||||
split_respect_sentence_boundary=True,
|
||||
split_overlap=0,
|
||||
tokenizer="tiktoken",
|
||||
).process(docs)
|
||||
assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_preprocess_huggingface_token_split(mock_huggingface_tokenizer):
|
||||
raw_docs = [
|
||||
"This is a document. It has two sentences and eleven words.",
|
||||
"This is a document with a long sentence (longer than my split length), it has seventeen words.",
|
||||
]
|
||||
docs = [Document(content=content) for content in raw_docs]
|
||||
split_length = 10
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
token_split_docs_not_respecting_sentences = PreProcessor(
|
||||
split_by="token",
|
||||
split_length=split_length,
|
||||
split_respect_sentence_boundary=False,
|
||||
split_overlap=0,
|
||||
tokenizer=tokenizer,
|
||||
).process(docs)
|
||||
assert len(token_split_docs_not_respecting_sentences) == 8
|
||||
split_documents_retokenized = [tokenizer.tokenize(d.content) for d in token_split_docs_not_respecting_sentences]
|
||||
assert all([len(d) <= split_length for d in split_documents_retokenized])
|
||||
token_split_docs_respecting_sentences = PreProcessor(
|
||||
split_by="token",
|
||||
split_length=split_length,
|
||||
split_respect_sentence_boundary=True,
|
||||
split_overlap=0,
|
||||
tokenizer=tokenizer,
|
||||
).process(docs)
|
||||
assert len(token_split_docs_respecting_sentences) == 3 # should not be more than there are sentences
|
||||
token_split_docs_not_respecting_sentences_instantiate_by_name = PreProcessor(
|
||||
split_by="token",
|
||||
split_length=split_length,
|
||||
split_respect_sentence_boundary=False,
|
||||
split_overlap=0,
|
||||
tokenizer="bert-base-uncased",
|
||||
).process(docs)
|
||||
assert token_split_docs_not_respecting_sentences == token_split_docs_not_respecting_sentences_instantiate_by_name
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("split_length_and_results", [(1, 3), (2, 2)])
|
||||
def test_preprocess_passage_split(split_length_and_results):
|
||||
|
Loading…
x
Reference in New Issue
Block a user