diff --git a/haystack/preprocessor/preprocessor.py b/haystack/preprocessor/preprocessor.py index bb889044d..97b15934e 100644 --- a/haystack/preprocessor/preprocessor.py +++ b/haystack/preprocessor/preprocessor.py @@ -98,22 +98,35 @@ class PreProcessor(BasePreProcessor): # split by words ensuring no sub sentence splits sentences = nltk.tokenize.sent_tokenize(text) word_count = 0 - text_splits = [] - current_slice = "" + list_splits = [] + current_slice: List[str] = [] for sen in sentences: current_word_count = len(sen.split(" ")) if current_word_count > self.split_length: logger.warning(f"A sentence found with word count higher than the split length.") if word_count + current_word_count > self.split_length: - text_splits.append(current_slice) - current_slice = "" - word_count = 0 - if len(current_slice) != 0: - sen = " " + sen - current_slice += sen - word_count += current_word_count + list_splits.append(current_slice) + #Enable split_stride with split_by='word' while respecting sentence boundaries. + if self.split_stride: + overlap = [] + w_count = 0 + for s in current_slice[::-1]: + sen_len = len(s.split(" ")) + if w_count < self.split_stride: + overlap.append(s) + w_count += sen_len + else: + break + current_slice = list(reversed(overlap)) + word_count = w_count + else: + current_slice = [] + word_count = 0 + current_slice.append(sen) + word_count += len(sen.split(" ")) if current_slice: - text_splits.append(current_slice) + list_splits.append(current_slice) + text_splits = [' '.join(sl) for sl in list_splits] else: # create individual "elements" of passage, sentence, or word if self.split_by == "passage": diff --git a/test/test_preprocessor.py b/test/test_preprocessor.py index 121c53182..4e0220642 100644 --- a/test/test_preprocessor.py +++ b/test/test_preprocessor.py @@ -47,6 +47,10 @@ def test_preprocess_word_split(): assert len(doc["text"].split(" ")) <= 15 or doc["text"].startswith("This is to trick") assert len(documents) == 8 + preprocessor = PreProcessor(split_length=40, split_stride=10, split_by="word", split_respect_sentence_boundary=True) + documents = preprocessor.process(document) + assert len(documents) == 5 + @pytest.mark.tika def test_preprocess_passage_split(): @@ -71,5 +75,4 @@ def test_clean_header_footer(): assert len(documents) == 1 assert "This is a header." not in documents[0]["text"] - assert "footer" not in documents[0]["text"] - \ No newline at end of file + assert "footer" not in documents[0]["text"] \ No newline at end of file