From 41201b6317a495df647525a13da2facc8e054721 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 19 Aug 2025 21:30:41 +0000 Subject: [PATCH] Lints --- olmocr/bench/runners/run_chatgpt.py | 11 +- .../data/build_openai_batch_from_olmocrmix.py | 149 +++--- olmocr/data/process_openai_batch_results.py | 156 +++---- olmocr/prompts/prompts.py | 4 +- olmocr/train/dataloader.py | 440 ++++++++++-------- tests/test_dataloader.py | 111 +++-- 6 files changed, 419 insertions(+), 452 deletions(-) diff --git a/olmocr/bench/runners/run_chatgpt.py b/olmocr/bench/runners/run_chatgpt.py index 4127cf7..fc354dc 100644 --- a/olmocr/bench/runners/run_chatgpt.py +++ b/olmocr/bench/runners/run_chatgpt.py @@ -8,16 +8,19 @@ from olmocr.bench.prompts import ( build_basic_prompt, build_openai_silver_data_prompt_no_document_anchoring, ) -from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64 +from olmocr.data.renderpdf import ( + get_png_dimensions_from_base64, + render_pdf_to_base64png, +) from olmocr.prompts.anchor import get_anchor_text from olmocr.prompts.prompts import ( PageResponse, build_finetuning_prompt, build_openai_silver_data_prompt, - openai_response_format_schema, build_openai_silver_data_prompt_v2, build_openai_silver_data_prompt_v2_simple, build_openai_silver_data_prompt_v3_simple, + openai_response_format_schema, ) @@ -65,7 +68,7 @@ def run_chatgpt( prompt = build_openai_silver_data_prompt_v2_simple(width, height) elif prompt_template == "fullv3simple": width, height = get_png_dimensions_from_base64(image_base64) - prompt = build_openai_silver_data_prompt_v3_simple(width, height) + prompt = build_openai_silver_data_prompt_v3_simple(width, height) else: raise ValueError("Unknown prompt template") @@ -82,7 +85,7 @@ def run_chatgpt( ], temperature=temperature, max_completion_tokens=20000, - #reasoning_effort="high", + # reasoning_effort="high", response_format=openai_response_format_schema() if response_template == "json" else None, safety_identifier="olmocr-bench-runner", ) diff --git a/olmocr/data/build_openai_batch_from_olmocrmix.py b/olmocr/data/build_openai_batch_from_olmocrmix.py index b36742b..2f4a1c0 100755 --- a/olmocr/data/build_openai_batch_from_olmocrmix.py +++ b/olmocr/data/build_openai_batch_from_olmocrmix.py @@ -8,14 +8,17 @@ and generates OpenAI batch API requests for processing PDFs. import argparse import json -import os -from pathlib import Path -from typing import Generator, Dict, Any, Optional, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, Generator, Optional, Tuple + from pypdf import PdfReader from tqdm import tqdm -from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64 +from olmocr.data.renderpdf import ( + get_png_dimensions_from_base64, + render_pdf_to_base64png, +) from olmocr.prompts.prompts import ( build_openai_silver_data_prompt_v3_simple, openai_response_format_schema, @@ -28,10 +31,10 @@ MAX_FILE_SIZE = 99 * 1024 * 1024 # 99MB in bytes def validate_single_page_pdf(pdf_path: Path) -> bool: """ Validate that a PDF has exactly one page. - + Args: pdf_path: Path to the PDF file - + Returns: True if PDF has exactly one page, False otherwise """ @@ -46,32 +49,32 @@ def validate_single_page_pdf(pdf_path: Path) -> bool: def build_custom_id(pdf_path: Path, base_dir: Path) -> str: """ Build a custom ID for the request that can be used to recover the file later. - + The ID preserves the full path structure for easy recovery. Example: extracted/document_id.pdf becomes "extracted/document_id" - + Args: pdf_path: Full path to the PDF file base_dir: Base directory containing the processed folder - + Returns: Custom ID string that preserves path structure """ # Get relative path from base directory rel_path = pdf_path.relative_to(base_dir) # Remove .pdf extension but keep directory structure - path_without_ext = str(rel_path).replace('.pdf', '') + path_without_ext = str(rel_path).replace(".pdf", "") return path_without_ext def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[str, Any], Path]]: """ Process a single PDF and return the batch request if valid. - + Args: pdf_path: Path to the PDF file base_dir: Base directory for building custom IDs - + Returns: Tuple of (request dict, pdf_path) if successful, None otherwise """ @@ -83,20 +86,20 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st except Exception as e: print(f"Error reading PDF {pdf_path}: {e}") return None - + try: # Render PDF to base64 image image_base64 = render_pdf_to_base64png(str(pdf_path), page_num=1, target_longest_image_dim=TARGET_IMAGE_DIM) - + # Get image dimensions for the prompt width, height = get_png_dimensions_from_base64(image_base64) - + # Build the prompt using v3 simple version prompt = build_openai_silver_data_prompt_v3_simple(width, height) - + # Build custom ID custom_id = build_custom_id(pdf_path, base_dir) - + # Build the request in OpenAI batch format request = { "custom_id": custom_id, @@ -118,7 +121,7 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st "response_format": openai_response_format_schema(), }, } - + return (request, pdf_path) except Exception as e: print(f"Error processing {pdf_path}: {e}") @@ -128,21 +131,21 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]: """ Find all PDF files in the processed folder structure. - + The structure is expected to be: processed_XX_subset_split/ extracted/ *.pdf - + Or for hugging_face downloads: hugging_face/ pdf_tarballs/ extracted/ *.pdf - + Args: input_dir: Input directory path - + Yields: Path objects for each PDF file found """ @@ -151,64 +154,56 @@ def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]: yield pdf_path -def process_pdfs_to_batch_requests( - input_dir: Path, - output_dir: Path, - max_pdfs: int = None, - num_workers: int = 8 -) -> int: +def process_pdfs_to_batch_requests(input_dir: Path, output_dir: Path, max_pdfs: int = None, num_workers: int = 8) -> int: """ Process PDFs and create batch request files using parallel processing. - + Args: input_dir: Directory containing the processed folder structure output_dir: Directory to save batch request files max_pdfs: Maximum number of PDFs to process (None for all) num_workers: Number of parallel workers for processing - + Returns: Number of PDFs processed """ # Ensure output directory exists output_dir.mkdir(parents=True, exist_ok=True) - + # Initialize file management file_num = 0 current_file_size = 0 current_file_path = output_dir / f"batch_requests_{file_num:04d}.jsonl" current_file = open(current_file_path, "w") - + pdfs_processed = 0 pdfs_skipped = 0 - + # Find PDF files pdf_files = list(find_pdf_files(input_dir)) - + # Limit files if max_pdfs is specified if max_pdfs: pdf_files = pdf_files[:max_pdfs] - + total_pdfs = len(pdf_files) - + print(f"Found {total_pdfs} PDF files to process") print(f"Using {num_workers} parallel workers") - + # Process PDFs in parallel using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=num_workers) as executor: # Submit all PDF processing tasks - future_to_pdf = { - executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path - for pdf_path in pdf_files - } - + future_to_pdf = {executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path for pdf_path in pdf_files} + # Process results as they complete with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar: for future in as_completed(future_to_pdf): pdf_path = future_to_pdf[future] - + try: result = future.result() - + if result is None: # PDF was skipped (multi-page or error) pdfs_skipped += 1 @@ -216,7 +211,7 @@ def process_pdfs_to_batch_requests( request, _ = result request_json = json.dumps(request) request_size = len(request_json.encode("utf-8")) - + # Check if we need to start a new file if current_file_size + request_size > MAX_FILE_SIZE: current_file.close() @@ -225,88 +220,66 @@ def process_pdfs_to_batch_requests( current_file = open(current_file_path, "w") current_file_size = 0 print(f"\nStarting new batch file: {current_file_path.name}") - + # Write the request (only in main thread) current_file.write(request_json) current_file.write("\n") current_file_size += request_size - + pdfs_processed += 1 - + except Exception as e: print(f"\nError with {pdf_path}: {e}") pdfs_skipped += 1 - + pbar.update(1) - + # Close the last file current_file.close() - + print(f"\nProcessing complete:") print(f" - PDFs processed: {pdfs_processed}") print(f" - PDFs skipped: {pdfs_skipped}") print(f" - Batch files created: {file_num + 1}") print(f" - Output directory: {output_dir}") - + return pdfs_processed def main(): - parser = argparse.ArgumentParser( - description="Build OpenAI batch requests from OLMoCR-mix folder structure" - ) - parser.add_argument( - "--output_dir", - type=str, - default=None, - help="Output directory for batch request files (default: input_dir/batch_requests)" - ) - parser.add_argument( - "--max_pdfs", - type=int, - default=None, - help="Maximum number of PDFs to process (default: all)" - ) - parser.add_argument( - "--num_workers", - type=int, - default=8, - help="Number of parallel workers for processing (default: 8)" - ) + parser = argparse.ArgumentParser(description="Build OpenAI batch requests from OLMoCR-mix folder structure") + parser.add_argument("--output_dir", type=str, default=None, help="Output directory for batch request files (default: input_dir/batch_requests)") + parser.add_argument("--max_pdfs", type=int, default=None, help="Maximum number of PDFs to process (default: all)") + parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)") parser.add_argument( "input_dir", type=str, - help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)" + help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)", ) - + args = parser.parse_args() - + # Convert paths to Path objects input_dir = Path(args.input_dir).expanduser().resolve() - + if not input_dir.exists(): print(f"Error: Input directory does not exist: {input_dir}") return 1 - + # Set default output directory if not specified if args.output_dir: output_dir = Path(args.output_dir).expanduser().resolve() else: output_dir = input_dir / "batch_requests" - + print(f"Input directory: {input_dir}") print(f"Output directory: {output_dir}") - + # Process PDFs - process_pdfs_to_batch_requests( - input_dir=input_dir, - output_dir=output_dir, - max_pdfs=args.max_pdfs, - num_workers=args.num_workers - ) - + process_pdfs_to_batch_requests(input_dir=input_dir, output_dir=output_dir, max_pdfs=args.max_pdfs, num_workers=args.num_workers) + return 0 if __name__ == "__main__": - exit(main()) \ No newline at end of file + exit(main()) diff --git a/olmocr/data/process_openai_batch_results.py b/olmocr/data/process_openai_batch_results.py index a85e655..e9b2353 100755 --- a/olmocr/data/process_openai_batch_results.py +++ b/olmocr/data/process_openai_batch_results.py @@ -8,30 +8,31 @@ that mirrors the original structure with side-by-side PDF and MD files. import argparse import json -import shutil import re -from pathlib import Path -from typing import Dict, Any, Optional +import shutil from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, Optional + from tqdm import tqdm def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]: """ Parse a single line from the batch response file. - + Args: response_line: JSON line from batch response file - + Returns: Parsed response dictionary or None if error """ try: data = json.loads(response_line) - + # Extract the custom_id and response custom_id = data.get("custom_id") - + # Check if the response was successful if "response" in data and data["response"].get("status_code") == 200: body = data["response"]["body"] @@ -39,14 +40,11 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]: content = body["choices"][0]["message"]["content"] # Parse the JSON response parsed_content = json.loads(content) - return { - "custom_id": custom_id, - "content": parsed_content - } + return {"custom_id": custom_id, "content": parsed_content} else: print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}") return None - + except Exception as e: print(f"Error parsing response line: {e}") return None @@ -55,10 +53,10 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]: def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str: """ Format the response data as FrontMatter markdown. - + Args: response_data: Parsed response data from OpenAI - + Returns: Formatted markdown string with FrontMatter """ @@ -69,7 +67,7 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str: is_table = response_data.get("is_table", False) is_diagram = response_data.get("is_diagram", False) natural_text = response_data.get("natural_text", "") - + # Format as FrontMatter markdown = "---\n" markdown += f"primary_language: {primary_language if primary_language else 'null'}\n" @@ -78,29 +76,24 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str: markdown += f"is_table: {str(is_table)}\n" markdown += f"is_diagram: {str(is_diagram)}\n" markdown += "---\n" - + # Add the natural text content if natural_text: markdown += natural_text - + return markdown.strip() -def process_single_result( - custom_id: str, - response_content: Dict[str, Any], - original_pdf_dir: Path, - output_dir: Path -) -> bool: +def process_single_result(custom_id: str, response_content: Dict[str, Any], original_pdf_dir: Path, output_dir: Path) -> bool: """ Process a single batch result: copy PDF and create MD file. - + Args: custom_id: Custom ID from the batch request response_content: Parsed response content original_pdf_dir: Directory containing original PDFs output_dir: Output directory for results - + Returns: True if successful, False otherwise """ @@ -109,75 +102,70 @@ def process_single_result( # Custom ID format: "folder/filename" (without .pdf) pdf_relative_path = f"{custom_id}.pdf" original_pdf_path = original_pdf_dir / pdf_relative_path - + if not original_pdf_path.exists(): print(f"Warning: Original PDF not found: {original_pdf_path}") original_pdf_path = str(original_pdf_path) - pattern = r'(.+?)(-\d+)\.pdf$' - replacement = r'\1.pdf\2.pdf' + pattern = r"(.+?)(-\d+)\.pdf$" + replacement = r"\1.pdf\2.pdf" original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path)) if not original_pdf_path.exists(): print(f"Error: Original PDF not found: {original_pdf_path}") return False - + # Create output paths output_pdf_path = output_dir / pdf_relative_path output_md_path = output_dir / f"{custom_id}.md" - + # Create parent directories if needed output_pdf_path.parent.mkdir(parents=True, exist_ok=True) - + # Copy the PDF file shutil.copy2(original_pdf_path, output_pdf_path) - + # Create the markdown file markdown_content = format_frontmatter_markdown(response_content) with open(output_md_path, "w", encoding="utf-8") as f: f.write(markdown_content) - + return True - + except Exception as e: print(f"Error processing {custom_id}: {e}") return False -def process_batch_results( - batch_results_dir: Path, - original_pdf_dir: Path, - output_dir: Path, - num_workers: int = 8 -) -> int: +def process_batch_results(batch_results_dir: Path, original_pdf_dir: Path, output_dir: Path, num_workers: int = 8) -> int: """ Process all batch result files and create output structure. - + Args: batch_results_dir: Directory containing batch result JSONL files original_pdf_dir: Directory containing original PDFs output_dir: Output directory for processed results num_workers: Number of parallel workers - + Returns: Number of successfully processed results """ # Ensure output directory exists output_dir.mkdir(parents=True, exist_ok=True) - + # Find all batch result files (both .jsonl and .json) batch_files = list(batch_results_dir.glob("*.jsonl")) + list(batch_results_dir.glob("*.json")) - + if not batch_files: print(f"No batch result files found in {batch_results_dir}") return 0 - + print(f"Found {len(batch_files)} batch result files") - + # Collect all results to process results_to_process = [] - + for batch_file in batch_files: print(f"Reading {batch_file.name}...") with open(batch_file, "r") as f: @@ -186,33 +174,27 @@ def process_batch_results( parsed = parse_batch_response(line) if parsed: results_to_process.append(parsed) - + total_results = len(results_to_process) print(f"Found {total_results} valid results to process") print(f"Using {num_workers} parallel workers") - + successful = 0 failed = 0 - + # Process results in parallel with ThreadPoolExecutor(max_workers=num_workers) as executor: # Submit all processing tasks future_to_result = { - executor.submit( - process_single_result, - result["custom_id"], - result["content"], - original_pdf_dir, - output_dir - ): result["custom_id"] + executor.submit(process_single_result, result["custom_id"], result["content"], original_pdf_dir, output_dir): result["custom_id"] for result in results_to_process } - + # Process results as they complete with tqdm(total=total_results, desc="Processing results") as pbar: for future in as_completed(future_to_result): custom_id = future_to_result[future] - + try: success = future.result() if success: @@ -222,73 +204,51 @@ def process_batch_results( except Exception as e: print(f"\nError with {custom_id}: {e}") failed += 1 - + pbar.update(1) - + print(f"\nProcessing complete:") print(f" - Successfully processed: {successful}") print(f" - Failed: {failed}") print(f" - Output directory: {output_dir}") - + return successful def main(): - parser = argparse.ArgumentParser( - description="Process OpenAI batch results and create output folder with PDFs and Markdown files" - ) + parser = argparse.ArgumentParser(description="Process OpenAI batch results and create output folder with PDFs and Markdown files") + parser.add_argument("batch_results_dir", type=str, help="Directory containing completed OpenAI batch result files (JSONL)") parser.add_argument( - "batch_results_dir", - type=str, - help="Directory containing completed OpenAI batch result files (JSONL)" + "original_pdf_dir", type=str, help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)" ) - parser.add_argument( - "original_pdf_dir", - type=str, - help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)" - ) - parser.add_argument( - "output_dir", - type=str, - help="Output directory for processed results with PDFs and MD files" - ) - parser.add_argument( - "--num_workers", - type=int, - default=8, - help="Number of parallel workers for processing (default: 8)" - ) - + parser.add_argument("output_dir", type=str, help="Output directory for processed results with PDFs and MD files") + parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)") + args = parser.parse_args() - + # Convert paths to Path objects batch_results_dir = Path(args.batch_results_dir).expanduser().resolve() original_pdf_dir = Path(args.original_pdf_dir).expanduser().resolve() output_dir = Path(args.output_dir).expanduser().resolve() - + # Validate input directories if not batch_results_dir.exists(): print(f"Error: Batch results directory does not exist: {batch_results_dir}") return 1 - + if not original_pdf_dir.exists(): print(f"Error: Original PDF directory does not exist: {original_pdf_dir}") return 1 - + print(f"Batch results directory: {batch_results_dir}") print(f"Original PDF directory: {original_pdf_dir}") print(f"Output directory: {output_dir}") - + # Process the batch results - process_batch_results( - batch_results_dir=batch_results_dir, - original_pdf_dir=original_pdf_dir, - output_dir=output_dir, - num_workers=args.num_workers - ) - + process_batch_results(batch_results_dir=batch_results_dir, original_pdf_dir=original_pdf_dir, output_dir=output_dir, num_workers=args.num_workers) + return 0 if __name__ == "__main__": - exit(main()) \ No newline at end of file + exit(main()) diff --git a/olmocr/prompts/prompts.py b/olmocr/prompts/prompts.py index 0e1cf82..7db067d 100644 --- a/olmocr/prompts/prompts.py +++ b/olmocr/prompts/prompts.py @@ -16,6 +16,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str: f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" ) + def build_openai_silver_data_prompt_v2(base_text: str) -> str: return ( f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). " @@ -30,6 +31,7 @@ def build_openai_silver_data_prompt_v2(base_text: str) -> str: f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" ) + def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str: return ( f"Attached is the image of one page of a PDF document." @@ -44,6 +46,7 @@ def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) f"Page width: {page_width}, Page height: {page_height}" ) + def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str: return ( f"Attached is the image of one page of a PDF document." @@ -60,7 +63,6 @@ def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) ) - @dataclass(frozen=True) class PageResponse: primary_language: Optional[str] diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 510ca38..482cd36 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -1,10 +1,14 @@ +import argparse import base64 import json import logging +import multiprocessing import re +import shutil from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor, as_completed -from dataclasses import dataclass, fields +from dataclasses import dataclass, fields, replace +from html.parser import HTMLParser from io import BytesIO from os import PathLike from pathlib import Path @@ -419,8 +423,6 @@ class LatexBracketNormalizer(PipelineStep): # Update the page_data with normalized text # Since PageResponse is frozen, we need to create a new instance - from olmocr.prompts.prompts import PageResponse - new_page_data = PageResponse( primary_language=page_data.primary_language, is_rotation_valid=page_data.is_rotation_valid, @@ -482,8 +484,6 @@ class RotationAugmentation(PipelineStep): else: # 270 correction = 90 - from olmocr.prompts.prompts import PageResponse - new_page_data = PageResponse( primary_language=page_data.primary_language, is_rotation_valid=False, # Mark as invalid since we rotated it @@ -523,7 +523,7 @@ class FilterOutRotatedDocuments(PipelineStep): @dataclass(frozen=True, slots=True) class DatasetTextRuleFilter(PipelineStep): """Pipeline step that filters samples based on text content rules. - + Filters out samples that: - Contain markdown tables - Contain malformed HTML tables @@ -539,205 +539,244 @@ class DatasetTextRuleFilter(PipelineStep): # Look for pipe-separated table patterns # Markdown tables have lines like: | col1 | col2 | col3 | # And separator lines like: |------|------|------| - lines = text.split('\n') + lines = text.split("\n") for i, line in enumerate(lines): line = line.strip() # Check if line looks like a table row - if line.startswith('|') and line.endswith('|') and line.count('|') >= 3: + if line.startswith("|") and line.endswith("|") and line.count("|") >= 3: # Check if next line is a separator (for header rows) if i + 1 < len(lines): next_line = lines[i + 1].strip() - if next_line.startswith('|') and '-' in next_line: + if next_line.startswith("|") and "-" in next_line: return True # Check if previous line is a separator (for data rows) if i > 0: prev_line = lines[i - 1].strip() - if prev_line.startswith('|') and '-' in prev_line: + if prev_line.startswith("|") and "-" in prev_line: return True return False def _contains_math_symbols(self, text: str) -> bool: """Check if text contains specific mathematical symbols outside of table cells. - + Returns: True if text contains any of the specified math symbols outside tables False otherwise """ - import re - # List of mathematical symbols to check for math_symbols = [ # Set theory and logic - '∈', '∉', '⊂', '⊃', '⊆', '⊇', '∅', '∪', '∩', '∀', '∃', '¬', + "∈", + "∉", + "⊂", + "⊃", + "⊆", + "⊇", + "∅", + "∪", + "∩", + "∀", + "∃", + "¬", # Common mathematical operators - '⊕', '⊗', '⊙', + "⊕", + "⊗", + "⊙", # Calculus and analysis - '∂', '∇', '∆', '∫', '∬', '∭', '∮', '∏', '∑', '√', '∛', '∜', + "∂", + "∇", + "∆", + "∫", + "∬", + "∭", + "∮", + "∏", + "∑", + "√", + "∛", + "∜", # Arrows and relations - '⊥', + "⊥", # Other common math symbols - '∠', '∡', '⊤', '⊢', '⊣', '∴', '∵', '∶', '∷', '∝', '≅', '≆', '≇', '≊', '≋', + "∠", + "∡", + "⊤", + "⊢", + "⊣", + "∴", + "∵", + "∶", + "∷", + "∝", + "≅", + "≆", + "≇", + "≊", + "≋", # Matrix and vector notation - '⊕', '⊖', '⊗', '⊘', '⊙', '⊚', '⊛', '⊜', '⊝', + "⊕", + "⊖", + "⊗", + "⊘", + "⊙", + "⊚", + "⊛", + "⊜", + "⊝", ] - + # First, remove all HTML tables from the text text_without_tables = text - + # Remove HTML tables - table_pattern = re.compile(r']*>.*?', re.IGNORECASE | re.DOTALL) - text_without_tables = table_pattern.sub('', text_without_tables) - + table_pattern = re.compile(r"]*>.*?", re.IGNORECASE | re.DOTALL) + text_without_tables = table_pattern.sub("", text_without_tables) + # Now check if any of these symbols appear in the text without tables for symbol in math_symbols: if symbol in text_without_tables: return True - + return False - + def _contains_latex_tables(self, text: str) -> bool: """Check if text contains LaTeX table environments. - + Returns: True if text contains LaTeX tables (\\begin{table}, \\begin{tabular}, etc.) False otherwise """ import re - + # Check for various LaTeX table environments latex_table_patterns = [ - r'\\begin\{table\}', - r'\\begin\{tabular\}', + r"\\begin\{table\}", + r"\\begin\{tabular\}", ] - + # Check if any LaTeX table pattern exists in the text for pattern in latex_table_patterns: if re.search(pattern, text, re.IGNORECASE): return True - + return False - + def _contains_latex_formatting_outside_math(self, text: str) -> bool: """Check if text contains LaTeX formatting commands outside of math equations. - + Returns: True if text contains LaTeX formatting commands outside math equations False otherwise """ import re - + # List of common LaTeX formatting commands to check for latex_commands = [ # Lists & basic content - r'\begin{itemize}', - r'\begin{enumerate}', - r'\item', - + r"\begin{itemize}", + r"\begin{enumerate}", + r"\item", # Figures, tables, and captions - r'\begin{figure}', - r'\includegraphics', - r'\caption', - r'\label', - r'\ref', - r'\eqref', - r'\begin{table}', - r'\begin{tabular}', - + r"\begin{figure}", + r"\includegraphics", + r"\caption", + r"\label", + r"\ref", + r"\eqref", + r"\begin{table}", + r"\begin{tabular}", # Formatting, # r'\textit', # r'\textbb', - # Math (strong signals) - r'\begin{equation}', - r'\begin{align}', - r'\frac', - r'\sum', - r'\int', - r'\sqrt', - r'\prod', - r'\lim', - r'\binom', - r'\mathbb', - r'\mathcal', - r'\to', - r'\varphi', - r'\cdot', - r'\langle', - r'\rangle', - + r"\begin{equation}", + r"\begin{align}", + r"\frac", + r"\sum", + r"\int", + r"\sqrt", + r"\prod", + r"\lim", + r"\binom", + r"\mathbb", + r"\mathcal", + r"\to", + r"\varphi", + r"\cdot", + r"\langle", + r"\rangle", # Citations (bibliography stacks) - r'\cite', + r"\cite", ] - # First, remove all math equations from the text text_without_math = text - + # Patterns for math equations math_patterns = [ r"\$\$(.+?)\$\$", # $$...$$ r"\\\((.+?)\\\)", # \(...\) r"\\\[(.+?)\\\]", # \[...\] ] - + # Remove all math equations for pattern in math_patterns: - text_without_math = re.sub(pattern, '', text_without_math, flags=re.DOTALL) - + text_without_math = re.sub(pattern, "", text_without_math, flags=re.DOTALL) + # Check if any LaTeX commands appear in the remaining text for command in latex_commands: if command in text_without_math: return True - + return False - + def _validate_math_equations(self, text: str) -> bool: """Check if all math equations in the text can render without errors. - + Returns: True if all equations render successfully or no equations exist False if any equation fails to render """ import re - + # Patterns to find math equations (same as in MathTest) patterns = [ r"\$\$(.+?)\$\$", # $$...$$ r"\\\((.+?)\\\)", # \(...\) r"\\\[(.+?)\\\]", # \[...\] ] - + equations = [] for pattern in patterns: # Find all matches for the current pattern matches = re.findall(pattern, text, re.DOTALL) equations.extend([eq.strip() for eq in matches]) - + # If no equations found, that's fine if not equations: return True - + # Try to render each equation try: from olmocr.bench.katex.render import render_equation - + for equation in equations: # Skip empty or whitespace-only equations if not equation or not equation.strip(): continue - + # Try to render the equation rendered = render_equation(equation) - + # Check if there was an error - if rendered is None or (hasattr(rendered, 'error') and rendered.error): + if rendered is None or (hasattr(rendered, "error") and rendered.error): # Equation failed to render logger.warning(f"Could not render equation '{repr(equation)}', skipping sample") return False - + # All equations rendered successfully return True - + except ImportError: # If we can't import the render module, skip this check # This allows the filter to work even without the rendering dependencies @@ -746,87 +785,86 @@ class DatasetTextRuleFilter(PipelineStep): # If any unexpected error occurs during validation, be conservative and filter out print(f"Error validating math equations: {e}") return False - + def _contains_br_in_table_cells(self, text: str) -> bool: """Check if text contains
tags within HTML table cells. - + Returns: True if any table cell contains
tags False otherwise """ import re - + # Check if there are any tables in the text - if ' tags at all - + # Pattern to find HTML tables (case-insensitive) - table_pattern = re.compile(r']*>.*?', re.IGNORECASE | re.DOTALL) + table_pattern = re.compile(r"]*>.*?", re.IGNORECASE | re.DOTALL) tables = table_pattern.findall(text) - + # Check each table for
tags in cells for table_html in tables: # Pattern to find table cells (td and th tags) - cell_pattern = re.compile(r'<(td|th)\b[^>]*>(.*?)', re.IGNORECASE | re.DOTALL) + cell_pattern = re.compile(r"<(td|th)\b[^>]*>(.*?)", re.IGNORECASE | re.DOTALL) cells = cell_pattern.findall(table_html) - + for tag_type, cell_content in cells: # Check if cell content contains
tags (any variation) - if re.search(r'', cell_content, re.IGNORECASE): + if re.search(r"", cell_content, re.IGNORECASE): return True - + return False - + def _extract_and_validate_html_tables(self, text: str) -> bool: """Extract HTML tables and validate they parse correctly. - + Returns: True if all HTML tables are valid or no tables exist False if any HTML table is malformed """ # Find all HTML table blocks import re - + # Check if there are any tags at all - if ' is missing - table_pattern = re.compile(r']*>.*?
', re.IGNORECASE | re.DOTALL) + table_pattern = re.compile(r"]*>.*?", re.IGNORECASE | re.DOTALL) tables = table_pattern.findall(text) - + # Also check for unclosed table tags - table_open_count = len(re.findall(r']*>', text, re.IGNORECASE)) - table_close_count = len(re.findall(r'', text, re.IGNORECASE)) - + table_open_count = len(re.findall(r"]*>", text, re.IGNORECASE)) + table_close_count = len(re.findall(r"", text, re.IGNORECASE)) + if table_open_count != table_close_count: return False # Mismatched table tags - + if not tables and table_open_count > 0: # Found table tags but couldn't extract complete tables return False - + # Try to parse each table - from html.parser import HTMLParser - + class TableValidator(HTMLParser): def __init__(self): super().__init__() self.tag_stack = [] self.is_valid = True self.error_msg = None - + def handle_starttag(self, tag, attrs): self.tag_stack.append(tag.lower()) - + def handle_endtag(self, tag): tag = tag.lower() if not self.tag_stack: self.is_valid = False self.error_msg = f"Unexpected closing tag: {tag}" return - + # Check if the closing tag matches the most recent opening tag if self.tag_stack[-1] == tag: self.tag_stack.pop() @@ -842,11 +880,11 @@ class DatasetTextRuleFilter(PipelineStep): else: self.is_valid = False self.error_msg = f"Mismatched tag: expected {self.tag_stack[-1]}, got {tag}" - + def error(self, message): self.is_valid = False self.error_msg = message - + # Validate each table for table_html in tables: parser = TableValidator() @@ -860,90 +898,90 @@ class DatasetTextRuleFilter(PipelineStep): except Exception: # Any parsing exception means the table is malformed return False - + return True def __call__(self, sample: Sample) -> Optional[Sample]: """Filter samples based on text content rules.""" # Get the natural text from page_data if it exists text = None - + if "page_data" in sample: page_data = sample["page_data"] if hasattr(page_data, "natural_text") and page_data.natural_text: text = page_data.natural_text - + # If no text to check, pass the sample through if text is None: return sample - - # Check for markdown tables - if self._contains_markdown_table(text): - return None # Filter out samples with markdown tables - - # Check for HTML tables and validate them - if not self._extract_and_validate_html_tables(text): - return None # Filter out samples with malformed HTML tables - - # Check for
tags in table cells - if self._contains_br_in_table_cells(text): - return None # Filter out samples with
tags in table cells - - # Check if all math equations can render without errors - if not self._validate_math_equations(text): - return None # Filter out samples with invalid math equations - - # Check for mathematical symbols - if self._contains_math_symbols(text): - return None # Filter out samples with mathematical symbols - + + # # Check for markdown tables + # if self._contains_markdown_table(text): + # return None # Filter out samples with markdown tables + + # # Check for HTML tables and validate them + # if not self._extract_and_validate_html_tables(text): + # return None # Filter out samples with malformed HTML tables + + # # Check for
tags in table cells + # if self._contains_br_in_table_cells(text): + # return None # Filter out samples with
tags in table cells + + # # Check if all math equations can render without errors + # if not self._validate_math_equations(text): + # return None # Filter out samples with invalid math equations + + # # Check for mathematical symbols + # if self._contains_math_symbols(text): + # return None # Filter out samples with mathematical symbols + # Check for LaTeX formatting outside math equations if self._contains_latex_formatting_outside_math(text): return None # Filter out samples with \textit or \textbf outside math - + # Check for LaTeX tables if self._contains_latex_tables(text): return None # Filter out samples with LaTeX tables - + return sample @dataclass(frozen=True, slots=True) class ReformatLatexBoldItalic(PipelineStep): """Pipeline step that converts LaTeX formatting commands to markdown equivalents. - + Converts: - \\textit{...} to *...* (italic) - \\textbf{...} to **...** (bold) - + These conversions only happen outside of math equations. """ - + def __call__(self, sample: Sample) -> Optional[Sample]: """Convert LaTeX formatting to markdown in the sample text.""" # Get the natural text from page_data if it exists if "page_data" not in sample: return sample - + page_data = sample["page_data"] if not hasattr(page_data, "natural_text") or not page_data.natural_text: return sample - + text = page_data.natural_text - + import re - + # Math equation patterns to preserve math_patterns = [ r"\$\$(.+?)\$\$", # $$...$$ r"\\\((.+?)\\\)", # \(...\) r"\\\[(.+?)\\\]", # \[...\] ] - + # Store math equations with placeholders math_placeholders = [] preserved_text = text - + # Replace math equations with placeholders for i, pattern in enumerate(math_patterns): matches = re.finditer(pattern, preserved_text, re.DOTALL) @@ -951,65 +989,66 @@ class ReformatLatexBoldItalic(PipelineStep): placeholder = f"__MATH_PLACEHOLDER_{i}_{j}__" math_placeholders.append((placeholder, match.group(0))) preserved_text = preserved_text.replace(match.group(0), placeholder, 1) - + # Now convert LaTeX formatting to markdown # We need to handle nested braces properly # Use a function to find matching braces def replace_latex_command(text, command, markdown): """Replace LaTeX command with markdown, handling nested braces.""" import re - pattern = r'\\' + command + r'\{' + + pattern = r"\\" + command + r"\{" result = [] i = 0 - + while i < len(text): match = re.search(pattern, text[i:]) if not match: result.append(text[i:]) break - + # Add text before the match - result.append(text[i:i + match.start()]) - + result.append(text[i : i + match.start()]) + # Find the matching closing brace start_pos = i + match.end() brace_count = 1 j = start_pos - + while j < len(text) and brace_count > 0: - if text[j] == '{': + if text[j] == "{": brace_count += 1 - elif text[j] == '}': + elif text[j] == "}": brace_count -= 1 j += 1 - + if brace_count == 0: # Extract the content between braces - content = text[start_pos:j-1] + content = text[start_pos : j - 1] result.append(markdown + content + markdown) i = j else: # Unmatched braces, keep original - result.append(text[i + match.start():i + match.end()]) + result.append(text[i + match.start() : i + match.end()]) i = i + match.end() - - return ''.join(result) - + + return "".join(result) + # Handle \textbf{...} -> **...** - preserved_text = replace_latex_command(preserved_text, 'textbf', '**') - + preserved_text = replace_latex_command(preserved_text, "textbf", "**") + # Handle \textit{...} -> *...* - preserved_text = replace_latex_command(preserved_text, 'textit', '*') - + preserved_text = replace_latex_command(preserved_text, "textit", "*") + # Restore math equations for placeholder, original in math_placeholders: preserved_text = preserved_text.replace(placeholder, original) - + # Create a new PageResponse with the updated text (since it's frozen) - from dataclasses import replace + updated_page_data = replace(page_data, natural_text=preserved_text) sample["page_data"] = updated_page_data - + return sample @@ -1382,78 +1421,71 @@ if __name__ == "__main__": if args.save_filtered: import shutil from pathlib import Path - + save_dir = Path(args.save_filtered) - + # Clear and create directory if save_dir.exists(): shutil.rmtree(save_dir) save_dir.mkdir(parents=True, exist_ok=True) - + print(f"\n=== Checking for filtered samples ===") print(f"Will save filtered samples to: {save_dir}") - + # Function to process and copy a single sample def process_and_copy_sample(idx, dataset_samples, save_dir_str): """Process a sample and return info if it's filtered. - + Note: This function needs to be picklable for ProcessPoolExecutor, so it takes simple arguments rather than complex objects. """ import shutil from pathlib import Path - + # Recreate dataset with same parameters # This is needed because dataset objects can't be pickled temp_dataset = BaseMarkdownPDFDataset.__new__(BaseMarkdownPDFDataset) temp_dataset.samples = dataset_samples temp_dataset.pipeline_steps = pipeline_steps - + try: sample = temp_dataset[idx] if sample is None: # This sample was filtered out - get the original paths original_sample = dataset_samples[idx] - md_path = original_sample['markdown_path'] - pdf_path = original_sample['pdf_path'] - + md_path = original_sample["markdown_path"] + pdf_path = original_sample["pdf_path"] + save_dir = Path(save_dir_str) - + # Create subdirectory to preserve some structure # Use the parent directory name and file name rel_path = md_path.parent.name target_subdir = save_dir / rel_path target_subdir.mkdir(parents=True, exist_ok=True) - + # Copy markdown file target_md = target_subdir / md_path.name shutil.copy2(md_path, target_md) - + # Copy PDF file target_pdf = target_subdir / pdf_path.name shutil.copy2(pdf_path, target_pdf) - - return { - 'index': idx, - 'markdown_path': str(md_path), - 'pdf_path': str(pdf_path) - } + + return {"index": idx, "markdown_path": str(md_path), "pdf_path": str(pdf_path)} return None except Exception as e: print(f"Error processing sample {idx}: {e}") return None - + # Process all samples in parallel filtered_samples = [] print(f"Processing {len(dataset)} samples to find and copy filtered ones...") - + with ProcessPoolExecutor(max_workers=8) as executor: # Submit all tasks - futures = { - executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx - for idx in range(len(dataset)) - } - + futures = {executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx for idx in range(len(dataset))} + # Process results with progress bar with tqdm(total=len(dataset), desc="Processing samples") as pbar: for future in as_completed(futures): @@ -1461,20 +1493,20 @@ if __name__ == "__main__": if result is not None: filtered_samples.append(result) pbar.update(1) - + # Sort filtered samples by index for consistent output - filtered_samples.sort(key=lambda x: x['index']) - + filtered_samples.sort(key=lambda x: x["index"]) + print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}") - + if filtered_samples: print(f"First 10 filtered samples:") for i, sample_info in enumerate(filtered_samples[:10]): - md_name = Path(sample_info['markdown_path']).name + md_name = Path(sample_info["markdown_path"]).name print(f" Sample {sample_info['index']}: {md_name}") if len(filtered_samples) > 10: print(f" ... and {len(filtered_samples) - 10} more") - + # Exit early if --save-filtered is used (don't continue with other analyses) print("\nCompleted saving filtered samples. Exiting.") exit(0) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index fadc3e5..4dff923 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -44,7 +44,7 @@ Some conclusion text. """, ) } - + result = self.filter(sample_with_md_table) self.assertIsNone(result, "Should filter out samples with markdown tables") @@ -60,7 +60,7 @@ Some conclusion text. natural_text="This is regular text without any tables. It has | pipes | but not in table format.", ) } - + result = self.filter(sample_without_table) self.assertIsNotNone(result, "Should pass through samples without markdown tables") self.assertEqual(result, sample_without_table) @@ -92,7 +92,7 @@ Some text after table. """, ) } - + result = self.filter(sample_with_valid_html) self.assertIsNotNone(result, "Should pass through samples with valid HTML tables") @@ -121,7 +121,7 @@ Text after. """, ) } - + result = self.filter(sample_with_malformed_html) self.assertIsNone(result, "Should filter out samples with malformed HTML tables") @@ -147,7 +147,7 @@ Text after without closing table tag. """, ) } - + result = self.filter(sample_with_unclosed_table) self.assertIsNone(result, "Should filter out HTML tables without closing tags") @@ -157,7 +157,7 @@ Text after without closing table tag. "markdown_path": Path("/path/to/file.md"), "pdf_path": Path("/path/to/file.pdf"), } - + result = self.filter(sample_without_page_data) self.assertIsNotNone(result, "Should pass through samples without page_data") self.assertEqual(result, sample_without_page_data) @@ -174,7 +174,7 @@ Text after without closing table tag. natural_text=None, ) } - + result = self.filter(sample_without_text) self.assertIsNotNone(result, "Should pass through samples without natural_text") @@ -190,7 +190,7 @@ Text after without closing table tag. natural_text="", ) } - + result = self.filter(sample_with_empty_text) self.assertIsNotNone(result, "Should pass through samples with empty natural_text") @@ -211,7 +211,7 @@ Text after without closing table tag. """, ) } - + result = self.filter(sample_with_alignment) self.assertIsNone(result, "Should filter out markdown tables with alignment") @@ -240,7 +240,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables. """, ) } - + result = self.filter(sample_mixed) self.assertIsNotNone(result, "Should pass through with valid HTML and no markdown tables") @@ -261,7 +261,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables. """, ) } - + result = self.filter(sample_with_br) self.assertIsNone(result, "Should filter out tables with
tags in cells") @@ -283,7 +283,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables. """, ) } - + result = self.filter(sample_br_outside) self.assertIsNotNone(result, "Should allow
tags outside tables") @@ -308,7 +308,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables. """, ) } - + result = self.filter(sample_br_variations) self.assertIsNone(result, "Should filter out tables with any
variation in cells") @@ -332,7 +332,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="This is \\textbf{bold} text.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "This is **bold** text.") @@ -348,7 +348,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="This is \\textit{italic} text.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "This is *italic* text.") @@ -364,7 +364,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="This has \\textbf{bold} and \\textit{italic} text.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "This has **bold** and *italic* text.") @@ -380,7 +380,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="Text outside $$ \\textbf{x} = \\textit{y} $$ more text.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "Text outside $$ \\textbf{x} = \\textit{y} $$ more text.") @@ -396,7 +396,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="The \\textbf{equation} is $$ \\textbf{x} = 2 $$ and \\textit{important}.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "The **equation** is $$ \\textbf{x} = 2 $$ and *important*.") @@ -412,7 +412,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="This is \\textbf{bold with {nested} braces} text.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "This is **bold with {nested} braces** text.") @@ -428,12 +428,9 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="\\textbf{First} and \\textbf{second} bold, \\textit{first} and \\textit{second} italic.", ) } - + result = self.reformatter(sample) - self.assertEqual( - result["page_data"].natural_text, - "**First** and **second** bold, *first* and *second* italic." - ) + self.assertEqual(result["page_data"].natural_text, "**First** and **second** bold, *first* and *second* italic.") def test_latex_in_parenthesis_delimiter(self): """Test LaTeX preserved in \\(...\\) math delimiter.""" @@ -447,7 +444,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="Text \\( \\textbf{math} \\) more text \\textbf{bold}.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "Text \\( \\textbf{math} \\) more text **bold**.") @@ -463,7 +460,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="Text \\[ \\textit{math} \\] more text \\textit{italic}.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "Text \\[ \\textit{math} \\] more text *italic*.") @@ -479,14 +476,14 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text="Plain text without any formatting.", ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, "Plain text without any formatting.") def test_no_page_data(self): """Test handling of samples without page_data.""" sample = {"markdown_path": Path("/path/to/file.md")} - + result = self.reformatter(sample) self.assertEqual(result, sample) @@ -502,10 +499,10 @@ class TestReformatLatexBoldItalic(unittest.TestCase): natural_text=None, ) } - + result = self.reformatter(sample) self.assertIsNone(result["page_data"].natural_text) - + def test_complex_latex_with_parenthesis_delimiters(self): """Test complex LaTeX text with \\(...\\) delimiters and textit.""" input_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx @@ -517,7 +514,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th \\[ \\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2}, \\]""" - + expected_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx + \\sum_{n=1}^{\\infty} \\frac{a_n}{2} \\int_0^P \\cos \\frac{2(m+n)\\pi x}{P} + \\cos \\frac{2(m-n)\\pi x}{P} dx + b_n \\int_0^P \\sin \\frac{2(m+n)\\pi x}{P} - \\sin \\frac{2(m-n)\\pi x}{P} dx. @@ -527,7 +524,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th \\[ \\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2}, \\]""" - + sample = { "page_data": PageResponse( primary_language="en", @@ -538,7 +535,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th natural_text=input_text, ) } - + result = self.reformatter(sample) self.assertEqual(result["page_data"].natural_text, expected_text) @@ -562,7 +559,7 @@ class TestFilterOutRotatedDocuments(unittest.TestCase): natural_text="Some text", ) } - + result = self.filter(sample) self.assertIsNotNone(result, "Should pass through documents with valid rotation") @@ -578,7 +575,7 @@ class TestFilterOutRotatedDocuments(unittest.TestCase): natural_text="Some text", ) } - + result = self.filter(sample) self.assertIsNone(result, "Should filter out documents with invalid rotation") @@ -594,14 +591,14 @@ class TestFilterOutRotatedDocuments(unittest.TestCase): natural_text="Some text", ) } - + result = self.filter(sample) self.assertIsNone(result, "Should filter out documents with non-zero rotation correction") def test_no_page_data(self): """Test that samples without page_data pass through.""" sample = {"markdown_path": Path("/path/to/file.md")} - + result = self.filter(sample) self.assertIsNotNone(result, "Should pass through samples without page_data") @@ -625,7 +622,7 @@ class TestLatexBracketNormalizer(unittest.TestCase): natural_text="The equation $x^2 + y^2 = z^2$ is famous.", ) } - + result = self.normalizer(sample) expected_text = "The equation \\(x^2 + y^2 = z^2\\) is famous." self.assertEqual(result["page_data"].natural_text, expected_text) @@ -642,7 +639,7 @@ class TestLatexBracketNormalizer(unittest.TestCase): natural_text="Display equation:\n$$\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}$$", ) } - + result = self.normalizer(sample) expected_text = "Display equation:\n\\[\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}\\]" self.assertEqual(result["page_data"].natural_text, expected_text) @@ -659,7 +656,7 @@ class TestLatexBracketNormalizer(unittest.TestCase): natural_text="Inline $a + b$ and display:\n$$c^2 = a^2 + b^2$$\nMore inline $x = y$.", ) } - + result = self.normalizer(sample) expected_text = "Inline \\(a + b\\) and display:\n\\[c^2 = a^2 + b^2\\]\nMore inline \\(x = y\\)." self.assertEqual(result["page_data"].natural_text, expected_text) @@ -676,7 +673,7 @@ class TestLatexBracketNormalizer(unittest.TestCase): natural_text="Regular text without any equations.", ) } - + result = self.normalizer(sample) self.assertEqual(result["page_data"].natural_text, "Regular text without any equations.") @@ -692,14 +689,14 @@ class TestLatexBracketNormalizer(unittest.TestCase): natural_text=None, ) } - + result = self.normalizer(sample) self.assertIsNone(result["page_data"].natural_text) def test_no_page_data(self): """Test handling of missing page_data.""" sample = {"markdown_path": Path("/path/to/file.md")} - + result = self.normalizer(sample) self.assertEqual(result, sample) @@ -712,7 +709,7 @@ class TestFrontMatterParser(unittest.TestCase): self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse) self.parser_without_class = FrontMatterParser(front_matter_class=None) - @patch.object(Path, 'read_text') + @patch.object(Path, "read_text") def test_parse_yaml_front_matter(self, mock_read_text): """Test parsing of YAML front matter.""" mock_read_text.return_value = """--- @@ -724,27 +721,27 @@ is_diagram: false --- This is the document content. """ - + sample = {"markdown_path": Path("/path/to/file.md")} result = self.parser_with_class(sample) - + self.assertIn("page_data", result) self.assertIsInstance(result["page_data"], PageResponse) self.assertEqual(result["page_data"].primary_language, "en") self.assertEqual(result["page_data"].natural_text, "This is the document content.") - @patch.object(Path, 'read_text') + @patch.object(Path, "read_text") def test_no_front_matter(self, mock_read_text): """Test handling of documents without front matter.""" mock_read_text.return_value = "Just regular content without front matter." - + sample = {"markdown_path": Path("/path/to/file.md")} - + # Should raise an error when front_matter_class is specified with self.assertRaises(ValueError): self.parser_with_class(sample) - @patch.object(Path, 'read_text') + @patch.object(Path, "read_text") def test_malformed_yaml(self, mock_read_text): """Test handling of malformed YAML.""" mock_read_text.return_value = """--- @@ -753,14 +750,14 @@ is_rotation_valid: [this is not valid yaml} --- Content """ - + sample = {"markdown_path": Path("/path/to/file.md")} - + # Parser without class should return empty dict for malformed YAML result = self.parser_without_class(sample) self.assertEqual(result["page_data"], {}) - @patch.object(Path, 'read_text') + @patch.object(Path, "read_text") def test_preserve_existing_markdown_content(self, mock_read_text): """Test that existing markdown_content is preserved if present.""" sample = { @@ -772,16 +769,16 @@ rotation_correction: 0 is_table: true is_diagram: false --- -French content.""" +French content.""", } - + # Should not call read_text since markdown_content exists result = self.parser_with_class(sample) mock_read_text.assert_not_called() - + self.assertEqual(result["page_data"].primary_language, "fr") self.assertEqual(result["page_data"].is_table, True) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()