This commit is contained in:
Jake Poznanski 2025-08-19 21:30:41 +00:00
parent 768cb33937
commit 41201b6317
6 changed files with 419 additions and 452 deletions

View File

@ -8,16 +8,19 @@ from olmocr.bench.prompts import (
build_basic_prompt, build_basic_prompt,
build_openai_silver_data_prompt_no_document_anchoring, build_openai_silver_data_prompt_no_document_anchoring,
) )
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64 from olmocr.data.renderpdf import (
get_png_dimensions_from_base64,
render_pdf_to_base64png,
)
from olmocr.prompts.anchor import get_anchor_text from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import ( from olmocr.prompts.prompts import (
PageResponse, PageResponse,
build_finetuning_prompt, build_finetuning_prompt,
build_openai_silver_data_prompt, build_openai_silver_data_prompt,
openai_response_format_schema,
build_openai_silver_data_prompt_v2, build_openai_silver_data_prompt_v2,
build_openai_silver_data_prompt_v2_simple, build_openai_silver_data_prompt_v2_simple,
build_openai_silver_data_prompt_v3_simple, build_openai_silver_data_prompt_v3_simple,
openai_response_format_schema,
) )

View File

@ -8,14 +8,17 @@ and generates OpenAI batch API requests for processing PDFs.
import argparse import argparse
import json import json
import os
from pathlib import Path
from typing import Generator, Dict, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple
from pypdf import PdfReader from pypdf import PdfReader
from tqdm import tqdm from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64 from olmocr.data.renderpdf import (
get_png_dimensions_from_base64,
render_pdf_to_base64png,
)
from olmocr.prompts.prompts import ( from olmocr.prompts.prompts import (
build_openai_silver_data_prompt_v3_simple, build_openai_silver_data_prompt_v3_simple,
openai_response_format_schema, openai_response_format_schema,
@ -60,7 +63,7 @@ def build_custom_id(pdf_path: Path, base_dir: Path) -> str:
# Get relative path from base directory # Get relative path from base directory
rel_path = pdf_path.relative_to(base_dir) rel_path = pdf_path.relative_to(base_dir)
# Remove .pdf extension but keep directory structure # Remove .pdf extension but keep directory structure
path_without_ext = str(rel_path).replace('.pdf', '') path_without_ext = str(rel_path).replace(".pdf", "")
return path_without_ext return path_without_ext
@ -151,12 +154,7 @@ def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
yield pdf_path yield pdf_path
def process_pdfs_to_batch_requests( def process_pdfs_to_batch_requests(input_dir: Path, output_dir: Path, max_pdfs: int = None, num_workers: int = 8) -> int:
input_dir: Path,
output_dir: Path,
max_pdfs: int = None,
num_workers: int = 8
) -> int:
""" """
Process PDFs and create batch request files using parallel processing. Process PDFs and create batch request files using parallel processing.
@ -196,10 +194,7 @@ def process_pdfs_to_batch_requests(
# Process PDFs in parallel using ThreadPoolExecutor # Process PDFs in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=num_workers) as executor: with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all PDF processing tasks # Submit all PDF processing tasks
future_to_pdf = { future_to_pdf = {executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path for pdf_path in pdf_files}
executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path
for pdf_path in pdf_files
}
# Process results as they complete # Process results as they complete
with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar: with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar:
@ -252,31 +247,14 @@ def process_pdfs_to_batch_requests(
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Build OpenAI batch requests from OLMoCR-mix folder structure")
description="Build OpenAI batch requests from OLMoCR-mix folder structure" parser.add_argument("--output_dir", type=str, default=None, help="Output directory for batch request files (default: input_dir/batch_requests)")
) parser.add_argument("--max_pdfs", type=int, default=None, help="Maximum number of PDFs to process (default: all)")
parser.add_argument( parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
"--output_dir",
type=str,
default=None,
help="Output directory for batch request files (default: input_dir/batch_requests)"
)
parser.add_argument(
"--max_pdfs",
type=int,
default=None,
help="Maximum number of PDFs to process (default: all)"
)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="Number of parallel workers for processing (default: 8)"
)
parser.add_argument( parser.add_argument(
"input_dir", "input_dir",
type=str, type=str,
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)" help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)",
) )
args = parser.parse_args() args = parser.parse_args()
@ -298,12 +276,7 @@ def main():
print(f"Output directory: {output_dir}") print(f"Output directory: {output_dir}")
# Process PDFs # Process PDFs
process_pdfs_to_batch_requests( process_pdfs_to_batch_requests(input_dir=input_dir, output_dir=output_dir, max_pdfs=args.max_pdfs, num_workers=args.num_workers)
input_dir=input_dir,
output_dir=output_dir,
max_pdfs=args.max_pdfs,
num_workers=args.num_workers
)
return 0 return 0

View File

@ -8,11 +8,12 @@ that mirrors the original structure with side-by-side PDF and MD files.
import argparse import argparse
import json import json
import shutil
import re import re
from pathlib import Path import shutil
from typing import Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, Optional
from tqdm import tqdm from tqdm import tqdm
@ -39,10 +40,7 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
content = body["choices"][0]["message"]["content"] content = body["choices"][0]["message"]["content"]
# Parse the JSON response # Parse the JSON response
parsed_content = json.loads(content) parsed_content = json.loads(content)
return { return {"custom_id": custom_id, "content": parsed_content}
"custom_id": custom_id,
"content": parsed_content
}
else: else:
print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}") print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}")
return None return None
@ -86,12 +84,7 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
return markdown.strip() return markdown.strip()
def process_single_result( def process_single_result(custom_id: str, response_content: Dict[str, Any], original_pdf_dir: Path, output_dir: Path) -> bool:
custom_id: str,
response_content: Dict[str, Any],
original_pdf_dir: Path,
output_dir: Path
) -> bool:
""" """
Process a single batch result: copy PDF and create MD file. Process a single batch result: copy PDF and create MD file.
@ -114,8 +107,8 @@ def process_single_result(
print(f"Warning: Original PDF not found: {original_pdf_path}") print(f"Warning: Original PDF not found: {original_pdf_path}")
original_pdf_path = str(original_pdf_path) original_pdf_path = str(original_pdf_path)
pattern = r'(.+?)(-\d+)\.pdf$' pattern = r"(.+?)(-\d+)\.pdf$"
replacement = r'\1.pdf\2.pdf' replacement = r"\1.pdf\2.pdf"
original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path)) original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path))
@ -145,12 +138,7 @@ def process_single_result(
return False return False
def process_batch_results( def process_batch_results(batch_results_dir: Path, original_pdf_dir: Path, output_dir: Path, num_workers: int = 8) -> int:
batch_results_dir: Path,
original_pdf_dir: Path,
output_dir: Path,
num_workers: int = 8
) -> int:
""" """
Process all batch result files and create output structure. Process all batch result files and create output structure.
@ -198,13 +186,7 @@ def process_batch_results(
with ThreadPoolExecutor(max_workers=num_workers) as executor: with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all processing tasks # Submit all processing tasks
future_to_result = { future_to_result = {
executor.submit( executor.submit(process_single_result, result["custom_id"], result["content"], original_pdf_dir, output_dir): result["custom_id"]
process_single_result,
result["custom_id"],
result["content"],
original_pdf_dir,
output_dir
): result["custom_id"]
for result in results_to_process for result in results_to_process
} }
@ -234,30 +216,13 @@ def process_batch_results(
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Process OpenAI batch results and create output folder with PDFs and Markdown files")
description="Process OpenAI batch results and create output folder with PDFs and Markdown files" parser.add_argument("batch_results_dir", type=str, help="Directory containing completed OpenAI batch result files (JSONL)")
)
parser.add_argument( parser.add_argument(
"batch_results_dir", "original_pdf_dir", type=str, help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
type=str,
help="Directory containing completed OpenAI batch result files (JSONL)"
)
parser.add_argument(
"original_pdf_dir",
type=str,
help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
)
parser.add_argument(
"output_dir",
type=str,
help="Output directory for processed results with PDFs and MD files"
)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="Number of parallel workers for processing (default: 8)"
) )
parser.add_argument("output_dir", type=str, help="Output directory for processed results with PDFs and MD files")
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
args = parser.parse_args() args = parser.parse_args()
@ -280,12 +245,7 @@ def main():
print(f"Output directory: {output_dir}") print(f"Output directory: {output_dir}")
# Process the batch results # Process the batch results
process_batch_results( process_batch_results(batch_results_dir=batch_results_dir, original_pdf_dir=original_pdf_dir, output_dir=output_dir, num_workers=args.num_workers)
batch_results_dir=batch_results_dir,
original_pdf_dir=original_pdf_dir,
output_dir=output_dir,
num_workers=args.num_workers
)
return 0 return 0

View File

@ -16,6 +16,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
) )
def build_openai_silver_data_prompt_v2(base_text: str) -> str: def build_openai_silver_data_prompt_v2(base_text: str) -> str:
return ( return (
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). " f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). "
@ -30,6 +31,7 @@ def build_openai_silver_data_prompt_v2(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
) )
def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str: def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str:
return ( return (
f"Attached is the image of one page of a PDF document." f"Attached is the image of one page of a PDF document."
@ -44,6 +46,7 @@ def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int)
f"Page width: {page_width}, Page height: {page_height}" f"Page width: {page_width}, Page height: {page_height}"
) )
def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str: def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str:
return ( return (
f"Attached is the image of one page of a PDF document." f"Attached is the image of one page of a PDF document."
@ -60,7 +63,6 @@ def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int)
) )
@dataclass(frozen=True) @dataclass(frozen=True)
class PageResponse: class PageResponse:
primary_language: Optional[str] primary_language: Optional[str]

