This commit is contained in:
Jake Poznanski 2025-08-19 21:30:41 +00:00
parent 768cb33937
commit 41201b6317
6 changed files with 419 additions and 452 deletions

View File

@ -8,16 +8,19 @@ from olmocr.bench.prompts import (
build_basic_prompt,
build_openai_silver_data_prompt_no_document_anchoring,
)
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
from olmocr.data.renderpdf import (
get_png_dimensions_from_base64,
render_pdf_to_base64png,
)
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import (
PageResponse,
build_finetuning_prompt,
build_openai_silver_data_prompt,
openai_response_format_schema,
build_openai_silver_data_prompt_v2,
build_openai_silver_data_prompt_v2_simple,
build_openai_silver_data_prompt_v3_simple,
openai_response_format_schema,
)
@ -65,7 +68,7 @@ def run_chatgpt(
prompt = build_openai_silver_data_prompt_v2_simple(width, height)
elif prompt_template == "fullv3simple":
width, height = get_png_dimensions_from_base64(image_base64)
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
else:
raise ValueError("Unknown prompt template")
@ -82,7 +85,7 @@ def run_chatgpt(
],
temperature=temperature,
max_completion_tokens=20000,
#reasoning_effort="high",
# reasoning_effort="high",
response_format=openai_response_format_schema() if response_template == "json" else None,
safety_identifier="olmocr-bench-runner",
)

View File

@ -8,14 +8,17 @@ and generates OpenAI batch API requests for processing PDFs.
import argparse
import json
import os
from pathlib import Path
from typing import Generator, Dict, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple
from pypdf import PdfReader
from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
from olmocr.data.renderpdf import (
get_png_dimensions_from_base64,
render_pdf_to_base64png,
)
from olmocr.prompts.prompts import (
build_openai_silver_data_prompt_v3_simple,
openai_response_format_schema,
@ -28,10 +31,10 @@ MAX_FILE_SIZE = 99 * 1024 * 1024 # 99MB in bytes
def validate_single_page_pdf(pdf_path: Path) -> bool:
"""
Validate that a PDF has exactly one page.
Args:
pdf_path: Path to the PDF file
Returns:
True if PDF has exactly one page, False otherwise
"""
@ -46,32 +49,32 @@ def validate_single_page_pdf(pdf_path: Path) -> bool:
def build_custom_id(pdf_path: Path, base_dir: Path) -> str:
"""
Build a custom ID for the request that can be used to recover the file later.
The ID preserves the full path structure for easy recovery.
Example: extracted/document_id.pdf becomes "extracted/document_id"
Args:
pdf_path: Full path to the PDF file
base_dir: Base directory containing the processed folder
Returns:
Custom ID string that preserves path structure
"""
# Get relative path from base directory
rel_path = pdf_path.relative_to(base_dir)
# Remove .pdf extension but keep directory structure
path_without_ext = str(rel_path).replace('.pdf', '')
path_without_ext = str(rel_path).replace(".pdf", "")
return path_without_ext
def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[str, Any], Path]]:
"""
Process a single PDF and return the batch request if valid.
Args:
pdf_path: Path to the PDF file
base_dir: Base directory for building custom IDs
Returns:
Tuple of (request dict, pdf_path) if successful, None otherwise
"""
@ -83,20 +86,20 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
except Exception as e:
print(f"Error reading PDF {pdf_path}: {e}")
return None
try:
# Render PDF to base64 image
image_base64 = render_pdf_to_base64png(str(pdf_path), page_num=1, target_longest_image_dim=TARGET_IMAGE_DIM)
# Get image dimensions for the prompt
width, height = get_png_dimensions_from_base64(image_base64)
# Build the prompt using v3 simple version
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
# Build custom ID
custom_id = build_custom_id(pdf_path, base_dir)
# Build the request in OpenAI batch format
request = {
"custom_id": custom_id,
@ -118,7 +121,7 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
"response_format": openai_response_format_schema(),
},
}
return (request, pdf_path)
except Exception as e:
print(f"Error processing {pdf_path}: {e}")
@ -128,21 +131,21 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
"""
Find all PDF files in the processed folder structure.
The structure is expected to be:
processed_XX_subset_split/
extracted/
*.pdf
Or for hugging_face downloads:
hugging_face/
pdf_tarballs/
extracted/
*.pdf
Args:
input_dir: Input directory path
Yields:
Path objects for each PDF file found
"""
@ -151,64 +154,56 @@ def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
yield pdf_path
def process_pdfs_to_batch_requests(
input_dir: Path,
output_dir: Path,
max_pdfs: int = None,
num_workers: int = 8
) -> int:
def process_pdfs_to_batch_requests(input_dir: Path, output_dir: Path, max_pdfs: int = None, num_workers: int = 8) -> int:
"""
Process PDFs and create batch request files using parallel processing.
Args:
input_dir: Directory containing the processed folder structure
output_dir: Directory to save batch request files
max_pdfs: Maximum number of PDFs to process (None for all)
num_workers: Number of parallel workers for processing
Returns:
Number of PDFs processed
"""
# Ensure output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
# Initialize file management
file_num = 0
current_file_size = 0
current_file_path = output_dir / f"batch_requests_{file_num:04d}.jsonl"
current_file = open(current_file_path, "w")
pdfs_processed = 0
pdfs_skipped = 0
# Find PDF files
pdf_files = list(find_pdf_files(input_dir))
# Limit files if max_pdfs is specified
if max_pdfs:
pdf_files = pdf_files[:max_pdfs]
total_pdfs = len(pdf_files)
print(f"Found {total_pdfs} PDF files to process")
print(f"Using {num_workers} parallel workers")
# Process PDFs in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all PDF processing tasks
future_to_pdf = {
executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path
for pdf_path in pdf_files
}
future_to_pdf = {executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path for pdf_path in pdf_files}
# Process results as they complete
with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar:
for future in as_completed(future_to_pdf):
pdf_path = future_to_pdf[future]
try:
result = future.result()
if result is None:
# PDF was skipped (multi-page or error)
pdfs_skipped += 1
@ -216,7 +211,7 @@ def process_pdfs_to_batch_requests(
request, _ = result
request_json = json.dumps(request)
request_size = len(request_json.encode("utf-8"))
# Check if we need to start a new file
if current_file_size + request_size > MAX_FILE_SIZE:
current_file.close()
@ -225,88 +220,66 @@ def process_pdfs_to_batch_requests(
current_file = open(current_file_path, "w")
current_file_size = 0
print(f"\nStarting new batch file: {current_file_path.name}")
# Write the request (only in main thread)
current_file.write(request_json)
current_file.write("\n")
current_file_size += request_size
pdfs_processed += 1
except Exception as e:
print(f"\nError with {pdf_path}: {e}")
pdfs_skipped += 1
pbar.update(1)
# Close the last file
current_file.close()
print(f"\nProcessing complete:")
print(f" - PDFs processed: {pdfs_processed}")
print(f" - PDFs skipped: {pdfs_skipped}")
print(f" - Batch files created: {file_num + 1}")
print(f" - Output directory: {output_dir}")
return pdfs_processed
def main():
parser = argparse.ArgumentParser(
description="Build OpenAI batch requests from OLMoCR-mix folder structure"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Output directory for batch request files (default: input_dir/batch_requests)"
)
parser.add_argument(
"--max_pdfs",
type=int,
default=None,
help="Maximum number of PDFs to process (default: all)"
)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="Number of parallel workers for processing (default: 8)"
)
parser = argparse.ArgumentParser(description="Build OpenAI batch requests from OLMoCR-mix folder structure")
parser.add_argument("--output_dir", type=str, default=None, help="Output directory for batch request files (default: input_dir/batch_requests)")
parser.add_argument("--max_pdfs", type=int, default=None, help="Maximum number of PDFs to process (default: all)")
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
parser.add_argument(
"input_dir",
type=str,
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)"
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)",
)
args = parser.parse_args()
# Convert paths to Path objects
input_dir = Path(args.input_dir).expanduser().resolve()
if not input_dir.exists():
print(f"Error: Input directory does not exist: {input_dir}")
return 1
# Set default output directory if not specified
if args.output_dir:
output_dir = Path(args.output_dir).expanduser().resolve()
else:
output_dir = input_dir / "batch_requests"
print(f"Input directory: {input_dir}")
print(f"Output directory: {output_dir}")
# Process PDFs
process_pdfs_to_batch_requests(
input_dir=input_dir,
output_dir=output_dir,
max_pdfs=args.max_pdfs,
num_workers=args.num_workers
)
process_pdfs_to_batch_requests(input_dir=input_dir, output_dir=output_dir, max_pdfs=args.max_pdfs, num_workers=args.num_workers)
return 0
if __name__ == "__main__":
exit(main())
exit(main())

