mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 12:07:04 +00:00
feat: Add DocumentCleaner 2.0 (#5976)
* remove whitespaces, substrings, regex, empty lines * remove repeated substrings * reno * return empty string as shortest common ngram * address first half of review feedback * address second half of review feedback * mention \f page separator for header/footer removal * mention \f page separator for header/footer removal * mark example usage as python code
This commit is contained in:
parent
ad25041618
commit
aaee03aee8
@ -1,4 +1,5 @@
|
||||
from haystack.preview.components.preprocessors.text_document_cleaner import DocumentCleaner
|
||||
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter
|
||||
from haystack.preview.components.preprocessors.text_language_classifier import TextLanguageClassifier
|
||||
|
||||
__all__ = ["TextDocumentSplitter", "TextLanguageClassifier"]
|
||||
__all__ = ["TextDocumentSplitter", "DocumentCleaner", "TextLanguageClassifier"]
|
||||
|
||||
@ -0,0 +1,247 @@
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial, reduce
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, Generator, List, Optional, Set
|
||||
|
||||
from haystack.preview import Document, component, default_from_dict, default_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class DocumentCleaner:
|
||||
"""
|
||||
Makes text documents more readable by removing extra whitespaces, empty lines, specified substrings, regexes, page headers and footers (in this order).
|
||||
This is useful for preparing the documents for further processing by LLMs.
|
||||
|
||||
Example usage in an indexing pipeline:
|
||||
|
||||
```python
|
||||
document_store = MemoryDocumentStore()
|
||||
p = Pipeline()
|
||||
p.add_component(instance=TextFileToDocument(), name="text_file_converter")
|
||||
p.add_component(instance=DocumentCleaner(), name="cleaner")
|
||||
p.add_component(instance=TextDocumentSplitter(split_by="sentence", split_length=1), name="splitter")
|
||||
p.add_component(instance=DocumentWriter(document_store=document_store), name="writer")
|
||||
p.connect("text_file_converter.documents", "cleaner.documents")
|
||||
p.connect("cleaner.documents", "splitter.documents")
|
||||
p.connect("splitter.documents", "writer.documents")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
remove_empty_lines: bool = True,
|
||||
remove_extra_whitespaces: bool = True,
|
||||
remove_repeated_substrings: bool = False,
|
||||
remove_substrings: Optional[List[str]] = None,
|
||||
remove_regex: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param remove_empty_lines: Whether to remove empty lines.
|
||||
:param remove_extra_whitespaces: Whether to remove extra whitespaces.
|
||||
:param remove_repeated_substrings: Whether to remove repeated substrings (headers/footers) from pages.
|
||||
Pages in the text need to be separated by form feed character "\f",
|
||||
which is supported by TextFileToDocument and AzureOCRDocumentConverter.
|
||||
:param remove_substrings: List of substrings to remove from the text.
|
||||
:param remove_regex: Regex to match and replace substrings by "".
|
||||
"""
|
||||
|
||||
self.remove_empty_lines = remove_empty_lines
|
||||
self.remove_extra_whitespaces = remove_extra_whitespaces
|
||||
self.remove_repeated_substrings = remove_repeated_substrings
|
||||
self.remove_substrings = remove_substrings
|
||||
self.remove_regex = remove_regex
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document]):
|
||||
"""
|
||||
Run the DocumentCleaner on the given list of documents
|
||||
"""
|
||||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
||||
raise TypeError("DocumentCleaner expects a List of Documents as input.")
|
||||
|
||||
cleaned_docs = []
|
||||
for doc in documents:
|
||||
if doc.text is None:
|
||||
logger.warning(
|
||||
"DocumentCleaner only cleans text documents but document.text for document ID %s is None.", doc.id
|
||||
)
|
||||
cleaned_docs.append(doc)
|
||||
continue
|
||||
text = doc.text
|
||||
|
||||
if self.remove_extra_whitespaces:
|
||||
text = self._remove_extra_whitespaces(text)
|
||||
if self.remove_empty_lines:
|
||||
text = self._remove_empty_lines(text)
|
||||
if self.remove_substrings:
|
||||
text = self._remove_substrings(text, self.remove_substrings)
|
||||
if self.remove_regex:
|
||||
text = self._remove_regex(text, self.remove_regex)
|
||||
if self.remove_repeated_substrings:
|
||||
text = self._remove_repeated_substrings(text)
|
||||
|
||||
cleaned_docs.append(Document(text=text, metadata=deepcopy(doc.metadata)))
|
||||
|
||||
return {"documents": cleaned_docs}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
remove_empty_lines=self.remove_empty_lines,
|
||||
remove_extra_whitespaces=self.remove_extra_whitespaces,
|
||||
remove_repeated_substrings=self.remove_repeated_substrings,
|
||||
remove_substrings=self.remove_substrings,
|
||||
remove_regex=self.remove_regex,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DocumentCleaner":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _remove_empty_lines(self, text: str) -> str:
|
||||
"""
|
||||
Remove empty lines and lines that contain nothing but whitespaces from text.
|
||||
:param text: Text to clean.
|
||||
:param return: The text without empty lines.
|
||||
"""
|
||||
lines = text.split("\n")
|
||||
non_empty_lines = filter(lambda line: line.strip() != "", lines)
|
||||
return "\n".join(non_empty_lines)
|
||||
|
||||
def _remove_extra_whitespaces(self, text: str) -> str:
|
||||
"""
|
||||
Remove extra whitespaces from text.
|
||||
:param text: Text to clean.
|
||||
:param return: The text without extra whitespaces.
|
||||
"""
|
||||
return re.sub(r"\s\s+", " ", text).strip()
|
||||
|
||||
def _remove_regex(self, text: str, regex: str) -> str:
|
||||
"""
|
||||
Remove substrings that match the specified regex from the text.
|
||||
:param text: Text to clean.
|
||||
:param regex: Regex to match and replace substrings by "".
|
||||
:param return: The text without any substrings that match the regex.
|
||||
"""
|
||||
return re.sub(regex, "", text).strip()
|
||||
|
||||
def _remove_substrings(self, text: str, substrings: List[str]) -> str:
|
||||
"""
|
||||
Remove all specified substrings from the text.
|
||||
:param text: Text to clean.
|
||||
:param substrings: Substrings to remove.
|
||||
:return: The text without the specified substrings.
|
||||
"""
|
||||
for substring in substrings:
|
||||
text = text.replace(substring, "")
|
||||
return text
|
||||
|
||||
def _remove_repeated_substrings(self, text: str) -> str:
|
||||
"""
|
||||
Remove any substrings from the text that occur repeatedly on every page. For example headers or footers.
|
||||
Pages in the text need to be separated by form feed character "\f".
|
||||
:param text: Text to clean.
|
||||
:return: The text without the repeated substrings.
|
||||
"""
|
||||
return self._find_and_remove_header_footer(
|
||||
text, n_chars=300, n_first_pages_to_ignore=1, n_last_pages_to_ignore=1
|
||||
)
|
||||
|
||||
def _find_and_remove_header_footer(
|
||||
self, text: str, n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int
|
||||
) -> str:
|
||||
"""
|
||||
Heuristic to find footers and headers across different pages by searching for the longest common string.
|
||||
Pages in the text need to be separated by form feed character "\f".
|
||||
For headers, we only search in the first n_chars characters (for footer: last n_chars).
|
||||
Note: This heuristic uses exact matches and therefore works well for footers like "Copyright 2019 by XXX",
|
||||
but won't detect "Page 3 of 4" or similar.
|
||||
|
||||
:param n_chars: The number of first/last characters where the header/footer shall be searched in.
|
||||
:param n_first_pages_to_ignore: The number of first pages to ignore (e.g. TOCs often don't contain footer/header).
|
||||
:param n_last_pages_to_ignore: The number of last pages to ignore.
|
||||
:return: The text without the found headers and footers.
|
||||
"""
|
||||
|
||||
pages = text.split("\f")
|
||||
|
||||
# header
|
||||
start_of_pages = [p[:n_chars] for p in pages[n_first_pages_to_ignore:-n_last_pages_to_ignore]]
|
||||
found_header = self._find_longest_common_ngram(start_of_pages)
|
||||
if found_header:
|
||||
pages = [page.replace(found_header, "") for page in pages]
|
||||
|
||||
# footer
|
||||
end_of_pages = [p[-n_chars:] for p in pages[n_first_pages_to_ignore:-n_last_pages_to_ignore]]
|
||||
found_footer = self._find_longest_common_ngram(end_of_pages)
|
||||
if found_footer:
|
||||
pages = [page.replace(found_footer, "") for page in pages]
|
||||
|
||||
logger.debug("Removed header '%s' and footer '%s' in document", found_header, found_footer)
|
||||
text = "\f".join(pages)
|
||||
return text
|
||||
|
||||
def _ngram(self, seq: str, n: int) -> Generator[str, None, None]:
|
||||
"""
|
||||
Return all ngrams of length n from a text sequence. Each ngram consists of n words split by whitespace.
|
||||
:param seq: The sequence to generate ngrams from.
|
||||
:param n: The length of the ngrams to generate.
|
||||
:return: A Generator generating all ngrams of length n from the given sequence.
|
||||
"""
|
||||
|
||||
# In order to maintain the original whitespace, but still consider \n and \t for n-gram tokenization,
|
||||
# we add a space here and remove it after creation of the ngrams again (see below)
|
||||
seq = seq.replace("\n", " \n")
|
||||
seq = seq.replace("\t", " \t")
|
||||
|
||||
words = seq.split(" ")
|
||||
ngrams = (
|
||||
" ".join(words[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(words) - n + 1)
|
||||
)
|
||||
|
||||
return ngrams
|
||||
|
||||
def _allngram(self, seq: str, min_ngram: int, max_ngram: int) -> Set[str]:
|
||||
"""
|
||||
Generates all possible ngrams from a given sequence of text.
|
||||
Considering all ngram lengths between the minimum and maximum length.
|
||||
|
||||
:param seq: The sequence to generate ngrams from.
|
||||
:param min_ngram: The minimum length of ngram to consider.
|
||||
:param max_ngram: The maximum length of ngram to consider.
|
||||
:return: A set of all ngrams from the given sequence.
|
||||
"""
|
||||
lengths = range(min_ngram, max_ngram) if max_ngram else range(min_ngram, len(seq))
|
||||
ngrams = map(partial(self._ngram, seq), lengths)
|
||||
res = set(chain.from_iterable(ngrams))
|
||||
return res
|
||||
|
||||
def _find_longest_common_ngram(self, sequences: List[str], min_ngram: int = 3, max_ngram: int = 30) -> str:
|
||||
"""
|
||||
Find the longest common ngram across a list of text sequences (e.g. start of pages).
|
||||
Considering all ngram lengths between the minimum and maximum length. Helpful for finding footers, headers etc.
|
||||
Empty sequences are ignored.
|
||||
|
||||
:param sequences: The list of strings that shall be searched for common n_grams.
|
||||
:param max_ngram: The maximum length of ngram to consider.
|
||||
:param min_ngram: The minimum length of ngram to consider.
|
||||
:return: The longest ngram that all sequences have in common.
|
||||
"""
|
||||
sequences = [s for s in sequences if s] # filter empty sequences
|
||||
if not sequences:
|
||||
return ""
|
||||
seqs_ngrams = map(partial(self._allngram, min_ngram=min_ngram, max_ngram=max_ngram), sequences)
|
||||
intersection = reduce(set.intersection, seqs_ngrams)
|
||||
|
||||
longest = max(intersection, key=len, default="")
|
||||
return longest if longest.strip() else ""
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Added DocumentCleaner, which removes extra whitespace, empty lines, headers, etc. from Documents containing text.
|
||||
Useful as a preprocessing step before splitting into shorter text documents.
|
||||
@ -0,0 +1,180 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.components.preprocessors import DocumentCleaner
|
||||
|
||||
|
||||
class TestDocumentCleaner:
|
||||
@pytest.mark.unit
|
||||
def test_init(self):
|
||||
cleaner = DocumentCleaner()
|
||||
assert cleaner.remove_empty_lines == True
|
||||
assert cleaner.remove_extra_whitespaces == True
|
||||
assert cleaner.remove_repeated_substrings == False
|
||||
assert cleaner.remove_substrings is None
|
||||
assert cleaner.remove_regex is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
cleaner = DocumentCleaner()
|
||||
data = cleaner.to_dict()
|
||||
assert data == {
|
||||
"type": "DocumentCleaner",
|
||||
"init_parameters": {
|
||||
"remove_empty_lines": True,
|
||||
"remove_extra_whitespaces": True,
|
||||
"remove_repeated_substrings": False,
|
||||
"remove_substrings": None,
|
||||
"remove_regex": None,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
cleaner = DocumentCleaner(
|
||||
remove_empty_lines=False,
|
||||
remove_extra_whitespaces=False,
|
||||
remove_repeated_substrings=True,
|
||||
remove_substrings=["a", "b"],
|
||||
remove_regex=r"\s\s+",
|
||||
)
|
||||
data = cleaner.to_dict()
|
||||
assert data == {
|
||||
"type": "DocumentCleaner",
|
||||
"init_parameters": {
|
||||
"remove_empty_lines": False,
|
||||
"remove_extra_whitespaces": False,
|
||||
"remove_repeated_substrings": True,
|
||||
"remove_substrings": ["a", "b"],
|
||||
"remove_regex": r"\s\s+",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "DocumentCleaner",
|
||||
"init_parameters": {
|
||||
"remove_empty_lines": False,
|
||||
"remove_extra_whitespaces": False,
|
||||
"remove_repeated_substrings": True,
|
||||
"remove_substrings": ["a", "b"],
|
||||
"remove_regex": r"\s\s+",
|
||||
},
|
||||
}
|
||||
cleaner = DocumentCleaner.from_dict(data)
|
||||
assert cleaner.remove_empty_lines == False
|
||||
assert cleaner.remove_extra_whitespaces == False
|
||||
assert cleaner.remove_repeated_substrings == True
|
||||
assert cleaner.remove_substrings == ["a", "b"]
|
||||
assert cleaner.remove_regex == r"\s\s+"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_text_document(self, caplog):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
cleaner = DocumentCleaner()
|
||||
cleaner.run(documents=[Document()])
|
||||
assert "DocumentCleaner only cleans text documents but document.text for document ID" in caplog.text
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_single_document(self):
|
||||
with pytest.raises(TypeError, match="DocumentCleaner expects a List of Documents as input."):
|
||||
cleaner = DocumentCleaner()
|
||||
cleaner.run(documents=Document())
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_list(self):
|
||||
cleaner = DocumentCleaner()
|
||||
result = cleaner.run(documents=[])
|
||||
assert result == {"documents": []}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_remove_empty_lines(self):
|
||||
cleaner = DocumentCleaner(remove_extra_whitespaces=False)
|
||||
result = cleaner.run(
|
||||
documents=[
|
||||
Document(
|
||||
text="This is a text with some words. "
|
||||
""
|
||||
"There is a second sentence. "
|
||||
""
|
||||
"And there is a third sentence."
|
||||
)
|
||||
]
|
||||
)
|
||||
assert len(result["documents"]) == 1
|
||||
assert (
|
||||
result["documents"][0].text
|
||||
== "This is a text with some words. There is a second sentence. And there is a third sentence."
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_remove_whitespaces(self):
|
||||
cleaner = DocumentCleaner(remove_empty_lines=False)
|
||||
result = cleaner.run(
|
||||
documents=[
|
||||
Document(
|
||||
text=" This is a text with some words. "
|
||||
""
|
||||
"There is a second sentence. "
|
||||
""
|
||||
"And there is a third sentence. "
|
||||
)
|
||||
]
|
||||
)
|
||||
assert len(result["documents"]) == 1
|
||||
assert result["documents"][0].text == (
|
||||
"This is a text with some words. " "" "There is a second sentence. " "" "And there is a third sentence."
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_remove_substrings(self):
|
||||
cleaner = DocumentCleaner(remove_substrings=["This", "A", "words", "🪲"])
|
||||
result = cleaner.run(documents=[Document(text="This is a text with some words.🪲")])
|
||||
assert len(result["documents"]) == 1
|
||||
assert result["documents"][0].text == " is a text with some ."
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_remove_regex(self):
|
||||
cleaner = DocumentCleaner(remove_regex=r"\s\s+")
|
||||
result = cleaner.run(documents=[Document(text="This is a text with some words.")])
|
||||
assert len(result["documents"]) == 1
|
||||
assert result["documents"][0].text == "This is a text with some words."
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_remove_repeated_substrings(self):
|
||||
cleaner = DocumentCleaner(
|
||||
remove_empty_lines=False, remove_extra_whitespaces=False, remove_repeated_substrings=True
|
||||
)
|
||||
|
||||
text = """First PageThis is a header.
|
||||
Page of
|
||||
2
|
||||
4
|
||||
Lorem ipsum dolor sit amet
|
||||
This is a footer number 1
|
||||
This is footer number 2This is a header.
|
||||
Page of
|
||||
3
|
||||
4
|
||||
Sid ut perspiciatis unde
|
||||
This is a footer number 1
|
||||
This is footer number 2This is a header.
|
||||
Page of
|
||||
4
|
||||
4
|
||||
Sed do eiusmod tempor.
|
||||
This is a footer number 1
|
||||
This is footer number 2"""
|
||||
|
||||
expected_text = """First Page 2
|
||||
4
|
||||
Lorem ipsum dolor sit amet 3
|
||||
4
|
||||
Sid ut perspiciatis unde 4
|
||||
4
|
||||
Sed do eiusmod tempor."""
|
||||
result = cleaner.run(documents=[Document(text=text)])
|
||||
assert result["documents"][0].text == expected_text
|
||||
Loading…
x
Reference in New Issue
Block a user