View File

@ -1,10 +1,14 @@
import argparse
import base64 import base64
import json import json
import logging import logging
import multiprocessing
import re import re
import shutil
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, fields from dataclasses import dataclass, fields, replace
from html.parser import HTMLParser
from io import BytesIO from io import BytesIO
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
@ -419,8 +423,6 @@ class LatexBracketNormalizer(PipelineStep):
# Update the page_data with normalized text # Update the page_data with normalized text
# Since PageResponse is frozen, we need to create a new instance # Since PageResponse is frozen, we need to create a new instance
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse( new_page_data = PageResponse(
primary_language=page_data.primary_language, primary_language=page_data.primary_language,
is_rotation_valid=page_data.is_rotation_valid, is_rotation_valid=page_data.is_rotation_valid,
@ -482,8 +484,6 @@ class RotationAugmentation(PipelineStep):
else: # 270 else: # 270
correction = 90 correction = 90
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse( new_page_data = PageResponse(
primary_language=page_data.primary_language, primary_language=page_data.primary_language,
is_rotation_valid=False, # Mark as invalid since we rotated it is_rotation_valid=False, # Mark as invalid since we rotated it
@ -539,20 +539,20 @@ class DatasetTextRuleFilter(PipelineStep):
# Look for pipe-separated table patterns # Look for pipe-separated table patterns
# Markdown tables have lines like: | col1 | col2 | col3 | # Markdown tables have lines like: | col1 | col2 | col3 |
# And separator lines like: |------|------|------| # And separator lines like: |------|------|------|
lines = text.split('\n') lines = text.split("\n")
for i, line in enumerate(lines): for i, line in enumerate(lines):
line = line.strip() line = line.strip()
# Check if line looks like a table row # Check if line looks like a table row
if line.startswith('|') and line.endswith('|') and line.count('|') >= 3: if line.startswith("|") and line.endswith("|") and line.count("|") >= 3:
# Check if next line is a separator (for header rows) # Check if next line is a separator (for header rows)
if i + 1 < len(lines): if i + 1 < len(lines):
next_line = lines[i + 1].strip() next_line = lines[i + 1].strip()
if next_line.startswith('|') and '-' in next_line: if next_line.startswith("|") and "-" in next_line:
return True return True
# Check if previous line is a separator (for data rows) # Check if previous line is a separator (for data rows)
if i > 0: if i > 0:
prev_line = lines[i - 1].strip() prev_line = lines[i - 1].strip()
if prev_line.startswith('|') and '-' in prev_line: if prev_line.startswith("|") and "-" in prev_line:
return True return True
return False return False
@ -563,30 +563,74 @@ class DatasetTextRuleFilter(PipelineStep):
True if text contains any of the specified math symbols outside tables True if text contains any of the specified math symbols outside tables
False otherwise False otherwise
""" """
import re
# List of mathematical symbols to check for # List of mathematical symbols to check for
math_symbols = [ math_symbols = [
# Set theory and logic # Set theory and logic
'', '', '', '', '', '', '', '', '', '', '', '¬', "",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"¬",
# Common mathematical operators # Common mathematical operators
'', '', '', "",
"",
"",
# Calculus and analysis # Calculus and analysis
'', '', '', '', '', '', '', '', '', '', '', '', "",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
# Arrows and relations # Arrows and relations
'', "",
# Other common math symbols # Other common math symbols
'', '', '', '', '', '', '', '', '', '', '', '', '', '', '', "",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
# Matrix and vector notation # Matrix and vector notation
'', '', '', '', '', '', '', '', '', "",
"",
"",
"",
"",
"",
"",
"",
"",
] ]
# First, remove all HTML tables from the text # First, remove all HTML tables from the text
text_without_tables = text text_without_tables = text
# Remove HTML tables # Remove HTML tables
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL) table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
text_without_tables = table_pattern.sub('', text_without_tables) text_without_tables = table_pattern.sub("", text_without_tables)
# Now check if any of these symbols appear in the text without tables # Now check if any of these symbols appear in the text without tables
for symbol in math_symbols: for symbol in math_symbols:
@ -606,8 +650,8 @@ class DatasetTextRuleFilter(PipelineStep):
# Check for various LaTeX table environments # Check for various LaTeX table environments
latex_table_patterns = [ latex_table_patterns = [
r'\\begin\{table\}', r"\\begin\{table\}",
r'\\begin\{tabular\}', r"\\begin\{tabular\}",
] ]
# Check if any LaTeX table pattern exists in the text # Check if any LaTeX table pattern exists in the text
@ -629,47 +673,42 @@ class DatasetTextRuleFilter(PipelineStep):
# List of common LaTeX formatting commands to check for # List of common LaTeX formatting commands to check for
latex_commands = [ latex_commands = [
# Lists & basic content # Lists & basic content
r'\begin{itemize}', r"\begin{itemize}",
r'\begin{enumerate}', r"\begin{enumerate}",
r'\item', r"\item",
# Figures, tables, and captions # Figures, tables, and captions
r'\begin{figure}', r"\begin{figure}",
r'\includegraphics', r"\includegraphics",
r'\caption', r"\caption",
r'\label', r"\label",
r'\ref', r"\ref",
r'\eqref', r"\eqref",
r'\begin{table}', r"\begin{table}",
r'\begin{tabular}', r"\begin{tabular}",
# Formatting, # Formatting,
# r'\textit', # r'\textit',
# r'\textbb', # r'\textbb',
# Math (strong signals) # Math (strong signals)
r'\begin{equation}', r"\begin{equation}",
r'\begin{align}', r"\begin{align}",
r'\frac', r"\frac",
r'\sum', r"\sum",
r'\int', r"\int",
r'\sqrt', r"\sqrt",
r'\prod', r"\prod",
r'\lim', r"\lim",
r'\binom', r"\binom",
r'\mathbb', r"\mathbb",
r'\mathcal', r"\mathcal",
r'\to', r"\to",
r'\varphi', r"\varphi",
r'\cdot', r"\cdot",
r'\langle', r"\langle",
r'\rangle', r"\rangle",
# Citations (bibliography stacks) # Citations (bibliography stacks)
r'\cite', r"\cite",
] ]
# First, remove all math equations from the text # First, remove all math equations from the text
text_without_math = text text_without_math = text
@ -682,7 +721,7 @@ class DatasetTextRuleFilter(PipelineStep):
# Remove all math equations # Remove all math equations
for pattern in math_patterns: for pattern in math_patterns:
text_without_math = re.sub(pattern, '', text_without_math, flags=re.DOTALL) text_without_math = re.sub(pattern, "", text_without_math, flags=re.DOTALL)
# Check if any LaTeX commands appear in the remaining text # Check if any LaTeX commands appear in the remaining text
for command in latex_commands: for command in latex_commands:
@ -730,7 +769,7 @@ class DatasetTextRuleFilter(PipelineStep):
rendered = render_equation(equation) rendered = render_equation(equation)
# Check if there was an error # Check if there was an error
if rendered is None or (hasattr(rendered, 'error') and rendered.error): if rendered is None or (hasattr(rendered, "error") and rendered.error):
# Equation failed to render # Equation failed to render
logger.warning(f"Could not render equation '{repr(equation)}', skipping sample") logger.warning(f"Could not render equation '{repr(equation)}', skipping sample")
return False return False
@ -757,22 +796,22 @@ class DatasetTextRuleFilter(PipelineStep):
import re import re
# Check if there are any tables in the text # Check if there are any tables in the text
if '<table' not in text.lower() or '<br' not in text.lower(): if "<table" not in text.lower() or "<br" not in text.lower():
return False # No tables or no <br> tags at all return False # No tables or no <br> tags at all
# Pattern to find HTML tables (case-insensitive) # Pattern to find HTML tables (case-insensitive)
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL) table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
tables = table_pattern.findall(text) tables = table_pattern.findall(text)
# Check each table for <br> tags in cells # Check each table for <br> tags in cells
for table_html in tables: for table_html in tables:
# Pattern to find table cells (td and th tags) # Pattern to find table cells (td and th tags)
cell_pattern = re.compile(r'<(td|th)\b[^>]*>(.*?)</\1>', re.IGNORECASE | re.DOTALL) cell_pattern = re.compile(r"<(td|th)\b[^>]*>(.*?)</\1>", re.IGNORECASE | re.DOTALL)
cells = cell_pattern.findall(table_html) cells = cell_pattern.findall(table_html)
for tag_type, cell_content in cells: for tag_type, cell_content in cells:
# Check if cell content contains <br> tags (any variation) # Check if cell content contains <br> tags (any variation)
if re.search(r'<br\s*/?>', cell_content, re.IGNORECASE): if re.search(r"<br\s*/?>", cell_content, re.IGNORECASE):
return True return True
return False return False
@ -788,17 +827,17 @@ class DatasetTextRuleFilter(PipelineStep):
import re import re
# Check if there are any <table> tags at all # Check if there are any <table> tags at all
if '<table' not in text.lower(): if "<table" not in text.lower():
return True # No tables, that's fine return True # No tables, that's fine
# Pattern to find HTML tables (case-insensitive) # Pattern to find HTML tables (case-insensitive)
# Note: This pattern might not catch malformed tables where </table> is missing # Note: This pattern might not catch malformed tables where </table> is missing
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL) table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
tables = table_pattern.findall(text) tables = table_pattern.findall(text)
# Also check for unclosed table tags # Also check for unclosed table tags
table_open_count = len(re.findall(r'<table\b[^>]*>', text, re.IGNORECASE)) table_open_count = len(re.findall(r"<table\b[^>]*>", text, re.IGNORECASE))
table_close_count = len(re.findall(r'</table>', text, re.IGNORECASE)) table_close_count = len(re.findall(r"</table>", text, re.IGNORECASE))
if table_open_count != table_close_count: if table_open_count != table_close_count:
return False # Mismatched table tags return False # Mismatched table tags
@ -808,7 +847,6 @@ class DatasetTextRuleFilter(PipelineStep):
return False return False
# Try to parse each table # Try to parse each table
from html.parser import HTMLParser
class TableValidator(HTMLParser): class TableValidator(HTMLParser):
def __init__(self): def __init__(self):
@ -877,25 +915,25 @@ class DatasetTextRuleFilter(PipelineStep):
if text is None: if text is None:
return sample return sample
# Check for markdown tables # # Check for markdown tables
if self._contains_markdown_table(text): # if self._contains_markdown_table(text):
return None # Filter out samples with markdown tables # return None # Filter out samples with markdown tables
# Check for HTML tables and validate them # # Check for HTML tables and validate them
if not self._extract_and_validate_html_tables(text): # if not self._extract_and_validate_html_tables(text):
return None # Filter out samples with malformed HTML tables # return None # Filter out samples with malformed HTML tables
# Check for <br> tags in table cells # # Check for <br> tags in table cells
if self._contains_br_in_table_cells(text): # if self._contains_br_in_table_cells(text):
return None # Filter out samples with <br> tags in table cells # return None # Filter out samples with <br> tags in table cells
# Check if all math equations can render without errors # # Check if all math equations can render without errors
if not self._validate_math_equations(text): # if not self._validate_math_equations(text):
return None # Filter out samples with invalid math equations # return None # Filter out samples with invalid math equations
# Check for mathematical symbols # # Check for mathematical symbols
if self._contains_math_symbols(text): # if self._contains_math_symbols(text):
return None # Filter out samples with mathematical symbols # return None # Filter out samples with mathematical symbols
# Check for LaTeX formatting outside math equations # Check for LaTeX formatting outside math equations
if self._contains_latex_formatting_outside_math(text): if self._contains_latex_formatting_outside_math(text):
@ -958,7 +996,8 @@ class ReformatLatexBoldItalic(PipelineStep):
def replace_latex_command(text, command, markdown): def replace_latex_command(text, command, markdown):
"""Replace LaTeX command with markdown, handling nested braces.""" """Replace LaTeX command with markdown, handling nested braces."""
import re import re
pattern = r'\\' + command + r'\{'
pattern = r"\\" + command + r"\{"
result = [] result = []
i = 0 i = 0
@ -977,9 +1016,9 @@ class ReformatLatexBoldItalic(PipelineStep):
j = start_pos j = start_pos
while j < len(text) and brace_count > 0: while j < len(text) and brace_count > 0:
if text[j] == '{': if text[j] == "{":
brace_count += 1 brace_count += 1
elif text[j] == '}': elif text[j] == "}":
brace_count -= 1 brace_count -= 1
j += 1 j += 1
@ -993,20 +1032,20 @@ class ReformatLatexBoldItalic(PipelineStep):
result.append(text[i + match.start() : i + match.end()]) result.append(text[i + match.start() : i + match.end()])
i = i + match.end() i = i + match.end()
return ''.join(result) return "".join(result)
# Handle \textbf{...} -> **...** # Handle \textbf{...} -> **...**
preserved_text = replace_latex_command(preserved_text, 'textbf', '**') preserved_text = replace_latex_command(preserved_text, "textbf", "**")
# Handle \textit{...} -> *...* # Handle \textit{...} -> *...*
preserved_text = replace_latex_command(preserved_text, 'textit', '*') preserved_text = replace_latex_command(preserved_text, "textit", "*")
# Restore math equations # Restore math equations
for placeholder, original in math_placeholders: for placeholder, original in math_placeholders:
preserved_text = preserved_text.replace(placeholder, original) preserved_text = preserved_text.replace(placeholder, original)
# Create a new PageResponse with the updated text (since it's frozen) # Create a new PageResponse with the updated text (since it's frozen)
from dataclasses import replace
updated_page_data = replace(page_data, natural_text=preserved_text) updated_page_data = replace(page_data, natural_text=preserved_text)
sample["page_data"] = updated_page_data sample["page_data"] = updated_page_data
@ -1414,8 +1453,8 @@ if __name__ == "__main__":
if sample is None: if sample is None:
# This sample was filtered out - get the original paths # This sample was filtered out - get the original paths
original_sample = dataset_samples[idx] original_sample = dataset_samples[idx]
md_path = original_sample['markdown_path'] md_path = original_sample["markdown_path"]
pdf_path = original_sample['pdf_path'] pdf_path = original_sample["pdf_path"]
save_dir = Path(save_dir_str) save_dir = Path(save_dir_str)
@ -1433,11 +1472,7 @@ if __name__ == "__main__":
target_pdf = target_subdir / pdf_path.name target_pdf = target_subdir / pdf_path.name
shutil.copy2(pdf_path, target_pdf) shutil.copy2(pdf_path, target_pdf)
return { return {"index": idx, "markdown_path": str(md_path), "pdf_path": str(pdf_path)}
'index': idx,
'markdown_path': str(md_path),
'pdf_path': str(pdf_path)
}
return None return None
except Exception as e: except Exception as e:
print(f"Error processing sample {idx}: {e}") print(f"Error processing sample {idx}: {e}")
@ -1449,10 +1484,7 @@ if __name__ == "__main__":
with ProcessPoolExecutor(max_workers=8) as executor: with ProcessPoolExecutor(max_workers=8) as executor:
# Submit all tasks # Submit all tasks
futures = { futures = {executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx for idx in range(len(dataset))}
executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx
for idx in range(len(dataset))
}
# Process results with progress bar # Process results with progress bar
with tqdm(total=len(dataset), desc="Processing samples") as pbar: with tqdm(total=len(dataset), desc="Processing samples") as pbar:
@ -1463,14 +1495,14 @@ if __name__ == "__main__":
pbar.update(1) pbar.update(1)
# Sort filtered samples by index for consistent output # Sort filtered samples by index for consistent output
filtered_samples.sort(key=lambda x: x['index']) filtered_samples.sort(key=lambda x: x["index"])
print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}") print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}")
if filtered_samples: if filtered_samples:
print(f"First 10 filtered samples:") print(f"First 10 filtered samples:")
for i, sample_info in enumerate(filtered_samples[:10]): for i, sample_info in enumerate(filtered_samples[:10]):
md_name = Path(sample_info['markdown_path']).name md_name = Path(sample_info["markdown_path"]).name
print(f" Sample {sample_info['index']}: {md_name}") print(f" Sample {sample_info['index']}: {md_name}")
if len(filtered_samples) > 10: if len(filtered_samples) > 10:
print(f" ... and {len(filtered_samples) - 10} more") print(f" ... and {len(filtered_samples) - 10} more")