View File

@ -8,30 +8,31 @@ that mirrors the original structure with side-by-side PDF and MD files.
import argparse
import json
import shutil
import re
from pathlib import Path
from typing import Dict, Any, Optional
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, Optional
from tqdm import tqdm
def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
"""
Parse a single line from the batch response file.
Args:
response_line: JSON line from batch response file
Returns:
Parsed response dictionary or None if error
"""
try:
data = json.loads(response_line)
# Extract the custom_id and response
custom_id = data.get("custom_id")
# Check if the response was successful
if "response" in data and data["response"].get("status_code") == 200:
body = data["response"]["body"]
@ -39,14 +40,11 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
content = body["choices"][0]["message"]["content"]
# Parse the JSON response
parsed_content = json.loads(content)
return {
"custom_id": custom_id,
"content": parsed_content
}
return {"custom_id": custom_id, "content": parsed_content}
else:
print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}")
return None
except Exception as e:
print(f"Error parsing response line: {e}")
return None
@ -55,10 +53,10 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
"""
Format the response data as FrontMatter markdown.
Args:
response_data: Parsed response data from OpenAI
Returns:
Formatted markdown string with FrontMatter
"""
@ -69,7 +67,7 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
is_table = response_data.get("is_table", False)
is_diagram = response_data.get("is_diagram", False)
natural_text = response_data.get("natural_text", "")
# Format as FrontMatter
markdown = "---\n"
markdown += f"primary_language: {primary_language if primary_language else 'null'}\n"
@ -78,29 +76,24 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
markdown += f"is_table: {str(is_table)}\n"
markdown += f"is_diagram: {str(is_diagram)}\n"
markdown += "---\n"
# Add the natural text content
if natural_text:
markdown += natural_text
return markdown.strip()
def process_single_result(
custom_id: str,
response_content: Dict[str, Any],
original_pdf_dir: Path,
output_dir: Path
) -> bool:
def process_single_result(custom_id: str, response_content: Dict[str, Any], original_pdf_dir: Path, output_dir: Path) -> bool:
"""
Process a single batch result: copy PDF and create MD file.
Args:
custom_id: Custom ID from the batch request
response_content: Parsed response content
original_pdf_dir: Directory containing original PDFs
output_dir: Output directory for results
Returns:
True if successful, False otherwise
"""
@ -109,75 +102,70 @@ def process_single_result(
# Custom ID format: "folder/filename" (without .pdf)
pdf_relative_path = f"{custom_id}.pdf"
original_pdf_path = original_pdf_dir / pdf_relative_path
if not original_pdf_path.exists():
print(f"Warning: Original PDF not found: {original_pdf_path}")
original_pdf_path = str(original_pdf_path)
pattern = r'(.+?)(-\d+)\.pdf$'
replacement = r'\1.pdf\2.pdf'
pattern = r"(.+?)(-\d+)\.pdf$"
replacement = r"\1.pdf\2.pdf"
original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path))
if not original_pdf_path.exists():
print(f"Error: Original PDF not found: {original_pdf_path}")
return False
# Create output paths
output_pdf_path = output_dir / pdf_relative_path
output_md_path = output_dir / f"{custom_id}.md"
# Create parent directories if needed
output_pdf_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the PDF file
shutil.copy2(original_pdf_path, output_pdf_path)
# Create the markdown file
markdown_content = format_frontmatter_markdown(response_content)
with open(output_md_path, "w", encoding="utf-8") as f:
f.write(markdown_content)
return True
except Exception as e:
print(f"Error processing {custom_id}: {e}")
return False
def process_batch_results(
batch_results_dir: Path,
original_pdf_dir: Path,
output_dir: Path,
num_workers: int = 8
) -> int:
def process_batch_results(batch_results_dir: Path, original_pdf_dir: Path, output_dir: Path, num_workers: int = 8) -> int:
"""
Process all batch result files and create output structure.
Args:
batch_results_dir: Directory containing batch result JSONL files
original_pdf_dir: Directory containing original PDFs
output_dir: Output directory for processed results
num_workers: Number of parallel workers
Returns:
Number of successfully processed results
"""
# Ensure output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
# Find all batch result files (both .jsonl and .json)
batch_files = list(batch_results_dir.glob("*.jsonl")) + list(batch_results_dir.glob("*.json"))
if not batch_files:
print(f"No batch result files found in {batch_results_dir}")
return 0
print(f"Found {len(batch_files)} batch result files")
# Collect all results to process
results_to_process = []
for batch_file in batch_files:
print(f"Reading {batch_file.name}...")
with open(batch_file, "r") as f:
@ -186,33 +174,27 @@ def process_batch_results(
parsed = parse_batch_response(line)
if parsed:
results_to_process.append(parsed)
total_results = len(results_to_process)
print(f"Found {total_results} valid results to process")
print(f"Using {num_workers} parallel workers")
successful = 0
failed = 0
# Process results in parallel
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all processing tasks
future_to_result = {
executor.submit(
process_single_result,
result["custom_id"],
result["content"],
original_pdf_dir,
output_dir
): result["custom_id"]
executor.submit(process_single_result, result["custom_id"], result["content"], original_pdf_dir, output_dir): result["custom_id"]
for result in results_to_process
}
# Process results as they complete
with tqdm(total=total_results, desc="Processing results") as pbar:
for future in as_completed(future_to_result):
custom_id = future_to_result[future]
try:
success = future.result()
if success:
@ -222,73 +204,51 @@ def process_batch_results(
except Exception as e:
print(f"\nError with {custom_id}: {e}")
failed += 1
pbar.update(1)
print(f"\nProcessing complete:")
print(f" - Successfully processed: {successful}")
print(f" - Failed: {failed}")
print(f" - Output directory: {output_dir}")
return successful
def main():
parser = argparse.ArgumentParser(
description="Process OpenAI batch results and create output folder with PDFs and Markdown files"
)
parser = argparse.ArgumentParser(description="Process OpenAI batch results and create output folder with PDFs and Markdown files")
parser.add_argument("batch_results_dir", type=str, help="Directory containing completed OpenAI batch result files (JSONL)")
parser.add_argument(
"batch_results_dir",
type=str,
help="Directory containing completed OpenAI batch result files (JSONL)"
"original_pdf_dir", type=str, help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
)
parser.add_argument(
"original_pdf_dir",
type=str,
help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
)
parser.add_argument(
"output_dir",
type=str,
help="Output directory for processed results with PDFs and MD files"
)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="Number of parallel workers for processing (default: 8)"
)
parser.add_argument("output_dir", type=str, help="Output directory for processed results with PDFs and MD files")
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
args = parser.parse_args()
# Convert paths to Path objects
batch_results_dir = Path(args.batch_results_dir).expanduser().resolve()
original_pdf_dir = Path(args.original_pdf_dir).expanduser().resolve()
output_dir = Path(args.output_dir).expanduser().resolve()
# Validate input directories
if not batch_results_dir.exists():
print(f"Error: Batch results directory does not exist: {batch_results_dir}")
return 1
if not original_pdf_dir.exists():
print(f"Error: Original PDF directory does not exist: {original_pdf_dir}")
return 1
print(f"Batch results directory: {batch_results_dir}")
print(f"Original PDF directory: {original_pdf_dir}")
print(f"Output directory: {output_dir}")
# Process the batch results
process_batch_results(
batch_results_dir=batch_results_dir,
original_pdf_dir=original_pdf_dir,
output_dir=output_dir,
num_workers=args.num_workers
)
process_batch_results(batch_results_dir=batch_results_dir, original_pdf_dir=original_pdf_dir, output_dir=output_dir, num_workers=args.num_workers)
return 0
if __name__ == "__main__":
exit(main())
exit(main())

View File

@ -16,6 +16,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
def build_openai_silver_data_prompt_v2(base_text: str) -> str:
return (
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). "
@ -30,6 +31,7 @@ def build_openai_silver_data_prompt_v2(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str:
return (
f"Attached is the image of one page of a PDF document."
@ -44,6 +46,7 @@ def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int)
f"Page width: {page_width}, Page height: {page_height}"
)
def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str:
return (
f"Attached is the image of one page of a PDF document."
@ -60,7 +63,6 @@ def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int)
)
@dataclass(frozen=True)
class PageResponse:
primary_language: Optional[str]

