diff --git a/haystack/preprocessor/preprocessor.py b/haystack/preprocessor/preprocessor.py index d66c02945..8b5fa8e0d 100644 --- a/haystack/preprocessor/preprocessor.py +++ b/haystack/preprocessor/preprocessor.py @@ -80,31 +80,30 @@ class PreProcessor(BasePreProcessor): if not self.split_length: raise Exception("split_length needs be set when using split_by.") + if self.split_respect_sentence_boundary and self.split_by not in("word","sentence"): + raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with" + " split_by='word' or split_by='sentence'.") + text = document["text"] - if self.split_respect_sentence_boundary: # split by words ensuring no sub sentence splits - if self.split_by == "word": - sentences = nltk.tokenize.sent_tokenize(text) - word_count = 0 - text_splits = [] - current_slice = "" - for sen in sentences: - current_word_count = len(sen.split(" ")) - if current_word_count > self.split_length: - logger.warning(f"A sentence found with word count higher than the split length.") - if word_count + current_word_count > self.split_length: - text_splits.append(current_slice) - current_slice = "" - word_count = 0 - current_slice += sen - word_count += len(sen.split(" ")) - if current_slice: + if self.split_respect_sentence_boundary and self.split_by == "word": + # split by words ensuring no sub sentence splits + sentences = nltk.tokenize.sent_tokenize(text) + word_count = 0 + text_splits = [] + current_slice = "" + for sen in sentences: + current_word_count = len(sen.split(" ")) + if current_word_count > self.split_length: + logger.warning(f"A sentence found with word count higher than the split length.") + if word_count + current_word_count > self.split_length: text_splits.append(current_slice) - - else: - raise NotImplementedError( - "'split_respect_sentence_boundary' parameter is only compatible with " "split_by='word'." - ) + current_slice = "" + word_count = 0 + current_slice += sen + word_count += len(sen.split(" ")) + if current_slice: + text_splits.append(current_slice) else: # create individual "elements" of passage, sentence, or word if self.split_by == "passage": diff --git a/test/test_preprocessor.py b/test/test_preprocessor.py index 0b0b3a17d..651257d52 100644 --- a/test/test_preprocessor.py +++ b/test/test_preprocessor.py @@ -44,11 +44,11 @@ def test_preprocess_word_split(): def test_preprocess_passage_split(): document = {"text": TEXT} - preprocessor = PreProcessor(split_length=1, split_stride=0, split_by="passage") + preprocessor = PreProcessor(split_length=1, split_stride=0, split_by="passage", split_respect_sentence_boundary=False) documents = preprocessor.process(document) assert len(documents) == 3 - preprocessor = PreProcessor(split_length=2, split_stride=0, split_by="passage") + preprocessor = PreProcessor(split_length=2, split_stride=0, split_by="passage", split_respect_sentence_boundary=False) documents = preprocessor.process(document) assert len(documents) == 2