View File

@ -430,10 +430,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
} }
result = self.reformatter(sample) result = self.reformatter(sample)
self.assertEqual( self.assertEqual(result["page_data"].natural_text, "**First** and **second** bold, *first* and *second* italic.")
result["page_data"].natural_text,
"**First** and **second** bold, *first* and *second* italic."
)
def test_latex_in_parenthesis_delimiter(self): def test_latex_in_parenthesis_delimiter(self):
"""Test LaTeX preserved in \\(...\\) math delimiter.""" """Test LaTeX preserved in \\(...\\) math delimiter."""
@ -712,7 +709,7 @@ class TestFrontMatterParser(unittest.TestCase):
self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse) self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse)
self.parser_without_class = FrontMatterParser(front_matter_class=None) self.parser_without_class = FrontMatterParser(front_matter_class=None)
@patch.object(Path, 'read_text') @patch.object(Path, "read_text")
def test_parse_yaml_front_matter(self, mock_read_text): def test_parse_yaml_front_matter(self, mock_read_text):
"""Test parsing of YAML front matter.""" """Test parsing of YAML front matter."""
mock_read_text.return_value = """--- mock_read_text.return_value = """---
@ -733,7 +730,7 @@ This is the document content.
self.assertEqual(result["page_data"].primary_language, "en") self.assertEqual(result["page_data"].primary_language, "en")
self.assertEqual(result["page_data"].natural_text, "This is the document content.") self.assertEqual(result["page_data"].natural_text, "This is the document content.")
@patch.object(Path, 'read_text') @patch.object(Path, "read_text")
def test_no_front_matter(self, mock_read_text): def test_no_front_matter(self, mock_read_text):
"""Test handling of documents without front matter.""" """Test handling of documents without front matter."""
mock_read_text.return_value = "Just regular content without front matter." mock_read_text.return_value = "Just regular content without front matter."
@ -744,7 +741,7 @@ This is the document content.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.parser_with_class(sample) self.parser_with_class(sample)
@patch.object(Path, 'read_text') @patch.object(Path, "read_text")
def test_malformed_yaml(self, mock_read_text): def test_malformed_yaml(self, mock_read_text):
"""Test handling of malformed YAML.""" """Test handling of malformed YAML."""
mock_read_text.return_value = """--- mock_read_text.return_value = """---
@ -760,7 +757,7 @@ Content
result = self.parser_without_class(sample) result = self.parser_without_class(sample)
self.assertEqual(result["page_data"], {}) self.assertEqual(result["page_data"], {})
@patch.object(Path, 'read_text') @patch.object(Path, "read_text")
def test_preserve_existing_markdown_content(self, mock_read_text): def test_preserve_existing_markdown_content(self, mock_read_text):
"""Test that existing markdown_content is preserved if present.""" """Test that existing markdown_content is preserved if present."""
sample = { sample = {
@ -772,7 +769,7 @@ rotation_correction: 0
is_table: true is_table: true
is_diagram: false is_diagram: false
--- ---
French content.""" French content.""",
} }
# Should not call read_text since markdown_content exists # Should not call read_text since markdown_content exists