View File

@ -1,10 +1,14 @@
import argparse
import base64
import json
import logging
import multiprocessing
import re
import shutil
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, replace
from html.parser import HTMLParser
from io import BytesIO
from os import PathLike
from pathlib import Path
@ -419,8 +423,6 @@ class LatexBracketNormalizer(PipelineStep):
# Update the page_data with normalized text
# Since PageResponse is frozen, we need to create a new instance
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse(
primary_language=page_data.primary_language,
is_rotation_valid=page_data.is_rotation_valid,
@ -482,8 +484,6 @@ class RotationAugmentation(PipelineStep):
else: # 270
correction = 90
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse(
primary_language=page_data.primary_language,
is_rotation_valid=False, # Mark as invalid since we rotated it
@ -523,7 +523,7 @@ class FilterOutRotatedDocuments(PipelineStep):
@dataclass(frozen=True, slots=True)
class DatasetTextRuleFilter(PipelineStep):
"""Pipeline step that filters samples based on text content rules.
Filters out samples that:
- Contain markdown tables
- Contain malformed HTML tables
@ -539,205 +539,244 @@ class DatasetTextRuleFilter(PipelineStep):
# Look for pipe-separated table patterns
# Markdown tables have lines like: | col1 | col2 | col3 |
# And separator lines like: |------|------|------|
lines = text.split('\n')
lines = text.split("\n")
for i, line in enumerate(lines):
line = line.strip()
# Check if line looks like a table row
if line.startswith('|') and line.endswith('|') and line.count('|') >= 3:
if line.startswith("|") and line.endswith("|") and line.count("|") >= 3:
# Check if next line is a separator (for header rows)
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
if next_line.startswith('|') and '-' in next_line:
if next_line.startswith("|") and "-" in next_line:
return True
# Check if previous line is a separator (for data rows)
if i > 0:
prev_line = lines[i - 1].strip()
if prev_line.startswith('|') and '-' in prev_line:
if prev_line.startswith("|") and "-" in prev_line:
return True
return False
def _contains_math_symbols(self, text: str) -> bool:
"""Check if text contains specific mathematical symbols outside of table cells.
Returns:
True if text contains any of the specified math symbols outside tables
False otherwise
"""
import re
# List of mathematical symbols to check for
math_symbols = [
# Set theory and logic
'', '', '', '', '', '', '', '', '', '', '', '¬',
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"¬",
# Common mathematical operators
'', '', '',
"",
"",
"",
# Calculus and analysis
'', '', '', '', '', '', '', '', '', '', '', '',
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
# Arrows and relations
'',
"",
# Other common math symbols
'', '', '', '', '', '', '', '', '', '', '', '', '', '', '',
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
# Matrix and vector notation
'', '', '', '', '', '', '', '', '',
"",
"",
"",
"",
"",
"",
"",
"",
"",
]
# First, remove all HTML tables from the text
text_without_tables = text
# Remove HTML tables
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
text_without_tables = table_pattern.sub('', text_without_tables)
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
text_without_tables = table_pattern.sub("", text_without_tables)
# Now check if any of these symbols appear in the text without tables
for symbol in math_symbols:
if symbol in text_without_tables:
return True
return False
def _contains_latex_tables(self, text: str) -> bool:
"""Check if text contains LaTeX table environments.
Returns:
True if text contains LaTeX tables (\\begin{table}, \\begin{tabular}, etc.)
False otherwise
"""
import re
# Check for various LaTeX table environments
latex_table_patterns = [
r'\\begin\{table\}',
r'\\begin\{tabular\}',
r"\\begin\{table\}",
r"\\begin\{tabular\}",
]
# Check if any LaTeX table pattern exists in the text
for pattern in latex_table_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _contains_latex_formatting_outside_math(self, text: str) -> bool:
"""Check if text contains LaTeX formatting commands outside of math equations.
Returns:
True if text contains LaTeX formatting commands outside math equations
False otherwise
"""
import re
# List of common LaTeX formatting commands to check for
latex_commands = [
# Lists & basic content
r'\begin{itemize}',
r'\begin{enumerate}',
r'\item',
r"\begin{itemize}",
r"\begin{enumerate}",
r"\item",
# Figures, tables, and captions
r'\begin{figure}',
r'\includegraphics',
r'\caption',
r'\label',
r'\ref',
r'\eqref',
r'\begin{table}',
r'\begin{tabular}',
r"\begin{figure}",
r"\includegraphics",
r"\caption",
r"\label",
r"\ref",
r"\eqref",
r"\begin{table}",
r"\begin{tabular}",
# Formatting,
# r'\textit',
# r'\textbb',
# Math (strong signals)
r'\begin{equation}',
r'\begin{align}',
r'\frac',
r'\sum',
r'\int',
r'\sqrt',
r'\prod',
r'\lim',
r'\binom',
r'\mathbb',
r'\mathcal',
r'\to',
r'\varphi',
r'\cdot',
r'\langle',
r'\rangle',
r"\begin{equation}",
r"\begin{align}",
r"\frac",
r"\sum",
r"\int",
r"\sqrt",
r"\prod",
r"\lim",
r"\binom",
r"\mathbb",
r"\mathcal",
r"\to",
r"\varphi",
r"\cdot",
r"\langle",
r"\rangle",
# Citations (bibliography stacks)
r'\cite',
r"\cite",
]
# First, remove all math equations from the text
text_without_math = text
# Patterns for math equations
math_patterns = [
r"\$\$(.+?)\$\$", # $$...$$
r"\\\((.+?)\\\)", # \(...\)
r"\\\[(.+?)\\\]", # \[...\]
]
# Remove all math equations
for pattern in math_patterns:
text_without_math = re.sub(pattern, '', text_without_math, flags=re.DOTALL)
text_without_math = re.sub(pattern, "", text_without_math, flags=re.DOTALL)
# Check if any LaTeX commands appear in the remaining text
for command in latex_commands:
if command in text_without_math:
return True
return False
def _validate_math_equations(self, text: str) -> bool:
"""Check if all math equations in the text can render without errors.
Returns:
True if all equations render successfully or no equations exist
False if any equation fails to render
"""
import re
# Patterns to find math equations (same as in MathTest)
patterns = [
r"\$\$(.+?)\$\$", # $$...$$
r"\\\((.+?)\\\)", # \(...\)
r"\\\[(.+?)\\\]", # \[...\]
]
equations = []
for pattern in patterns:
# Find all matches for the current pattern
matches = re.findall(pattern, text, re.DOTALL)
equations.extend([eq.strip() for eq in matches])
# If no equations found, that's fine
if not equations:
return True
# Try to render each equation
try:
from olmocr.bench.katex.render import render_equation
for equation in equations:
# Skip empty or whitespace-only equations
if not equation or not equation.strip():
continue
# Try to render the equation
rendered = render_equation(equation)
# Check if there was an error
if rendered is None or (hasattr(rendered, 'error') and rendered.error):
if rendered is None or (hasattr(rendered, "error") and rendered.error):
# Equation failed to render
logger.warning(f"Could not render equation '{repr(equation)}', skipping sample")
return False
# All equations rendered successfully
return True
except ImportError:
# If we can't import the render module, skip this check
# This allows the filter to work even without the rendering dependencies
@ -746,87 +785,86 @@ class DatasetTextRuleFilter(PipelineStep):
# If any unexpected error occurs during validation, be conservative and filter out
print(f"Error validating math equations: {e}")
return False
def _contains_br_in_table_cells(self, text: str) -> bool:
"""Check if text contains <br> tags within HTML table cells.
Returns:
True if any table cell contains <br> tags
False otherwise
"""
import re
# Check if there are any tables in the text
if '<table' not in text.lower() or '<br' not in text.lower():
if "<table" not in text.lower() or "<br" not in text.lower():
return False # No tables or no <br> tags at all
# Pattern to find HTML tables (case-insensitive)
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
tables = table_pattern.findall(text)
# Check each table for <br> tags in cells
for table_html in tables:
# Pattern to find table cells (td and th tags)
cell_pattern = re.compile(r'<(td|th)\b[^>]*>(.*?)</\1>', re.IGNORECASE | re.DOTALL)
cell_pattern = re.compile(r"<(td|th)\b[^>]*>(.*?)</\1>", re.IGNORECASE | re.DOTALL)
cells = cell_pattern.findall(table_html)
for tag_type, cell_content in cells:
# Check if cell content contains <br> tags (any variation)
if re.search(r'<br\s*/?>', cell_content, re.IGNORECASE):
if re.search(r"<br\s*/?>", cell_content, re.IGNORECASE):
return True
return False
def _extract_and_validate_html_tables(self, text: str) -> bool:
"""Extract HTML tables and validate they parse correctly.
Returns:
True if all HTML tables are valid or no tables exist
False if any HTML table is malformed
"""
# Find all HTML table blocks
import re
# Check if there are any <table> tags at all
if '<table' not in text.lower():
if "<table" not in text.lower():
return True # No tables, that's fine
# Pattern to find HTML tables (case-insensitive)
# Note: This pattern might not catch malformed tables where </table> is missing
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
tables = table_pattern.findall(text)
# Also check for unclosed table tags
table_open_count = len(re.findall(r'<table\b[^>]*>', text, re.IGNORECASE))
table_close_count = len(re.findall(r'</table>', text, re.IGNORECASE))
table_open_count = len(re.findall(r"<table\b[^>]*>", text, re.IGNORECASE))
table_close_count = len(re.findall(r"</table>", text, re.IGNORECASE))
if table_open_count != table_close_count:
return False # Mismatched table tags
if not tables and table_open_count > 0:
# Found table tags but couldn't extract complete tables
return False
# Try to parse each table
from html.parser import HTMLParser
class TableValidator(HTMLParser):
def __init__(self):
super().__init__()
self.tag_stack = []
self.is_valid = True
self.error_msg = None
def handle_starttag(self, tag, attrs):
self.tag_stack.append(tag.lower())
def handle_endtag(self, tag):
tag = tag.lower()
if not self.tag_stack:
self.is_valid = False
self.error_msg = f"Unexpected closing tag: {tag}"
return
# Check if the closing tag matches the most recent opening tag
if self.tag_stack[-1] == tag:
self.tag_stack.pop()
@ -842,11 +880,11 @@ class DatasetTextRuleFilter(PipelineStep):
else:
self.is_valid = False
self.error_msg = f"Mismatched tag: expected {self.tag_stack[-1]}, got {tag}"
def error(self, message):
self.is_valid = False
self.error_msg = message
# Validate each table
for table_html in tables:
parser = TableValidator()
@ -860,90 +898,90 @@ class DatasetTextRuleFilter(PipelineStep):
except Exception:
# Any parsing exception means the table is malformed
return False
return True
def __call__(self, sample: Sample) -> Optional[Sample]:
"""Filter samples based on text content rules."""
# Get the natural text from page_data if it exists
text = None
if "page_data" in sample:
page_data = sample["page_data"]
if hasattr(page_data, "natural_text") and page_data.natural_text:
text = page_data.natural_text
# If no text to check, pass the sample through
if text is None:
return sample
# Check for markdown tables
if self._contains_markdown_table(text):
return None # Filter out samples with markdown tables
# Check for HTML tables and validate them
if not self._extract_and_validate_html_tables(text):
return None # Filter out samples with malformed HTML tables
# Check for <br> tags in table cells
if self._contains_br_in_table_cells(text):
return None # Filter out samples with <br> tags in table cells
# Check if all math equations can render without errors
if not self._validate_math_equations(text):
return None # Filter out samples with invalid math equations
# Check for mathematical symbols
if self._contains_math_symbols(text):
return None # Filter out samples with mathematical symbols
# # Check for markdown tables
# if self._contains_markdown_table(text):
# return None # Filter out samples with markdown tables
# # Check for HTML tables and validate them
# if not self._extract_and_validate_html_tables(text):
# return None # Filter out samples with malformed HTML tables
# # Check for <br> tags in table cells
# if self._contains_br_in_table_cells(text):
# return None # Filter out samples with <br> tags in table cells
# # Check if all math equations can render without errors
# if not self._validate_math_equations(text):
# return None # Filter out samples with invalid math equations
# # Check for mathematical symbols
# if self._contains_math_symbols(text):
# return None # Filter out samples with mathematical symbols
# Check for LaTeX formatting outside math equations
if self._contains_latex_formatting_outside_math(text):
return None # Filter out samples with \textit or \textbf outside math
# Check for LaTeX tables
if self._contains_latex_tables(text):
return None # Filter out samples with LaTeX tables
return sample
@dataclass(frozen=True, slots=True)
class ReformatLatexBoldItalic(PipelineStep):
"""Pipeline step that converts LaTeX formatting commands to markdown equivalents.
Converts:
- \\textit{...} to *...* (italic)
- \\textbf{...} to **...** (bold)
These conversions only happen outside of math equations.
"""
def __call__(self, sample: Sample) -> Optional[Sample]:
"""Convert LaTeX formatting to markdown in the sample text."""
# Get the natural text from page_data if it exists
if "page_data" not in sample:
return sample
page_data = sample["page_data"]
if not hasattr(page_data, "natural_text") or not page_data.natural_text:
return sample
text = page_data.natural_text
import re
# Math equation patterns to preserve
math_patterns = [
r"\$\$(.+?)\$\$", # $$...$$
r"\\\((.+?)\\\)", # \(...\)
r"\\\[(.+?)\\\]", # \[...\]
]
# Store math equations with placeholders
math_placeholders = []
preserved_text = text
# Replace math equations with placeholders
for i, pattern in enumerate(math_patterns):
matches = re.finditer(pattern, preserved_text, re.DOTALL)
@ -951,65 +989,66 @@ class ReformatLatexBoldItalic(PipelineStep):
placeholder = f"__MATH_PLACEHOLDER_{i}_{j}__"
math_placeholders.append((placeholder, match.group(0)))
preserved_text = preserved_text.replace(match.group(0), placeholder, 1)
# Now convert LaTeX formatting to markdown
# We need to handle nested braces properly
# Use a function to find matching braces
def replace_latex_command(text, command, markdown):
"""Replace LaTeX command with markdown, handling nested braces."""
import re
pattern = r'\\' + command + r'\{'
pattern = r"\\" + command + r"\{"
result = []
i = 0
while i < len(text):
match = re.search(pattern, text[i:])
if not match:
result.append(text[i:])
break
# Add text before the match
result.append(text[i:i + match.start()])
result.append(text[i : i + match.start()])
# Find the matching closing brace
start_pos = i + match.end()
brace_count = 1
j = start_pos
while j < len(text) and brace_count > 0:
if text[j] == '{':
if text[j] == "{":
brace_count += 1
elif text[j] == '}':
elif text[j] == "}":
brace_count -= 1
j += 1
if brace_count == 0:
# Extract the content between braces
content = text[start_pos:j-1]
content = text[start_pos : j - 1]
result.append(markdown + content + markdown)
i = j
else:
# Unmatched braces, keep original
result.append(text[i + match.start():i + match.end()])
result.append(text[i + match.start() : i + match.end()])
i = i + match.end()
return ''.join(result)
return "".join(result)
# Handle \textbf{...} -> **...**
preserved_text = replace_latex_command(preserved_text, 'textbf', '**')
preserved_text = replace_latex_command(preserved_text, "textbf", "**")
# Handle \textit{...} -> *...*
preserved_text = replace_latex_command(preserved_text, 'textit', '*')
preserved_text = replace_latex_command(preserved_text, "textit", "*")
# Restore math equations
for placeholder, original in math_placeholders:
preserved_text = preserved_text.replace(placeholder, original)
# Create a new PageResponse with the updated text (since it's frozen)
from dataclasses import replace
updated_page_data = replace(page_data, natural_text=preserved_text)
sample["page_data"] = updated_page_data
return sample
@ -1382,78 +1421,71 @@ if __name__ == "__main__":
if args.save_filtered:
import shutil
from pathlib import Path
save_dir = Path(args.save_filtered)
# Clear and create directory
if save_dir.exists():
shutil.rmtree(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
print(f"\n=== Checking for filtered samples ===")
print(f"Will save filtered samples to: {save_dir}")
# Function to process and copy a single sample
def process_and_copy_sample(idx, dataset_samples, save_dir_str):
"""Process a sample and return info if it's filtered.
Note: This function needs to be picklable for ProcessPoolExecutor,
so it takes simple arguments rather than complex objects.
"""
import shutil
from pathlib import Path
# Recreate dataset with same parameters
# This is needed because dataset objects can't be pickled
temp_dataset = BaseMarkdownPDFDataset.__new__(BaseMarkdownPDFDataset)
temp_dataset.samples = dataset_samples
temp_dataset.pipeline_steps = pipeline_steps
try:
sample = temp_dataset[idx]
if sample is None:
# This sample was filtered out - get the original paths
original_sample = dataset_samples[idx]
md_path = original_sample['markdown_path']
pdf_path = original_sample['pdf_path']
md_path = original_sample["markdown_path"]
pdf_path = original_sample["pdf_path"]
save_dir = Path(save_dir_str)
# Create subdirectory to preserve some structure
# Use the parent directory name and file name
rel_path = md_path.parent.name
target_subdir = save_dir / rel_path
target_subdir.mkdir(parents=True, exist_ok=True)
# Copy markdown file
target_md = target_subdir / md_path.name
shutil.copy2(md_path, target_md)
# Copy PDF file
target_pdf = target_subdir / pdf_path.name
shutil.copy2(pdf_path, target_pdf)
return {
'index': idx,
'markdown_path': str(md_path),
'pdf_path': str(pdf_path)
}
return {"index": idx, "markdown_path": str(md_path), "pdf_path": str(pdf_path)}
return None
except Exception as e:
print(f"Error processing sample {idx}: {e}")
return None
# Process all samples in parallel
filtered_samples = []
print(f"Processing {len(dataset)} samples to find and copy filtered ones...")
with ProcessPoolExecutor(max_workers=8) as executor:
# Submit all tasks
futures = {
executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx
for idx in range(len(dataset))
}
futures = {executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx for idx in range(len(dataset))}
# Process results with progress bar
with tqdm(total=len(dataset), desc="Processing samples") as pbar:
for future in as_completed(futures):
@ -1461,20 +1493,20 @@ if __name__ == "__main__":
if result is not None:
filtered_samples.append(result)
pbar.update(1)
# Sort filtered samples by index for consistent output
filtered_samples.sort(key=lambda x: x['index'])
filtered_samples.sort(key=lambda x: x["index"])
print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}")
if filtered_samples:
print(f"First 10 filtered samples:")
for i, sample_info in enumerate(filtered_samples[:10]):
md_name = Path(sample_info['markdown_path']).name
md_name = Path(sample_info["markdown_path"]).name
print(f" Sample {sample_info['index']}: {md_name}")
if len(filtered_samples) > 10:
print(f" ... and {len(filtered_samples) - 10} more")
# Exit early if --save-filtered is used (don't continue with other analyses)
print("\nCompleted saving filtered samples. Exiting.")
exit(0)

View File

@ -44,7 +44,7 @@ Some conclusion text.
""",
)
}
result = self.filter(sample_with_md_table)
self.assertIsNone(result, "Should filter out samples with markdown tables")
@ -60,7 +60,7 @@ Some conclusion text.
natural_text="This is regular text without any tables. It has | pipes | but not in table format.",
)
}
result = self.filter(sample_without_table)
self.assertIsNotNone(result, "Should pass through samples without markdown tables")
self.assertEqual(result, sample_without_table)
@ -92,7 +92,7 @@ Some text after table.
""",
)
}
result = self.filter(sample_with_valid_html)
self.assertIsNotNone(result, "Should pass through samples with valid HTML tables")
@ -121,7 +121,7 @@ Text after.
""",
)
}
result = self.filter(sample_with_malformed_html)
self.assertIsNone(result, "Should filter out samples with malformed HTML tables")
@ -147,7 +147,7 @@ Text after without closing table tag.
""",
)
}
result = self.filter(sample_with_unclosed_table)
self.assertIsNone(result, "Should filter out HTML tables without closing tags")
@ -157,7 +157,7 @@ Text after without closing table tag.
"markdown_path": Path("/path/to/file.md"),
"pdf_path": Path("/path/to/file.pdf"),
}
result = self.filter(sample_without_page_data)
self.assertIsNotNone(result, "Should pass through samples without page_data")
self.assertEqual(result, sample_without_page_data)
@ -174,7 +174,7 @@ Text after without closing table tag.
natural_text=None,
)
}
result = self.filter(sample_without_text)
self.assertIsNotNone(result, "Should pass through samples without natural_text")
@ -190,7 +190,7 @@ Text after without closing table tag.
natural_text="",
)
}
result = self.filter(sample_with_empty_text)
self.assertIsNotNone(result, "Should pass through samples with empty natural_text")
@ -211,7 +211,7 @@ Text after without closing table tag.
""",
)
}
result = self.filter(sample_with_alignment)
self.assertIsNone(result, "Should filter out markdown tables with alignment")
@ -240,7 +240,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
""",
)
}
result = self.filter(sample_mixed)
self.assertIsNotNone(result, "Should pass through with valid HTML and no markdown tables")
@ -261,7 +261,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
</table>""",
)
}
result = self.filter(sample_with_br)
self.assertIsNone(result, "Should filter out tables with <br> tags in cells")
@ -283,7 +283,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
</table>""",
)
}
result = self.filter(sample_br_outside)
self.assertIsNotNone(result, "Should allow <br> tags outside tables")
@ -308,7 +308,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
</table>""",
)
}
result = self.filter(sample_br_variations)
self.assertIsNone(result, "Should filter out tables with any <br> variation in cells")
@ -332,7 +332,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="This is \\textbf{bold} text.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "This is **bold** text.")
@ -348,7 +348,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="This is \\textit{italic} text.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "This is *italic* text.")
@ -364,7 +364,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="This has \\textbf{bold} and \\textit{italic} text.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "This has **bold** and *italic* text.")
@ -380,7 +380,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="Text outside $$ \\textbf{x} = \\textit{y} $$ more text.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "Text outside $$ \\textbf{x} = \\textit{y} $$ more text.")
@ -396,7 +396,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="The \\textbf{equation} is $$ \\textbf{x} = 2 $$ and \\textit{important}.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "The **equation** is $$ \\textbf{x} = 2 $$ and *important*.")
@ -412,7 +412,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="This is \\textbf{bold with {nested} braces} text.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "This is **bold with {nested} braces** text.")
@ -428,12 +428,9 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="\\textbf{First} and \\textbf{second} bold, \\textit{first} and \\textit{second} italic.",
)
}
result = self.reformatter(sample)
self.assertEqual(
result["page_data"].natural_text,
"**First** and **second** bold, *first* and *second* italic."
)
self.assertEqual(result["page_data"].natural_text, "**First** and **second** bold, *first* and *second* italic.")
def test_latex_in_parenthesis_delimiter(self):
"""Test LaTeX preserved in \\(...\\) math delimiter."""
@ -447,7 +444,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="Text \\( \\textbf{math} \\) more text \\textbf{bold}.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "Text \\( \\textbf{math} \\) more text **bold**.")
@ -463,7 +460,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="Text \\[ \\textit{math} \\] more text \\textit{italic}.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "Text \\[ \\textit{math} \\] more text *italic*.")
@ -479,14 +476,14 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text="Plain text without any formatting.",
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, "Plain text without any formatting.")
def test_no_page_data(self):
"""Test handling of samples without page_data."""
sample = {"markdown_path": Path("/path/to/file.md")}
result = self.reformatter(sample)
self.assertEqual(result, sample)
@ -502,10 +499,10 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
natural_text=None,
)
}
result = self.reformatter(sample)
self.assertIsNone(result["page_data"].natural_text)
def test_complex_latex_with_parenthesis_delimiters(self):
"""Test complex LaTeX text with \\(...\\) delimiters and textit."""
input_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx
@ -517,7 +514,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
\\[
\\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2},
\\]"""
expected_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx
+ \\sum_{n=1}^{\\infty} \\frac{a_n}{2} \\int_0^P \\cos \\frac{2(m+n)\\pi x}{P} + \\cos \\frac{2(m-n)\\pi x}{P} dx
+ b_n \\int_0^P \\sin \\frac{2(m+n)\\pi x}{P} - \\sin \\frac{2(m-n)\\pi x}{P} dx.
@ -527,7 +524,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
\\[
\\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2},
\\]"""
sample = {
"page_data": PageResponse(
primary_language="en",
@ -538,7 +535,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
natural_text=input_text,
)
}
result = self.reformatter(sample)
self.assertEqual(result["page_data"].natural_text, expected_text)
@ -562,7 +559,7 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
natural_text="Some text",
)
}
result = self.filter(sample)
self.assertIsNotNone(result, "Should pass through documents with valid rotation")
@ -578,7 +575,7 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
natural_text="Some text",
)
}
result = self.filter(sample)
self.assertIsNone(result, "Should filter out documents with invalid rotation")
@ -594,14 +591,14 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
natural_text="Some text",
)
}
result = self.filter(sample)
self.assertIsNone(result, "Should filter out documents with non-zero rotation correction")
def test_no_page_data(self):
"""Test that samples without page_data pass through."""
sample = {"markdown_path": Path("/path/to/file.md")}
result = self.filter(sample)
self.assertIsNotNone(result, "Should pass through samples without page_data")
@ -625,7 +622,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
natural_text="The equation $x^2 + y^2 = z^2$ is famous.",
)
}
result = self.normalizer(sample)
expected_text = "The equation \\(x^2 + y^2 = z^2\\) is famous."
self.assertEqual(result["page_data"].natural_text, expected_text)
@ -642,7 +639,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
natural_text="Display equation:\n$$\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}$$",
)
}
result = self.normalizer(sample)
expected_text = "Display equation:\n\\[\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}\\]"
self.assertEqual(result["page_data"].natural_text, expected_text)
@ -659,7 +656,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
natural_text="Inline $a + b$ and display:\n$$c^2 = a^2 + b^2$$\nMore inline $x = y$.",
)
}
result = self.normalizer(sample)
expected_text = "Inline \\(a + b\\) and display:\n\\[c^2 = a^2 + b^2\\]\nMore inline \\(x = y\\)."
self.assertEqual(result["page_data"].natural_text, expected_text)
@ -676,7 +673,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
natural_text="Regular text without any equations.",
)
}
result = self.normalizer(sample)
self.assertEqual(result["page_data"].natural_text, "Regular text without any equations.")
@ -692,14 +689,14 @@ class TestLatexBracketNormalizer(unittest.TestCase):
natural_text=None,
)
}
result = self.normalizer(sample)
self.assertIsNone(result["page_data"].natural_text)
def test_no_page_data(self):
"""Test handling of missing page_data."""
sample = {"markdown_path": Path("/path/to/file.md")}
result = self.normalizer(sample)
self.assertEqual(result, sample)
@ -712,7 +709,7 @@ class TestFrontMatterParser(unittest.TestCase):
self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse)
self.parser_without_class = FrontMatterParser(front_matter_class=None)
@patch.object(Path, 'read_text')
@patch.object(Path, "read_text")
def test_parse_yaml_front_matter(self, mock_read_text):
"""Test parsing of YAML front matter."""
mock_read_text.return_value = """---
@ -724,27 +721,27 @@ is_diagram: false
---
This is the document content.
"""
sample = {"markdown_path": Path("/path/to/file.md")}
result = self.parser_with_class(sample)
self.assertIn("page_data", result)
self.assertIsInstance(result["page_data"], PageResponse)
self.assertEqual(result["page_data"].primary_language, "en")
self.assertEqual(result["page_data"].natural_text, "This is the document content.")
@patch.object(Path, 'read_text')
@patch.object(Path, "read_text")
def test_no_front_matter(self, mock_read_text):
"""Test handling of documents without front matter."""
mock_read_text.return_value = "Just regular content without front matter."
sample = {"markdown_path": Path("/path/to/file.md")}
# Should raise an error when front_matter_class is specified
with self.assertRaises(ValueError):
self.parser_with_class(sample)
@patch.object(Path, 'read_text')
@patch.object(Path, "read_text")
def test_malformed_yaml(self, mock_read_text):
"""Test handling of malformed YAML."""
mock_read_text.return_value = """---
@ -753,14 +750,14 @@ is_rotation_valid: [this is not valid yaml}
---
Content
"""
sample = {"markdown_path": Path("/path/to/file.md")}
# Parser without class should return empty dict for malformed YAML
result = self.parser_without_class(sample)
self.assertEqual(result["page_data"], {})
@patch.object(Path, 'read_text')
@patch.object(Path, "read_text")
def test_preserve_existing_markdown_content(self, mock_read_text):
"""Test that existing markdown_content is preserved if present."""
sample = {
@ -772,16 +769,16 @@ rotation_correction: 0
is_table: true
is_diagram: false
---
French content."""
French content.""",
}
# Should not call read_text since markdown_content exists
result = self.parser_with_class(sample)
mock_read_text.assert_not_called()
self.assertEqual(result["page_data"].primary_language, "fr")
self.assertEqual(result["page_data"].is_table, True)
if __name__ == "__main__":
unittest.main()
unittest.main()