mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-29 08:41:00 +00:00
Filter refactor
This commit is contained in:
parent
3ecbeae6dc
commit
dd4f9670b5
@ -1,16 +1,8 @@
|
||||
import csv
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import subprocess
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from io import StringIO
|
||||
|
||||
import requests
|
||||
from lingua import Language, LanguageDetectorBuilder
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import DependencyError, PyPdfError
|
||||
@ -20,61 +12,33 @@ logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class PdfFilter:
|
||||
def __init__(self):
|
||||
def __init__(
|
||||
self,
|
||||
languages_to_keep=None,
|
||||
apply_form_check=True,
|
||||
apply_download_spam_check=True,
|
||||
download_spam_threshold=0.004,
|
||||
):
|
||||
super().__init__()
|
||||
self.language_detector = LanguageDetectorBuilder.from_all_languages().with_preloaded_language_models().build()
|
||||
self.ngram_log_probs = self._build_ngram_log_probs()
|
||||
|
||||
# Used for comparing frequency of words to eliminate bad documents
|
||||
def _build_ngram_log_probs(self):
|
||||
NGRAM_DATASET_LINK = (
|
||||
"https://ai2-s2-research-public.s3-us-west-2.amazonaws.com/lucas/google-1T-unigram/unigram_freq.csv"
|
||||
self.language_detector = (
|
||||
LanguageDetectorBuilder.from_all_languages()
|
||||
.with_preloaded_language_models()
|
||||
.build()
|
||||
)
|
||||
self.languages_to_keep = (
|
||||
languages_to_keep if languages_to_keep is not None else [Language.ENGLISH]
|
||||
)
|
||||
self.apply_form_check = apply_form_check
|
||||
self.apply_download_spam_check = apply_download_spam_check
|
||||
self.download_spam_threshold = download_spam_threshold
|
||||
|
||||
ngrams = {}
|
||||
def _is_form(self, pdf_reader) -> bool:
|
||||
# Check if the PDF is a form
|
||||
if pdf_reader.get_form_text_fields():
|
||||
return True
|
||||
return False # Not a form
|
||||
|
||||
# Download the dataset
|
||||
response = requests.get(NGRAM_DATASET_LINK)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to download data, status code: {response.status_code}")
|
||||
|
||||
# Read the CSV content
|
||||
csv_content = StringIO(response.text)
|
||||
reader = csv.DictReader(csv_content)
|
||||
|
||||
# Build the frequency dictionary
|
||||
total_count = 0
|
||||
|
||||
for row in reader:
|
||||
word = row["word"]
|
||||
count = int(row["count"])
|
||||
total_count += count
|
||||
ngrams[word] = count
|
||||
|
||||
# Convert to log probs
|
||||
return {word: math.log(count / total_count) for word, count in ngrams.items()}
|
||||
|
||||
def _is_form(self, local_pdf_path: str) -> bool:
|
||||
# Remove PDFs which are forms
|
||||
try:
|
||||
pdf_reader = PdfReader(local_pdf_path)
|
||||
if pdf_reader.get_form_text_fields():
|
||||
return True
|
||||
except PyPdfError as pex:
|
||||
logger.exception(pex)
|
||||
logger.warning("Invalid PDF, filtering out")
|
||||
return False
|
||||
except DependencyError as dex:
|
||||
logger.warning(f"PDF requires external dependency {dex}, filtering out")
|
||||
return False
|
||||
except Exception as ex:
|
||||
logger.exception(ex)
|
||||
logger.warning(f"Internal error reading PDF, filtering out")
|
||||
return False
|
||||
|
||||
# TODO: If distribution of _ characters is very high, it's probably a form
|
||||
|
||||
def _is_download_spam(self, base_text: str, threshold: float = 0.004) -> bool:
|
||||
def _is_download_spam(self, base_text: str) -> bool:
|
||||
seo_words = {
|
||||
"download",
|
||||
"pdf",
|
||||
@ -89,7 +53,6 @@ class PdfFilter:
|
||||
"cialis",
|
||||
"ciprofloxacin",
|
||||
}
|
||||
seo_word_probs = {word: self.ngram_log_probs[word] for word in seo_words}
|
||||
|
||||
base_text = base_text.strip().lower()
|
||||
clean_text = re.sub(r"\W+", " ", base_text)
|
||||
@ -99,13 +62,21 @@ class PdfFilter:
|
||||
|
||||
seo_score = sum(word_counts[word] for word in seo_words if word in word_counts)
|
||||
|
||||
return seo_score / total_words > threshold
|
||||
return (seo_score / total_words) > self.download_spam_threshold
|
||||
|
||||
# Returns True if there is something wrong with this PDF
|
||||
def filter_out_pdf(self, local_pdf_path: str) -> bool:
|
||||
# Basic metadata-level filtering
|
||||
if self._is_form(local_pdf_path):
|
||||
return False
|
||||
try:
|
||||
# Attempt to read the PDF at the beginning
|
||||
pdf_reader = PdfReader(local_pdf_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading PDF {local_pdf_path}: {e}")
|
||||
return True # Filter out the PDF if an exception occurs
|
||||
|
||||
# Form check
|
||||
if self.apply_form_check and self._is_form(pdf_reader):
|
||||
logger.info(f"Filtering out {local_pdf_path} because it's a form")
|
||||
return True # Filter out
|
||||
|
||||
# Read the first five pages of text for language calculation
|
||||
pdftotext_result = subprocess.run(
|
||||
@ -115,28 +86,24 @@ class PdfFilter:
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
if pdftotext_result.returncode != 0:
|
||||
logger.warn(f"pdftotext returned {pdftotext_result.returncode} on {local_pdf_path}")
|
||||
return False
|
||||
logger.warning(
|
||||
f"pdftotext returned {pdftotext_result.returncode} on {local_pdf_path}"
|
||||
)
|
||||
return True # Filter out
|
||||
|
||||
base_text = pdftotext_result.stdout.decode("utf-8")
|
||||
|
||||
# Other filter ideas:
|
||||
# - Remove patents, they tend to be ocred, multicolumn, and should come in through a cleaner dataset
|
||||
# - Detect things with too many figures
|
||||
# - Detect too many pages with no input
|
||||
# - Off distribution in terms of words per page, etc
|
||||
if len(base_text) < 100 or len(base_text.split()) < 50:
|
||||
logger.warn("PDF is too short, skipping")
|
||||
return False
|
||||
|
||||
# Language check
|
||||
language = self.language_detector.detect_language_of(base_text)
|
||||
if language not in self.languages_to_keep:
|
||||
logger.info(
|
||||
f"Filtering out {local_pdf_path} because language was {language}"
|
||||
)
|
||||
return True # Filter out
|
||||
|
||||
if language != Language.ENGLISH:
|
||||
logger.info(f"Filtering out {local_pdf_path} because language was {language}")
|
||||
return True
|
||||
|
||||
if self._is_download_spam(base_text):
|
||||
# Download spam check
|
||||
if self.apply_download_spam_check and self._is_download_spam(base_text):
|
||||
logger.info(f"Filtering out {local_pdf_path} because of SEO/download spam")
|
||||
return True
|
||||
return True # Filter out
|
||||
|
||||
return False
|
||||
return False # Keep the PDF
|
||||
|
||||
@ -7,11 +7,19 @@ from pdelfin.filter import PdfFilter
|
||||
|
||||
|
||||
class PdfFilterTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.filter = PdfFilter()
|
||||
|
||||
def testFormLaterPages(self):
|
||||
self.assertTrue(
|
||||
self.filter._is_form(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "form_on_later_pages.pdf"))
|
||||
)
|
||||
self.filter = PdfFilter(apply_form_check=True)
|
||||
|
||||
self.assertTrue(self.filter.filter_out_pdf(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"form_on_later_pages.pdf"
|
||||
)))
|
||||
|
||||
self.filter = PdfFilter(apply_form_check=False)
|
||||
|
||||
self.assertFalse(self.filter.filter_out_pdf(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"form_on_later_pages.pdf"
|
||||
)))
|
||||
Loading…
x
Reference in New Issue
Block a user