mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-18 19:38:33 +00:00
Lints
This commit is contained in:
parent
768cb33937
commit
41201b6317
@ -8,16 +8,19 @@ from olmocr.bench.prompts import (
|
||||
build_basic_prompt,
|
||||
build_openai_silver_data_prompt_no_document_anchoring,
|
||||
)
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
|
||||
from olmocr.data.renderpdf import (
|
||||
get_png_dimensions_from_base64,
|
||||
render_pdf_to_base64png,
|
||||
)
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.prompts.prompts import (
|
||||
PageResponse,
|
||||
build_finetuning_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
openai_response_format_schema,
|
||||
build_openai_silver_data_prompt_v2,
|
||||
build_openai_silver_data_prompt_v2_simple,
|
||||
build_openai_silver_data_prompt_v3_simple,
|
||||
openai_response_format_schema,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -8,14 +8,17 @@ and generates OpenAI batch API requests for processing PDFs.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Generator, Dict, Any, Optional, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
|
||||
from olmocr.data.renderpdf import (
|
||||
get_png_dimensions_from_base64,
|
||||
render_pdf_to_base64png,
|
||||
)
|
||||
from olmocr.prompts.prompts import (
|
||||
build_openai_silver_data_prompt_v3_simple,
|
||||
openai_response_format_schema,
|
||||
@ -60,7 +63,7 @@ def build_custom_id(pdf_path: Path, base_dir: Path) -> str:
|
||||
# Get relative path from base directory
|
||||
rel_path = pdf_path.relative_to(base_dir)
|
||||
# Remove .pdf extension but keep directory structure
|
||||
path_without_ext = str(rel_path).replace('.pdf', '')
|
||||
path_without_ext = str(rel_path).replace(".pdf", "")
|
||||
return path_without_ext
|
||||
|
||||
|
||||
@ -151,12 +154,7 @@ def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
|
||||
yield pdf_path
|
||||
|
||||
|
||||
def process_pdfs_to_batch_requests(
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
max_pdfs: int = None,
|
||||
num_workers: int = 8
|
||||
) -> int:
|
||||
def process_pdfs_to_batch_requests(input_dir: Path, output_dir: Path, max_pdfs: int = None, num_workers: int = 8) -> int:
|
||||
"""
|
||||
Process PDFs and create batch request files using parallel processing.
|
||||
|
||||
@ -196,10 +194,7 @@ def process_pdfs_to_batch_requests(
|
||||
# Process PDFs in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
# Submit all PDF processing tasks
|
||||
future_to_pdf = {
|
||||
executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path
|
||||
for pdf_path in pdf_files
|
||||
}
|
||||
future_to_pdf = {executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path for pdf_path in pdf_files}
|
||||
|
||||
# Process results as they complete
|
||||
with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar:
|
||||
@ -252,31 +247,14 @@ def process_pdfs_to_batch_requests(
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Build OpenAI batch requests from OLMoCR-mix folder structure"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory for batch request files (default: input_dir/batch_requests)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_pdfs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of PDFs to process (default: all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of parallel workers for processing (default: 8)"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Build OpenAI batch requests from OLMoCR-mix folder structure")
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Output directory for batch request files (default: input_dir/batch_requests)")
|
||||
parser.add_argument("--max_pdfs", type=int, default=None, help="Maximum number of PDFs to process (default: all)")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
|
||||
parser.add_argument(
|
||||
"input_dir",
|
||||
type=str,
|
||||
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)"
|
||||
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -298,12 +276,7 @@ def main():
|
||||
print(f"Output directory: {output_dir}")
|
||||
|
||||
# Process PDFs
|
||||
process_pdfs_to_batch_requests(
|
||||
input_dir=input_dir,
|
||||
output_dir=output_dir,
|
||||
max_pdfs=args.max_pdfs,
|
||||
num_workers=args.num_workers
|
||||
)
|
||||
process_pdfs_to_batch_requests(input_dir=input_dir, output_dir=output_dir, max_pdfs=args.max_pdfs, num_workers=args.num_workers)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@ -8,11 +8,12 @@ that mirrors the original structure with side-by-side PDF and MD files.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@ -39,10 +40,7 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
|
||||
content = body["choices"][0]["message"]["content"]
|
||||
# Parse the JSON response
|
||||
parsed_content = json.loads(content)
|
||||
return {
|
||||
"custom_id": custom_id,
|
||||
"content": parsed_content
|
||||
}
|
||||
return {"custom_id": custom_id, "content": parsed_content}
|
||||
else:
|
||||
print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}")
|
||||
return None
|
||||
@ -86,12 +84,7 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
|
||||
return markdown.strip()
|
||||
|
||||
|
||||
def process_single_result(
|
||||
custom_id: str,
|
||||
response_content: Dict[str, Any],
|
||||
original_pdf_dir: Path,
|
||||
output_dir: Path
|
||||
) -> bool:
|
||||
def process_single_result(custom_id: str, response_content: Dict[str, Any], original_pdf_dir: Path, output_dir: Path) -> bool:
|
||||
"""
|
||||
Process a single batch result: copy PDF and create MD file.
|
||||
|
||||
@ -114,8 +107,8 @@ def process_single_result(
|
||||
print(f"Warning: Original PDF not found: {original_pdf_path}")
|
||||
|
||||
original_pdf_path = str(original_pdf_path)
|
||||
pattern = r'(.+?)(-\d+)\.pdf$'
|
||||
replacement = r'\1.pdf\2.pdf'
|
||||
pattern = r"(.+?)(-\d+)\.pdf$"
|
||||
replacement = r"\1.pdf\2.pdf"
|
||||
|
||||
original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path))
|
||||
|
||||
@ -145,12 +138,7 @@ def process_single_result(
|
||||
return False
|
||||
|
||||
|
||||
def process_batch_results(
|
||||
batch_results_dir: Path,
|
||||
original_pdf_dir: Path,
|
||||
output_dir: Path,
|
||||
num_workers: int = 8
|
||||
) -> int:
|
||||
def process_batch_results(batch_results_dir: Path, original_pdf_dir: Path, output_dir: Path, num_workers: int = 8) -> int:
|
||||
"""
|
||||
Process all batch result files and create output structure.
|
||||
|
||||
@ -198,13 +186,7 @@ def process_batch_results(
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
# Submit all processing tasks
|
||||
future_to_result = {
|
||||
executor.submit(
|
||||
process_single_result,
|
||||
result["custom_id"],
|
||||
result["content"],
|
||||
original_pdf_dir,
|
||||
output_dir
|
||||
): result["custom_id"]
|
||||
executor.submit(process_single_result, result["custom_id"], result["content"], original_pdf_dir, output_dir): result["custom_id"]
|
||||
for result in results_to_process
|
||||
}
|
||||
|
||||
@ -234,30 +216,13 @@ def process_batch_results(
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process OpenAI batch results and create output folder with PDFs and Markdown files"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Process OpenAI batch results and create output folder with PDFs and Markdown files")
|
||||
parser.add_argument("batch_results_dir", type=str, help="Directory containing completed OpenAI batch result files (JSONL)")
|
||||
parser.add_argument(
|
||||
"batch_results_dir",
|
||||
type=str,
|
||||
help="Directory containing completed OpenAI batch result files (JSONL)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"original_pdf_dir",
|
||||
type=str,
|
||||
help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir",
|
||||
type=str,
|
||||
help="Output directory for processed results with PDFs and MD files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of parallel workers for processing (default: 8)"
|
||||
"original_pdf_dir", type=str, help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
|
||||
)
|
||||
parser.add_argument("output_dir", type=str, help="Output directory for processed results with PDFs and MD files")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -280,12 +245,7 @@ def main():
|
||||
print(f"Output directory: {output_dir}")
|
||||
|
||||
# Process the batch results
|
||||
process_batch_results(
|
||||
batch_results_dir=batch_results_dir,
|
||||
original_pdf_dir=original_pdf_dir,
|
||||
output_dir=output_dir,
|
||||
num_workers=args.num_workers
|
||||
)
|
||||
process_batch_results(batch_results_dir=batch_results_dir, original_pdf_dir=original_pdf_dir, output_dir=output_dir, num_workers=args.num_workers)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
|
||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||
)
|
||||
|
||||
|
||||
def build_openai_silver_data_prompt_v2(base_text: str) -> str:
|
||||
return (
|
||||
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). "
|
||||
@ -30,6 +31,7 @@ def build_openai_silver_data_prompt_v2(base_text: str) -> str:
|
||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||
)
|
||||
|
||||
|
||||
def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str:
|
||||
return (
|
||||
f"Attached is the image of one page of a PDF document."
|
||||
@ -44,6 +46,7 @@ def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int)
|
||||
f"Page width: {page_width}, Page height: {page_height}"
|
||||
)
|
||||
|
||||
|
||||
def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str:
|
||||
return (
|
||||
f"Attached is the image of one page of a PDF document."
|
||||
@ -60,7 +63,6 @@ def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PageResponse:
|
||||
primary_language: Optional[str]
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import re
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from html.parser import HTMLParser
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
@ -419,8 +423,6 @@ class LatexBracketNormalizer(PipelineStep):
|
||||
|
||||
# Update the page_data with normalized text
|
||||
# Since PageResponse is frozen, we need to create a new instance
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
new_page_data = PageResponse(
|
||||
primary_language=page_data.primary_language,
|
||||
is_rotation_valid=page_data.is_rotation_valid,
|
||||
@ -482,8 +484,6 @@ class RotationAugmentation(PipelineStep):
|
||||
else: # 270
|
||||
correction = 90
|
||||
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
new_page_data = PageResponse(
|
||||
primary_language=page_data.primary_language,
|
||||
is_rotation_valid=False, # Mark as invalid since we rotated it
|
||||
@ -539,20 +539,20 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
# Look for pipe-separated table patterns
|
||||
# Markdown tables have lines like: | col1 | col2 | col3 |
|
||||
# And separator lines like: |------|------|------|
|
||||
lines = text.split('\n')
|
||||
lines = text.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
# Check if line looks like a table row
|
||||
if line.startswith('|') and line.endswith('|') and line.count('|') >= 3:
|
||||
if line.startswith("|") and line.endswith("|") and line.count("|") >= 3:
|
||||
# Check if next line is a separator (for header rows)
|
||||
if i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('|') and '-' in next_line:
|
||||
if next_line.startswith("|") and "-" in next_line:
|
||||
return True
|
||||
# Check if previous line is a separator (for data rows)
|
||||
if i > 0:
|
||||
prev_line = lines[i - 1].strip()
|
||||
if prev_line.startswith('|') and '-' in prev_line:
|
||||
if prev_line.startswith("|") and "-" in prev_line:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -563,30 +563,74 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
True if text contains any of the specified math symbols outside tables
|
||||
False otherwise
|
||||
"""
|
||||
import re
|
||||
|
||||
# List of mathematical symbols to check for
|
||||
math_symbols = [
|
||||
# Set theory and logic
|
||||
'∈', '∉', '⊂', '⊃', '⊆', '⊇', '∅', '∪', '∩', '∀', '∃', '¬',
|
||||
"∈",
|
||||
"∉",
|
||||
"⊂",
|
||||
"⊃",
|
||||
"⊆",
|
||||
"⊇",
|
||||
"∅",
|
||||
"∪",
|
||||
"∩",
|
||||
"∀",
|
||||
"∃",
|
||||
"¬",
|
||||
# Common mathematical operators
|
||||
'⊕', '⊗', '⊙',
|
||||
"⊕",
|
||||
"⊗",
|
||||
"⊙",
|
||||
# Calculus and analysis
|
||||
'∂', '∇', '∆', '∫', '∬', '∭', '∮', '∏', '∑', '√', '∛', '∜',
|
||||
"∂",
|
||||
"∇",
|
||||
"∆",
|
||||
"∫",
|
||||
"∬",
|
||||
"∭",
|
||||
"∮",
|
||||
"∏",
|
||||
"∑",
|
||||
"√",
|
||||
"∛",
|
||||
"∜",
|
||||
# Arrows and relations
|
||||
'⊥',
|
||||
"⊥",
|
||||
# Other common math symbols
|
||||
'∠', '∡', '⊤', '⊢', '⊣', '∴', '∵', '∶', '∷', '∝', '≅', '≆', '≇', '≊', '≋',
|
||||
"∠",
|
||||
"∡",
|
||||
"⊤",
|
||||
"⊢",
|
||||
"⊣",
|
||||
"∴",
|
||||
"∵",
|
||||
"∶",
|
||||
"∷",
|
||||
"∝",
|
||||
"≅",
|
||||
"≆",
|
||||
"≇",
|
||||
"≊",
|
||||
"≋",
|
||||
# Matrix and vector notation
|
||||
'⊕', '⊖', '⊗', '⊘', '⊙', '⊚', '⊛', '⊜', '⊝',
|
||||
"⊕",
|
||||
"⊖",
|
||||
"⊗",
|
||||
"⊘",
|
||||
"⊙",
|
||||
"⊚",
|
||||
"⊛",
|
||||
"⊜",
|
||||
"⊝",
|
||||
]
|
||||
|
||||
# First, remove all HTML tables from the text
|
||||
text_without_tables = text
|
||||
|
||||
# Remove HTML tables
|
||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
||||
text_without_tables = table_pattern.sub('', text_without_tables)
|
||||
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||
text_without_tables = table_pattern.sub("", text_without_tables)
|
||||
|
||||
# Now check if any of these symbols appear in the text without tables
|
||||
for symbol in math_symbols:
|
||||
@ -606,8 +650,8 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
|
||||
# Check for various LaTeX table environments
|
||||
latex_table_patterns = [
|
||||
r'\\begin\{table\}',
|
||||
r'\\begin\{tabular\}',
|
||||
r"\\begin\{table\}",
|
||||
r"\\begin\{tabular\}",
|
||||
]
|
||||
|
||||
# Check if any LaTeX table pattern exists in the text
|
||||
@ -629,47 +673,42 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
# List of common LaTeX formatting commands to check for
|
||||
latex_commands = [
|
||||
# Lists & basic content
|
||||
r'\begin{itemize}',
|
||||
r'\begin{enumerate}',
|
||||
r'\item',
|
||||
|
||||
r"\begin{itemize}",
|
||||
r"\begin{enumerate}",
|
||||
r"\item",
|
||||
# Figures, tables, and captions
|
||||
r'\begin{figure}',
|
||||
r'\includegraphics',
|
||||
r'\caption',
|
||||
r'\label',
|
||||
r'\ref',
|
||||
r'\eqref',
|
||||
r'\begin{table}',
|
||||
r'\begin{tabular}',
|
||||
|
||||
r"\begin{figure}",
|
||||
r"\includegraphics",
|
||||
r"\caption",
|
||||
r"\label",
|
||||
r"\ref",
|
||||
r"\eqref",
|
||||
r"\begin{table}",
|
||||
r"\begin{tabular}",
|
||||
# Formatting,
|
||||
# r'\textit',
|
||||
# r'\textbb',
|
||||
|
||||
# Math (strong signals)
|
||||
r'\begin{equation}',
|
||||
r'\begin{align}',
|
||||
r'\frac',
|
||||
r'\sum',
|
||||
r'\int',
|
||||
r'\sqrt',
|
||||
r'\prod',
|
||||
r'\lim',
|
||||
r'\binom',
|
||||
r'\mathbb',
|
||||
r'\mathcal',
|
||||
r'\to',
|
||||
r'\varphi',
|
||||
r'\cdot',
|
||||
r'\langle',
|
||||
r'\rangle',
|
||||
|
||||
r"\begin{equation}",
|
||||
r"\begin{align}",
|
||||
r"\frac",
|
||||
r"\sum",
|
||||
r"\int",
|
||||
r"\sqrt",
|
||||
r"\prod",
|
||||
r"\lim",
|
||||
r"\binom",
|
||||
r"\mathbb",
|
||||
r"\mathcal",
|
||||
r"\to",
|
||||
r"\varphi",
|
||||
r"\cdot",
|
||||
r"\langle",
|
||||
r"\rangle",
|
||||
# Citations (bibliography stacks)
|
||||
r'\cite',
|
||||
r"\cite",
|
||||
]
|
||||
|
||||
|
||||
# First, remove all math equations from the text
|
||||
text_without_math = text
|
||||
|
||||
@ -682,7 +721,7 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
|
||||
# Remove all math equations
|
||||
for pattern in math_patterns:
|
||||
text_without_math = re.sub(pattern, '', text_without_math, flags=re.DOTALL)
|
||||
text_without_math = re.sub(pattern, "", text_without_math, flags=re.DOTALL)
|
||||
|
||||
# Check if any LaTeX commands appear in the remaining text
|
||||
for command in latex_commands:
|
||||
@ -730,7 +769,7 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
rendered = render_equation(equation)
|
||||
|
||||
# Check if there was an error
|
||||
if rendered is None or (hasattr(rendered, 'error') and rendered.error):
|
||||
if rendered is None or (hasattr(rendered, "error") and rendered.error):
|
||||
# Equation failed to render
|
||||
logger.warning(f"Could not render equation '{repr(equation)}', skipping sample")
|
||||
return False
|
||||
@ -757,22 +796,22 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
import re
|
||||
|
||||
# Check if there are any tables in the text
|
||||
if '<table' not in text.lower() or '<br' not in text.lower():
|
||||
if "<table" not in text.lower() or "<br" not in text.lower():
|
||||
return False # No tables or no <br> tags at all
|
||||
|
||||
# Pattern to find HTML tables (case-insensitive)
|
||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
||||
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||
tables = table_pattern.findall(text)
|
||||
|
||||
# Check each table for <br> tags in cells
|
||||
for table_html in tables:
|
||||
# Pattern to find table cells (td and th tags)
|
||||
cell_pattern = re.compile(r'<(td|th)\b[^>]*>(.*?)</\1>', re.IGNORECASE | re.DOTALL)
|
||||
cell_pattern = re.compile(r"<(td|th)\b[^>]*>(.*?)</\1>", re.IGNORECASE | re.DOTALL)
|
||||
cells = cell_pattern.findall(table_html)
|
||||
|
||||
for tag_type, cell_content in cells:
|
||||
# Check if cell content contains <br> tags (any variation)
|
||||
if re.search(r'<br\s*/?>', cell_content, re.IGNORECASE):
|
||||
if re.search(r"<br\s*/?>", cell_content, re.IGNORECASE):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -788,17 +827,17 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
import re
|
||||
|
||||
# Check if there are any <table> tags at all
|
||||
if '<table' not in text.lower():
|
||||
if "<table" not in text.lower():
|
||||
return True # No tables, that's fine
|
||||
|
||||
# Pattern to find HTML tables (case-insensitive)
|
||||
# Note: This pattern might not catch malformed tables where </table> is missing
|
||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
||||
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||
tables = table_pattern.findall(text)
|
||||
|
||||
# Also check for unclosed table tags
|
||||
table_open_count = len(re.findall(r'<table\b[^>]*>', text, re.IGNORECASE))
|
||||
table_close_count = len(re.findall(r'</table>', text, re.IGNORECASE))
|
||||
table_open_count = len(re.findall(r"<table\b[^>]*>", text, re.IGNORECASE))
|
||||
table_close_count = len(re.findall(r"</table>", text, re.IGNORECASE))
|
||||
|
||||
if table_open_count != table_close_count:
|
||||
return False # Mismatched table tags
|
||||
@ -808,7 +847,6 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
return False
|
||||
|
||||
# Try to parse each table
|
||||
from html.parser import HTMLParser
|
||||
|
||||
class TableValidator(HTMLParser):
|
||||
def __init__(self):
|
||||
@ -877,25 +915,25 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
if text is None:
|
||||
return sample
|
||||
|
||||
# Check for markdown tables
|
||||
if self._contains_markdown_table(text):
|
||||
return None # Filter out samples with markdown tables
|
||||
# # Check for markdown tables
|
||||
# if self._contains_markdown_table(text):
|
||||
# return None # Filter out samples with markdown tables
|
||||
|
||||
# Check for HTML tables and validate them
|
||||
if not self._extract_and_validate_html_tables(text):
|
||||
return None # Filter out samples with malformed HTML tables
|
||||
# # Check for HTML tables and validate them
|
||||
# if not self._extract_and_validate_html_tables(text):
|
||||
# return None # Filter out samples with malformed HTML tables
|
||||
|
||||
# Check for <br> tags in table cells
|
||||
if self._contains_br_in_table_cells(text):
|
||||
return None # Filter out samples with <br> tags in table cells
|
||||
# # Check for <br> tags in table cells
|
||||
# if self._contains_br_in_table_cells(text):
|
||||
# return None # Filter out samples with <br> tags in table cells
|
||||
|
||||
# Check if all math equations can render without errors
|
||||
if not self._validate_math_equations(text):
|
||||
return None # Filter out samples with invalid math equations
|
||||
# # Check if all math equations can render without errors
|
||||
# if not self._validate_math_equations(text):
|
||||
# return None # Filter out samples with invalid math equations
|
||||
|
||||
# Check for mathematical symbols
|
||||
if self._contains_math_symbols(text):
|
||||
return None # Filter out samples with mathematical symbols
|
||||
# # Check for mathematical symbols
|
||||
# if self._contains_math_symbols(text):
|
||||
# return None # Filter out samples with mathematical symbols
|
||||
|
||||
# Check for LaTeX formatting outside math equations
|
||||
if self._contains_latex_formatting_outside_math(text):
|
||||
@ -958,7 +996,8 @@ class ReformatLatexBoldItalic(PipelineStep):
|
||||
def replace_latex_command(text, command, markdown):
|
||||
"""Replace LaTeX command with markdown, handling nested braces."""
|
||||
import re
|
||||
pattern = r'\\' + command + r'\{'
|
||||
|
||||
pattern = r"\\" + command + r"\{"
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
@ -977,9 +1016,9 @@ class ReformatLatexBoldItalic(PipelineStep):
|
||||
j = start_pos
|
||||
|
||||
while j < len(text) and brace_count > 0:
|
||||
if text[j] == '{':
|
||||
if text[j] == "{":
|
||||
brace_count += 1
|
||||
elif text[j] == '}':
|
||||
elif text[j] == "}":
|
||||
brace_count -= 1
|
||||
j += 1
|
||||
|
||||
@ -993,20 +1032,20 @@ class ReformatLatexBoldItalic(PipelineStep):
|
||||
result.append(text[i + match.start() : i + match.end()])
|
||||
i = i + match.end()
|
||||
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
# Handle \textbf{...} -> **...**
|
||||
preserved_text = replace_latex_command(preserved_text, 'textbf', '**')
|
||||
preserved_text = replace_latex_command(preserved_text, "textbf", "**")
|
||||
|
||||
# Handle \textit{...} -> *...*
|
||||
preserved_text = replace_latex_command(preserved_text, 'textit', '*')
|
||||
preserved_text = replace_latex_command(preserved_text, "textit", "*")
|
||||
|
||||
# Restore math equations
|
||||
for placeholder, original in math_placeholders:
|
||||
preserved_text = preserved_text.replace(placeholder, original)
|
||||
|
||||
# Create a new PageResponse with the updated text (since it's frozen)
|
||||
from dataclasses import replace
|
||||
|
||||
updated_page_data = replace(page_data, natural_text=preserved_text)
|
||||
sample["page_data"] = updated_page_data
|
||||
|
||||
@ -1414,8 +1453,8 @@ if __name__ == "__main__":
|
||||
if sample is None:
|
||||
# This sample was filtered out - get the original paths
|
||||
original_sample = dataset_samples[idx]
|
||||
md_path = original_sample['markdown_path']
|
||||
pdf_path = original_sample['pdf_path']
|
||||
md_path = original_sample["markdown_path"]
|
||||
pdf_path = original_sample["pdf_path"]
|
||||
|
||||
save_dir = Path(save_dir_str)
|
||||
|
||||
@ -1433,11 +1472,7 @@ if __name__ == "__main__":
|
||||
target_pdf = target_subdir / pdf_path.name
|
||||
shutil.copy2(pdf_path, target_pdf)
|
||||
|
||||
return {
|
||||
'index': idx,
|
||||
'markdown_path': str(md_path),
|
||||
'pdf_path': str(pdf_path)
|
||||
}
|
||||
return {"index": idx, "markdown_path": str(md_path), "pdf_path": str(pdf_path)}
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error processing sample {idx}: {e}")
|
||||
@ -1449,10 +1484,7 @@ if __name__ == "__main__":
|
||||
|
||||
with ProcessPoolExecutor(max_workers=8) as executor:
|
||||
# Submit all tasks
|
||||
futures = {
|
||||
executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx
|
||||
for idx in range(len(dataset))
|
||||
}
|
||||
futures = {executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx for idx in range(len(dataset))}
|
||||
|
||||
# Process results with progress bar
|
||||
with tqdm(total=len(dataset), desc="Processing samples") as pbar:
|
||||
@ -1463,14 +1495,14 @@ if __name__ == "__main__":
|
||||
pbar.update(1)
|
||||
|
||||
# Sort filtered samples by index for consistent output
|
||||
filtered_samples.sort(key=lambda x: x['index'])
|
||||
filtered_samples.sort(key=lambda x: x["index"])
|
||||
|
||||
print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}")
|
||||
|
||||
if filtered_samples:
|
||||
print(f"First 10 filtered samples:")
|
||||
for i, sample_info in enumerate(filtered_samples[:10]):
|
||||
md_name = Path(sample_info['markdown_path']).name
|
||||
md_name = Path(sample_info["markdown_path"]).name
|
||||
print(f" Sample {sample_info['index']}: {md_name}")
|
||||
if len(filtered_samples) > 10:
|
||||
print(f" ... and {len(filtered_samples) - 10} more")
|
||||
|
||||
@ -430,10 +430,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
}
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(
|
||||
result["page_data"].natural_text,
|
||||
"**First** and **second** bold, *first* and *second* italic."
|
||||
)
|
||||
self.assertEqual(result["page_data"].natural_text, "**First** and **second** bold, *first* and *second* italic.")
|
||||
|
||||
def test_latex_in_parenthesis_delimiter(self):
|
||||
"""Test LaTeX preserved in \\(...\\) math delimiter."""
|
||||
@ -712,7 +709,7 @@ class TestFrontMatterParser(unittest.TestCase):
|
||||
self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse)
|
||||
self.parser_without_class = FrontMatterParser(front_matter_class=None)
|
||||
|
||||
@patch.object(Path, 'read_text')
|
||||
@patch.object(Path, "read_text")
|
||||
def test_parse_yaml_front_matter(self, mock_read_text):
|
||||
"""Test parsing of YAML front matter."""
|
||||
mock_read_text.return_value = """---
|
||||
@ -733,7 +730,7 @@ This is the document content.
|
||||
self.assertEqual(result["page_data"].primary_language, "en")
|
||||
self.assertEqual(result["page_data"].natural_text, "This is the document content.")
|
||||
|
||||
@patch.object(Path, 'read_text')
|
||||
@patch.object(Path, "read_text")
|
||||
def test_no_front_matter(self, mock_read_text):
|
||||
"""Test handling of documents without front matter."""
|
||||
mock_read_text.return_value = "Just regular content without front matter."
|
||||
@ -744,7 +741,7 @@ This is the document content.
|
||||
with self.assertRaises(ValueError):
|
||||
self.parser_with_class(sample)
|
||||
|
||||
@patch.object(Path, 'read_text')
|
||||
@patch.object(Path, "read_text")
|
||||
def test_malformed_yaml(self, mock_read_text):
|
||||
"""Test handling of malformed YAML."""
|
||||
mock_read_text.return_value = """---
|
||||
@ -760,7 +757,7 @@ Content
|
||||
result = self.parser_without_class(sample)
|
||||
self.assertEqual(result["page_data"], {})
|
||||
|
||||
@patch.object(Path, 'read_text')
|
||||
@patch.object(Path, "read_text")
|
||||
def test_preserve_existing_markdown_content(self, mock_read_text):
|
||||
"""Test that existing markdown_content is preserved if present."""
|
||||
sample = {
|
||||
@ -772,7 +769,7 @@ rotation_correction: 0
|
||||
is_table: true
|
||||
is_diagram: false
|
||||
---
|
||||
French content."""
|
||||
French content.""",
|
||||
}
|
||||
|
||||
# Should not call read_text since markdown_content exists
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user