Restructure checks in PreProcessor (#504)

* restructure checks

* fix variable name

* Fix test
This commit is contained in:
Malte Pietsch 2020-10-20 06:43:59 +02:00 committed by GitHub
parent c13abba6d6
commit 956543e239
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 24 deletions

View File

@ -80,10 +80,14 @@ class PreProcessor(BasePreProcessor):
if not self.split_length: if not self.split_length:
raise Exception("split_length needs be set when using split_by.") raise Exception("split_length needs be set when using split_by.")
if self.split_respect_sentence_boundary and self.split_by not in("word","sentence"):
raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with"
" split_by='word' or split_by='sentence'.")
text = document["text"] text = document["text"]
if self.split_respect_sentence_boundary: # split by words ensuring no sub sentence splits if self.split_respect_sentence_boundary and self.split_by == "word":
if self.split_by == "word": # split by words ensuring no sub sentence splits
sentences = nltk.tokenize.sent_tokenize(text) sentences = nltk.tokenize.sent_tokenize(text)
word_count = 0 word_count = 0
text_splits = [] text_splits = []
@ -100,11 +104,6 @@ class PreProcessor(BasePreProcessor):
word_count += len(sen.split(" ")) word_count += len(sen.split(" "))
if current_slice: if current_slice:
text_splits.append(current_slice) text_splits.append(current_slice)
else:
raise NotImplementedError(
"'split_respect_sentence_boundary' parameter is only compatible with " "split_by='word'."
)
else: else:
# create individual "elements" of passage, sentence, or word # create individual "elements" of passage, sentence, or word
if self.split_by == "passage": if self.split_by == "passage":

View File

@ -44,11 +44,11 @@ def test_preprocess_word_split():
def test_preprocess_passage_split(): def test_preprocess_passage_split():
document = {"text": TEXT} document = {"text": TEXT}
preprocessor = PreProcessor(split_length=1, split_stride=0, split_by="passage") preprocessor = PreProcessor(split_length=1, split_stride=0, split_by="passage", split_respect_sentence_boundary=False)
documents = preprocessor.process(document) documents = preprocessor.process(document)
assert len(documents) == 3 assert len(documents) == 3
preprocessor = PreProcessor(split_length=2, split_stride=0, split_by="passage") preprocessor = PreProcessor(split_length=2, split_stride=0, split_by="passage", split_respect_sentence_boundary=False)
documents = preprocessor.process(document) documents = preprocessor.process(document)
assert len(documents) == 2 assert len(documents) == 2