feat: Add DocumentCleaner 2.0 (#5976)

* remove whitespaces, substrings, regex, empty lines

* remove repeated substrings

* reno

* return empty string as shortest common ngram

* address first half of review feedback

* address second half of review feedback

* mention \f page separator for header/footer removal

* mention \f page separator for header/footer removal

* mark example usage as python code
This commit is contained in:
Julian Risch 2023-10-13 12:39:55 +02:00 committed by GitHub
parent ad25041618
commit aaee03aee8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 434 additions and 1 deletions

View File

@ -1,4 +1,5 @@
from haystack.preview.components.preprocessors.text_document_cleaner import DocumentCleaner
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter
from haystack.preview.components.preprocessors.text_language_classifier import TextLanguageClassifier
__all__ = ["TextDocumentSplitter", "TextLanguageClassifier"]
__all__ = ["TextDocumentSplitter", "DocumentCleaner", "TextLanguageClassifier"]

View File

@ -0,0 +1,247 @@
import logging
import re
from copy import deepcopy
from functools import partial, reduce
from itertools import chain
from typing import Any, Dict, Generator, List, Optional, Set
from haystack.preview import Document, component, default_from_dict, default_to_dict
logger = logging.getLogger(__name__)
@component
class DocumentCleaner:
"""
Makes text documents more readable by removing extra whitespaces, empty lines, specified substrings, regexes, page headers and footers (in this order).
This is useful for preparing the documents for further processing by LLMs.
Example usage in an indexing pipeline:
```python
document_store = MemoryDocumentStore()
p = Pipeline()
p.add_component(instance=TextFileToDocument(), name="text_file_converter")
p.add_component(instance=DocumentCleaner(), name="cleaner")
p.add_component(instance=TextDocumentSplitter(split_by="sentence", split_length=1), name="splitter")
p.add_component(instance=DocumentWriter(document_store=document_store), name="writer")
p.connect("text_file_converter.documents", "cleaner.documents")
p.connect("cleaner.documents", "splitter.documents")
p.connect("splitter.documents", "writer.documents")
```
"""
def __init__(
self,
remove_empty_lines: bool = True,
remove_extra_whitespaces: bool = True,
remove_repeated_substrings: bool = False,
remove_substrings: Optional[List[str]] = None,
remove_regex: Optional[str] = None,
):
"""
:param remove_empty_lines: Whether to remove empty lines.
:param remove_extra_whitespaces: Whether to remove extra whitespaces.
:param remove_repeated_substrings: Whether to remove repeated substrings (headers/footers) from pages.
Pages in the text need to be separated by form feed character "\f",
which is supported by TextFileToDocument and AzureOCRDocumentConverter.
:param remove_substrings: List of substrings to remove from the text.
:param remove_regex: Regex to match and replace substrings by "".
"""
self.remove_empty_lines = remove_empty_lines
self.remove_extra_whitespaces = remove_extra_whitespaces
self.remove_repeated_substrings = remove_repeated_substrings
self.remove_substrings = remove_substrings
self.remove_regex = remove_regex
@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Run the DocumentCleaner on the given list of documents
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError("DocumentCleaner expects a List of Documents as input.")
cleaned_docs = []
for doc in documents:
if doc.text is None:
logger.warning(
"DocumentCleaner only cleans text documents but document.text for document ID %s is None.", doc.id
)
cleaned_docs.append(doc)
continue
text = doc.text
if self.remove_extra_whitespaces:
text = self._remove_extra_whitespaces(text)
if self.remove_empty_lines:
text = self._remove_empty_lines(text)
if self.remove_substrings:
text = self._remove_substrings(text, self.remove_substrings)
if self.remove_regex:
text = self._remove_regex(text, self.remove_regex)
if self.remove_repeated_substrings:
text = self._remove_repeated_substrings(text)
cleaned_docs.append(Document(text=text, metadata=deepcopy(doc.metadata)))
return {"documents": cleaned_docs}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
remove_empty_lines=self.remove_empty_lines,
remove_extra_whitespaces=self.remove_extra_whitespaces,
remove_repeated_substrings=self.remove_repeated_substrings,
remove_substrings=self.remove_substrings,
remove_regex=self.remove_regex,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DocumentCleaner":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def _remove_empty_lines(self, text: str) -> str:
"""
Remove empty lines and lines that contain nothing but whitespaces from text.
:param text: Text to clean.
:param return: The text without empty lines.
"""
lines = text.split("\n")
non_empty_lines = filter(lambda line: line.strip() != "", lines)
return "\n".join(non_empty_lines)
def _remove_extra_whitespaces(self, text: str) -> str:
"""
Remove extra whitespaces from text.
:param text: Text to clean.
:param return: The text without extra whitespaces.
"""
return re.sub(r"\s\s+", " ", text).strip()
def _remove_regex(self, text: str, regex: str) -> str:
"""
Remove substrings that match the specified regex from the text.
:param text: Text to clean.
:param regex: Regex to match and replace substrings by "".
:param return: The text without any substrings that match the regex.
"""
return re.sub(regex, "", text).strip()
def _remove_substrings(self, text: str, substrings: List[str]) -> str:
"""
Remove all specified substrings from the text.
:param text: Text to clean.
:param substrings: Substrings to remove.
:return: The text without the specified substrings.
"""
for substring in substrings:
text = text.replace(substring, "")
return text
def _remove_repeated_substrings(self, text: str) -> str:
"""
Remove any substrings from the text that occur repeatedly on every page. For example headers or footers.
Pages in the text need to be separated by form feed character "\f".
:param text: Text to clean.
:return: The text without the repeated substrings.
"""
return self._find_and_remove_header_footer(
text, n_chars=300, n_first_pages_to_ignore=1, n_last_pages_to_ignore=1
)
def _find_and_remove_header_footer(
self, text: str, n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int
) -> str:
"""
Heuristic to find footers and headers across different pages by searching for the longest common string.
Pages in the text need to be separated by form feed character "\f".
For headers, we only search in the first n_chars characters (for footer: last n_chars).
Note: This heuristic uses exact matches and therefore works well for footers like "Copyright 2019 by XXX",
but won't detect "Page 3 of 4" or similar.
:param n_chars: The number of first/last characters where the header/footer shall be searched in.
:param n_first_pages_to_ignore: The number of first pages to ignore (e.g. TOCs often don't contain footer/header).
:param n_last_pages_to_ignore: The number of last pages to ignore.
:return: The text without the found headers and footers.
"""
pages = text.split("\f")
# header
start_of_pages = [p[:n_chars] for p in pages[n_first_pages_to_ignore:-n_last_pages_to_ignore]]
found_header = self._find_longest_common_ngram(start_of_pages)
if found_header:
pages = [page.replace(found_header, "") for page in pages]
# footer
end_of_pages = [p[-n_chars:] for p in pages[n_first_pages_to_ignore:-n_last_pages_to_ignore]]
found_footer = self._find_longest_common_ngram(end_of_pages)
if found_footer:
pages = [page.replace(found_footer, "") for page in pages]
logger.debug("Removed header '%s' and footer '%s' in document", found_header, found_footer)
text = "\f".join(pages)
return text
def _ngram(self, seq: str, n: int) -> Generator[str, None, None]:
"""
Return all ngrams of length n from a text sequence. Each ngram consists of n words split by whitespace.
:param seq: The sequence to generate ngrams from.
:param n: The length of the ngrams to generate.
:return: A Generator generating all ngrams of length n from the given sequence.
"""
# In order to maintain the original whitespace, but still consider \n and \t for n-gram tokenization,
# we add a space here and remove it after creation of the ngrams again (see below)
seq = seq.replace("\n", " \n")
seq = seq.replace("\t", " \t")
words = seq.split(" ")
ngrams = (
" ".join(words[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(words) - n + 1)
)
return ngrams
def _allngram(self, seq: str, min_ngram: int, max_ngram: int) -> Set[str]:
"""
Generates all possible ngrams from a given sequence of text.
Considering all ngram lengths between the minimum and maximum length.
:param seq: The sequence to generate ngrams from.
:param min_ngram: The minimum length of ngram to consider.
:param max_ngram: The maximum length of ngram to consider.
:return: A set of all ngrams from the given sequence.
"""
lengths = range(min_ngram, max_ngram) if max_ngram else range(min_ngram, len(seq))
ngrams = map(partial(self._ngram, seq), lengths)
res = set(chain.from_iterable(ngrams))
return res
def _find_longest_common_ngram(self, sequences: List[str], min_ngram: int = 3, max_ngram: int = 30) -> str:
"""
Find the longest common ngram across a list of text sequences (e.g. start of pages).
Considering all ngram lengths between the minimum and maximum length. Helpful for finding footers, headers etc.
Empty sequences are ignored.
:param sequences: The list of strings that shall be searched for common n_grams.
:param max_ngram: The maximum length of ngram to consider.
:param min_ngram: The minimum length of ngram to consider.
:return: The longest ngram that all sequences have in common.
"""
sequences = [s for s in sequences if s] # filter empty sequences
if not sequences:
return ""
seqs_ngrams = map(partial(self._allngram, min_ngram=min_ngram, max_ngram=max_ngram), sequences)
intersection = reduce(set.intersection, seqs_ngrams)
longest = max(intersection, key=len, default="")
return longest if longest.strip() else ""

View File

@ -0,0 +1,5 @@
---
preview:
- |
Added DocumentCleaner, which removes extra whitespace, empty lines, headers, etc. from Documents containing text.
Useful as a preprocessing step before splitting into shorter text documents.

View File

@ -0,0 +1,180 @@
import logging
import pytest
from haystack.preview import Document
from haystack.preview.components.preprocessors import DocumentCleaner
class TestDocumentCleaner:
@pytest.mark.unit
def test_init(self):
cleaner = DocumentCleaner()
assert cleaner.remove_empty_lines == True
assert cleaner.remove_extra_whitespaces == True
assert cleaner.remove_repeated_substrings == False
assert cleaner.remove_substrings is None
assert cleaner.remove_regex is None
@pytest.mark.unit
def test_to_dict(self):
cleaner = DocumentCleaner()
data = cleaner.to_dict()
assert data == {
"type": "DocumentCleaner",
"init_parameters": {
"remove_empty_lines": True,
"remove_extra_whitespaces": True,
"remove_repeated_substrings": False,
"remove_substrings": None,
"remove_regex": None,
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
cleaner = DocumentCleaner(
remove_empty_lines=False,
remove_extra_whitespaces=False,
remove_repeated_substrings=True,
remove_substrings=["a", "b"],
remove_regex=r"\s\s+",
)
data = cleaner.to_dict()
assert data == {
"type": "DocumentCleaner",
"init_parameters": {
"remove_empty_lines": False,
"remove_extra_whitespaces": False,
"remove_repeated_substrings": True,
"remove_substrings": ["a", "b"],
"remove_regex": r"\s\s+",
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "DocumentCleaner",
"init_parameters": {
"remove_empty_lines": False,
"remove_extra_whitespaces": False,
"remove_repeated_substrings": True,
"remove_substrings": ["a", "b"],
"remove_regex": r"\s\s+",
},
}
cleaner = DocumentCleaner.from_dict(data)
assert cleaner.remove_empty_lines == False
assert cleaner.remove_extra_whitespaces == False
assert cleaner.remove_repeated_substrings == True
assert cleaner.remove_substrings == ["a", "b"]
assert cleaner.remove_regex == r"\s\s+"
@pytest.mark.unit
def test_non_text_document(self, caplog):
with caplog.at_level(logging.WARNING):
cleaner = DocumentCleaner()
cleaner.run(documents=[Document()])
assert "DocumentCleaner only cleans text documents but document.text for document ID" in caplog.text
@pytest.mark.unit
def test_single_document(self):
with pytest.raises(TypeError, match="DocumentCleaner expects a List of Documents as input."):
cleaner = DocumentCleaner()
cleaner.run(documents=Document())
@pytest.mark.unit
def test_empty_list(self):
cleaner = DocumentCleaner()
result = cleaner.run(documents=[])
assert result == {"documents": []}
@pytest.mark.unit
def test_remove_empty_lines(self):
cleaner = DocumentCleaner(remove_extra_whitespaces=False)
result = cleaner.run(
documents=[
Document(
text="This is a text with some words. "
""
"There is a second sentence. "
""
"And there is a third sentence."
)
]
)
assert len(result["documents"]) == 1
assert (
result["documents"][0].text
== "This is a text with some words. There is a second sentence. And there is a third sentence."
)
@pytest.mark.unit
def test_remove_whitespaces(self):
cleaner = DocumentCleaner(remove_empty_lines=False)
result = cleaner.run(
documents=[
Document(
text=" This is a text with some words. "
""
"There is a second sentence. "
""
"And there is a third sentence. "
)
]
)
assert len(result["documents"]) == 1
assert result["documents"][0].text == (
"This is a text with some words. " "" "There is a second sentence. " "" "And there is a third sentence."
)
@pytest.mark.unit
def test_remove_substrings(self):
cleaner = DocumentCleaner(remove_substrings=["This", "A", "words", "🪲"])
result = cleaner.run(documents=[Document(text="This is a text with some words.🪲")])
assert len(result["documents"]) == 1
assert result["documents"][0].text == " is a text with some ."
@pytest.mark.unit
def test_remove_regex(self):
cleaner = DocumentCleaner(remove_regex=r"\s\s+")
result = cleaner.run(documents=[Document(text="This is a text with some words.")])
assert len(result["documents"]) == 1
assert result["documents"][0].text == "This is a text with some words."
@pytest.mark.unit
def test_remove_repeated_substrings(self):
cleaner = DocumentCleaner(
remove_empty_lines=False, remove_extra_whitespaces=False, remove_repeated_substrings=True
)
text = """First Page This is a header.
Page of
2
4
Lorem ipsum dolor sit amet
This is a footer number 1
This is footer number 2 This is a header.
Page of
3
4
Sid ut perspiciatis unde
This is a footer number 1
This is footer number 2 This is a header.
Page of
4
4
Sed do eiusmod tempor.
This is a footer number 1
This is footer number 2"""
expected_text = """First Page 2
4
Lorem ipsum dolor sit amet 3
4
Sid ut perspiciatis unde 4
4
Sed do eiusmod tempor."""
result = cleaner.run(documents=[Document(text=text)])
assert result["documents"][0].text == expected_text