mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-01 18:43:45 +00:00
Lints
This commit is contained in:
parent
768cb33937
commit
41201b6317
@ -8,16 +8,19 @@ from olmocr.bench.prompts import (
|
||||
build_basic_prompt,
|
||||
build_openai_silver_data_prompt_no_document_anchoring,
|
||||
)
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
|
||||
from olmocr.data.renderpdf import (
|
||||
get_png_dimensions_from_base64,
|
||||
render_pdf_to_base64png,
|
||||
)
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.prompts.prompts import (
|
||||
PageResponse,
|
||||
build_finetuning_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
openai_response_format_schema,
|
||||
build_openai_silver_data_prompt_v2,
|
||||
build_openai_silver_data_prompt_v2_simple,
|
||||
build_openai_silver_data_prompt_v3_simple,
|
||||
openai_response_format_schema,
|
||||
)
|
||||
|
||||
|
||||
@ -65,7 +68,7 @@ def run_chatgpt(
|
||||
prompt = build_openai_silver_data_prompt_v2_simple(width, height)
|
||||
elif prompt_template == "fullv3simple":
|
||||
width, height = get_png_dimensions_from_base64(image_base64)
|
||||
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
|
||||
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
|
||||
else:
|
||||
raise ValueError("Unknown prompt template")
|
||||
|
||||
@ -82,7 +85,7 @@ def run_chatgpt(
|
||||
],
|
||||
temperature=temperature,
|
||||
max_completion_tokens=20000,
|
||||
#reasoning_effort="high",
|
||||
# reasoning_effort="high",
|
||||
response_format=openai_response_format_schema() if response_template == "json" else None,
|
||||
safety_identifier="olmocr-bench-runner",
|
||||
)
|
||||
|
||||
@ -8,14 +8,17 @@ and generates OpenAI batch API requests for processing PDFs.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Generator, Dict, Any, Optional, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
|
||||
from olmocr.data.renderpdf import (
|
||||
get_png_dimensions_from_base64,
|
||||
render_pdf_to_base64png,
|
||||
)
|
||||
from olmocr.prompts.prompts import (
|
||||
build_openai_silver_data_prompt_v3_simple,
|
||||
openai_response_format_schema,
|
||||
@ -28,10 +31,10 @@ MAX_FILE_SIZE = 99 * 1024 * 1024 # 99MB in bytes
|
||||
def validate_single_page_pdf(pdf_path: Path) -> bool:
|
||||
"""
|
||||
Validate that a PDF has exactly one page.
|
||||
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
|
||||
|
||||
Returns:
|
||||
True if PDF has exactly one page, False otherwise
|
||||
"""
|
||||
@ -46,32 +49,32 @@ def validate_single_page_pdf(pdf_path: Path) -> bool:
|
||||
def build_custom_id(pdf_path: Path, base_dir: Path) -> str:
|
||||
"""
|
||||
Build a custom ID for the request that can be used to recover the file later.
|
||||
|
||||
|
||||
The ID preserves the full path structure for easy recovery.
|
||||
Example: extracted/document_id.pdf becomes "extracted/document_id"
|
||||
|
||||
|
||||
Args:
|
||||
pdf_path: Full path to the PDF file
|
||||
base_dir: Base directory containing the processed folder
|
||||
|
||||
|
||||
Returns:
|
||||
Custom ID string that preserves path structure
|
||||
"""
|
||||
# Get relative path from base directory
|
||||
rel_path = pdf_path.relative_to(base_dir)
|
||||
# Remove .pdf extension but keep directory structure
|
||||
path_without_ext = str(rel_path).replace('.pdf', '')
|
||||
path_without_ext = str(rel_path).replace(".pdf", "")
|
||||
return path_without_ext
|
||||
|
||||
|
||||
def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[str, Any], Path]]:
|
||||
"""
|
||||
Process a single PDF and return the batch request if valid.
|
||||
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
base_dir: Base directory for building custom IDs
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (request dict, pdf_path) if successful, None otherwise
|
||||
"""
|
||||
@ -83,20 +86,20 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
|
||||
except Exception as e:
|
||||
print(f"Error reading PDF {pdf_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
# Render PDF to base64 image
|
||||
image_base64 = render_pdf_to_base64png(str(pdf_path), page_num=1, target_longest_image_dim=TARGET_IMAGE_DIM)
|
||||
|
||||
|
||||
# Get image dimensions for the prompt
|
||||
width, height = get_png_dimensions_from_base64(image_base64)
|
||||
|
||||
|
||||
# Build the prompt using v3 simple version
|
||||
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
|
||||
|
||||
|
||||
# Build custom ID
|
||||
custom_id = build_custom_id(pdf_path, base_dir)
|
||||
|
||||
|
||||
# Build the request in OpenAI batch format
|
||||
request = {
|
||||
"custom_id": custom_id,
|
||||
@ -118,7 +121,7 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
|
||||
"response_format": openai_response_format_schema(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
return (request, pdf_path)
|
||||
except Exception as e:
|
||||
print(f"Error processing {pdf_path}: {e}")
|
||||
@ -128,21 +131,21 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
|
||||
def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
|
||||
"""
|
||||
Find all PDF files in the processed folder structure.
|
||||
|
||||
|
||||
The structure is expected to be:
|
||||
processed_XX_subset_split/
|
||||
extracted/
|
||||
*.pdf
|
||||
|
||||
|
||||
Or for hugging_face downloads:
|
||||
hugging_face/
|
||||
pdf_tarballs/
|
||||
extracted/
|
||||
*.pdf
|
||||
|
||||
|
||||
Args:
|
||||
input_dir: Input directory path
|
||||
|
||||
|
||||
Yields:
|
||||
Path objects for each PDF file found
|
||||
"""
|
||||
@ -151,64 +154,56 @@ def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
|
||||
yield pdf_path
|
||||
|
||||
|
||||
def process_pdfs_to_batch_requests(
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
max_pdfs: int = None,
|
||||
num_workers: int = 8
|
||||
) -> int:
|
||||
def process_pdfs_to_batch_requests(input_dir: Path, output_dir: Path, max_pdfs: int = None, num_workers: int = 8) -> int:
|
||||
"""
|
||||
Process PDFs and create batch request files using parallel processing.
|
||||
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing the processed folder structure
|
||||
output_dir: Directory to save batch request files
|
||||
max_pdfs: Maximum number of PDFs to process (None for all)
|
||||
num_workers: Number of parallel workers for processing
|
||||
|
||||
|
||||
Returns:
|
||||
Number of PDFs processed
|
||||
"""
|
||||
# Ensure output directory exists
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Initialize file management
|
||||
file_num = 0
|
||||
current_file_size = 0
|
||||
current_file_path = output_dir / f"batch_requests_{file_num:04d}.jsonl"
|
||||
current_file = open(current_file_path, "w")
|
||||
|
||||
|
||||
pdfs_processed = 0
|
||||
pdfs_skipped = 0
|
||||
|
||||
|
||||
# Find PDF files
|
||||
pdf_files = list(find_pdf_files(input_dir))
|
||||
|
||||
|
||||
# Limit files if max_pdfs is specified
|
||||
if max_pdfs:
|
||||
pdf_files = pdf_files[:max_pdfs]
|
||||
|
||||
|
||||
total_pdfs = len(pdf_files)
|
||||
|
||||
|
||||
print(f"Found {total_pdfs} PDF files to process")
|
||||
print(f"Using {num_workers} parallel workers")
|
||||
|
||||
|
||||
# Process PDFs in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
# Submit all PDF processing tasks
|
||||
future_to_pdf = {
|
||||
executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path
|
||||
for pdf_path in pdf_files
|
||||
}
|
||||
|
||||
future_to_pdf = {executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path for pdf_path in pdf_files}
|
||||
|
||||
# Process results as they complete
|
||||
with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar:
|
||||
for future in as_completed(future_to_pdf):
|
||||
pdf_path = future_to_pdf[future]
|
||||
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
|
||||
|
||||
if result is None:
|
||||
# PDF was skipped (multi-page or error)
|
||||
pdfs_skipped += 1
|
||||
@ -216,7 +211,7 @@ def process_pdfs_to_batch_requests(
|
||||
request, _ = result
|
||||
request_json = json.dumps(request)
|
||||
request_size = len(request_json.encode("utf-8"))
|
||||
|
||||
|
||||
# Check if we need to start a new file
|
||||
if current_file_size + request_size > MAX_FILE_SIZE:
|
||||
current_file.close()
|
||||
@ -225,88 +220,66 @@ def process_pdfs_to_batch_requests(
|
||||
current_file = open(current_file_path, "w")
|
||||
current_file_size = 0
|
||||
print(f"\nStarting new batch file: {current_file_path.name}")
|
||||
|
||||
|
||||
# Write the request (only in main thread)
|
||||
current_file.write(request_json)
|
||||
current_file.write("\n")
|
||||
current_file_size += request_size
|
||||
|
||||
|
||||
pdfs_processed += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError with {pdf_path}: {e}")
|
||||
pdfs_skipped += 1
|
||||
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
# Close the last file
|
||||
current_file.close()
|
||||
|
||||
|
||||
print(f"\nProcessing complete:")
|
||||
print(f" - PDFs processed: {pdfs_processed}")
|
||||
print(f" - PDFs skipped: {pdfs_skipped}")
|
||||
print(f" - Batch files created: {file_num + 1}")
|
||||
print(f" - Output directory: {output_dir}")
|
||||
|
||||
|
||||
return pdfs_processed
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Build OpenAI batch requests from OLMoCR-mix folder structure"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory for batch request files (default: input_dir/batch_requests)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_pdfs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of PDFs to process (default: all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of parallel workers for processing (default: 8)"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Build OpenAI batch requests from OLMoCR-mix folder structure")
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Output directory for batch request files (default: input_dir/batch_requests)")
|
||||
parser.add_argument("--max_pdfs", type=int, default=None, help="Maximum number of PDFs to process (default: all)")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
|
||||
parser.add_argument(
|
||||
"input_dir",
|
||||
type=str,
|
||||
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)"
|
||||
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Convert paths to Path objects
|
||||
input_dir = Path(args.input_dir).expanduser().resolve()
|
||||
|
||||
|
||||
if not input_dir.exists():
|
||||
print(f"Error: Input directory does not exist: {input_dir}")
|
||||
return 1
|
||||
|
||||
|
||||
# Set default output directory if not specified
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir).expanduser().resolve()
|
||||
else:
|
||||
output_dir = input_dir / "batch_requests"
|
||||
|
||||
|
||||
print(f"Input directory: {input_dir}")
|
||||
print(f"Output directory: {output_dir}")
|
||||
|
||||
|
||||
# Process PDFs
|
||||
process_pdfs_to_batch_requests(
|
||||
input_dir=input_dir,
|
||||
output_dir=output_dir,
|
||||
max_pdfs=args.max_pdfs,
|
||||
num_workers=args.num_workers
|
||||
)
|
||||
|
||||
process_pdfs_to_batch_requests(input_dir=input_dir, output_dir=output_dir, max_pdfs=args.max_pdfs, num_workers=args.num_workers)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
exit(main())
|
||||
|
||||
@ -8,30 +8,31 @@ that mirrors the original structure with side-by-side PDF and MD files.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse a single line from the batch response file.
|
||||
|
||||
|
||||
Args:
|
||||
response_line: JSON line from batch response file
|
||||
|
||||
|
||||
Returns:
|
||||
Parsed response dictionary or None if error
|
||||
"""
|
||||
try:
|
||||
data = json.loads(response_line)
|
||||
|
||||
|
||||
# Extract the custom_id and response
|
||||
custom_id = data.get("custom_id")
|
||||
|
||||
|
||||
# Check if the response was successful
|
||||
if "response" in data and data["response"].get("status_code") == 200:
|
||||
body = data["response"]["body"]
|
||||
@ -39,14 +40,11 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
|
||||
content = body["choices"][0]["message"]["content"]
|
||||
# Parse the JSON response
|
||||
parsed_content = json.loads(content)
|
||||
return {
|
||||
"custom_id": custom_id,
|
||||
"content": parsed_content
|
||||
}
|
||||
return {"custom_id": custom_id, "content": parsed_content}
|
||||
else:
|
||||
print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing response line: {e}")
|
||||
return None
|
||||
@ -55,10 +53,10 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
|
||||
def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format the response data as FrontMatter markdown.
|
||||
|
||||
|
||||
Args:
|
||||
response_data: Parsed response data from OpenAI
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted markdown string with FrontMatter
|
||||
"""
|
||||
@ -69,7 +67,7 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
|
||||
is_table = response_data.get("is_table", False)
|
||||
is_diagram = response_data.get("is_diagram", False)
|
||||
natural_text = response_data.get("natural_text", "")
|
||||
|
||||
|
||||
# Format as FrontMatter
|
||||
markdown = "---\n"
|
||||
markdown += f"primary_language: {primary_language if primary_language else 'null'}\n"
|
||||
@ -78,29 +76,24 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
|
||||
markdown += f"is_table: {str(is_table)}\n"
|
||||
markdown += f"is_diagram: {str(is_diagram)}\n"
|
||||
markdown += "---\n"
|
||||
|
||||
|
||||
# Add the natural text content
|
||||
if natural_text:
|
||||
markdown += natural_text
|
||||
|
||||
|
||||
return markdown.strip()
|
||||
|
||||
|
||||
def process_single_result(
|
||||
custom_id: str,
|
||||
response_content: Dict[str, Any],
|
||||
original_pdf_dir: Path,
|
||||
output_dir: Path
|
||||
) -> bool:
|
||||
def process_single_result(custom_id: str, response_content: Dict[str, Any], original_pdf_dir: Path, output_dir: Path) -> bool:
|
||||
"""
|
||||
Process a single batch result: copy PDF and create MD file.
|
||||
|
||||
|
||||
Args:
|
||||
custom_id: Custom ID from the batch request
|
||||
response_content: Parsed response content
|
||||
original_pdf_dir: Directory containing original PDFs
|
||||
output_dir: Output directory for results
|
||||
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
@ -109,75 +102,70 @@ def process_single_result(
|
||||
# Custom ID format: "folder/filename" (without .pdf)
|
||||
pdf_relative_path = f"{custom_id}.pdf"
|
||||
original_pdf_path = original_pdf_dir / pdf_relative_path
|
||||
|
||||
|
||||
if not original_pdf_path.exists():
|
||||
print(f"Warning: Original PDF not found: {original_pdf_path}")
|
||||
|
||||
original_pdf_path = str(original_pdf_path)
|
||||
pattern = r'(.+?)(-\d+)\.pdf$'
|
||||
replacement = r'\1.pdf\2.pdf'
|
||||
pattern = r"(.+?)(-\d+)\.pdf$"
|
||||
replacement = r"\1.pdf\2.pdf"
|
||||
|
||||
original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path))
|
||||
|
||||
if not original_pdf_path.exists():
|
||||
print(f"Error: Original PDF not found: {original_pdf_path}")
|
||||
return False
|
||||
|
||||
|
||||
# Create output paths
|
||||
output_pdf_path = output_dir / pdf_relative_path
|
||||
output_md_path = output_dir / f"{custom_id}.md"
|
||||
|
||||
|
||||
# Create parent directories if needed
|
||||
output_pdf_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Copy the PDF file
|
||||
shutil.copy2(original_pdf_path, output_pdf_path)
|
||||
|
||||
|
||||
# Create the markdown file
|
||||
markdown_content = format_frontmatter_markdown(response_content)
|
||||
with open(output_md_path, "w", encoding="utf-8") as f:
|
||||
f.write(markdown_content)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {custom_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def process_batch_results(
|
||||
batch_results_dir: Path,
|
||||
original_pdf_dir: Path,
|
||||
output_dir: Path,
|
||||
num_workers: int = 8
|
||||
) -> int:
|
||||
def process_batch_results(batch_results_dir: Path, original_pdf_dir: Path, output_dir: Path, num_workers: int = 8) -> int:
|
||||
"""
|
||||
Process all batch result files and create output structure.
|
||||
|
||||
|
||||
Args:
|
||||
batch_results_dir: Directory containing batch result JSONL files
|
||||
original_pdf_dir: Directory containing original PDFs
|
||||
output_dir: Output directory for processed results
|
||||
num_workers: Number of parallel workers
|
||||
|
||||
|
||||
Returns:
|
||||
Number of successfully processed results
|
||||
"""
|
||||
# Ensure output directory exists
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Find all batch result files (both .jsonl and .json)
|
||||
batch_files = list(batch_results_dir.glob("*.jsonl")) + list(batch_results_dir.glob("*.json"))
|
||||
|
||||
|
||||
if not batch_files:
|
||||
print(f"No batch result files found in {batch_results_dir}")
|
||||
return 0
|
||||
|
||||
|
||||
print(f"Found {len(batch_files)} batch result files")
|
||||
|
||||
|
||||
# Collect all results to process
|
||||
results_to_process = []
|
||||
|
||||
|
||||
for batch_file in batch_files:
|
||||
print(f"Reading {batch_file.name}...")
|
||||
with open(batch_file, "r") as f:
|
||||
@ -186,33 +174,27 @@ def process_batch_results(
|
||||
parsed = parse_batch_response(line)
|
||||
if parsed:
|
||||
results_to_process.append(parsed)
|
||||
|
||||
|
||||
total_results = len(results_to_process)
|
||||
print(f"Found {total_results} valid results to process")
|
||||
print(f"Using {num_workers} parallel workers")
|
||||
|
||||
|
||||
successful = 0
|
||||
failed = 0
|
||||
|
||||
|
||||
# Process results in parallel
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
# Submit all processing tasks
|
||||
future_to_result = {
|
||||
executor.submit(
|
||||
process_single_result,
|
||||
result["custom_id"],
|
||||
result["content"],
|
||||
original_pdf_dir,
|
||||
output_dir
|
||||
): result["custom_id"]
|
||||
executor.submit(process_single_result, result["custom_id"], result["content"], original_pdf_dir, output_dir): result["custom_id"]
|
||||
for result in results_to_process
|
||||
}
|
||||
|
||||
|
||||
# Process results as they complete
|
||||
with tqdm(total=total_results, desc="Processing results") as pbar:
|
||||
for future in as_completed(future_to_result):
|
||||
custom_id = future_to_result[future]
|
||||
|
||||
|
||||
try:
|
||||
success = future.result()
|
||||
if success:
|
||||
@ -222,73 +204,51 @@ def process_batch_results(
|
||||
except Exception as e:
|
||||
print(f"\nError with {custom_id}: {e}")
|
||||
failed += 1
|
||||
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
print(f"\nProcessing complete:")
|
||||
print(f" - Successfully processed: {successful}")
|
||||
print(f" - Failed: {failed}")
|
||||
print(f" - Output directory: {output_dir}")
|
||||
|
||||
|
||||
return successful
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process OpenAI batch results and create output folder with PDFs and Markdown files"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Process OpenAI batch results and create output folder with PDFs and Markdown files")
|
||||
parser.add_argument("batch_results_dir", type=str, help="Directory containing completed OpenAI batch result files (JSONL)")
|
||||
parser.add_argument(
|
||||
"batch_results_dir",
|
||||
type=str,
|
||||
help="Directory containing completed OpenAI batch result files (JSONL)"
|
||||
"original_pdf_dir", type=str, help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"original_pdf_dir",
|
||||
type=str,
|
||||
help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir",
|
||||
type=str,
|
||||
help="Output directory for processed results with PDFs and MD files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of parallel workers for processing (default: 8)"
|
||||
)
|
||||
|
||||
parser.add_argument("output_dir", type=str, help="Output directory for processed results with PDFs and MD files")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Convert paths to Path objects
|
||||
batch_results_dir = Path(args.batch_results_dir).expanduser().resolve()
|
||||
original_pdf_dir = Path(args.original_pdf_dir).expanduser().resolve()
|
||||
output_dir = Path(args.output_dir).expanduser().resolve()
|
||||
|
||||
|
||||
# Validate input directories
|
||||
if not batch_results_dir.exists():
|
||||
print(f"Error: Batch results directory does not exist: {batch_results_dir}")
|
||||
return 1
|
||||
|
||||
|
||||
if not original_pdf_dir.exists():
|
||||
print(f"Error: Original PDF directory does not exist: {original_pdf_dir}")
|
||||
return 1
|
||||
|
||||
|
||||
print(f"Batch results directory: {batch_results_dir}")
|
||||
print(f"Original PDF directory: {original_pdf_dir}")
|
||||
print(f"Output directory: {output_dir}")
|
||||
|
||||
|
||||
# Process the batch results
|
||||
process_batch_results(
|
||||
batch_results_dir=batch_results_dir,
|
||||
original_pdf_dir=original_pdf_dir,
|
||||
output_dir=output_dir,
|
||||
num_workers=args.num_workers
|
||||
)
|
||||
|
||||
process_batch_results(batch_results_dir=batch_results_dir, original_pdf_dir=original_pdf_dir, output_dir=output_dir, num_workers=args.num_workers)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
exit(main())
|
||||
|
||||
@ -16,6 +16,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
|
||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||
)
|
||||
|
||||
|
||||
def build_openai_silver_data_prompt_v2(base_text: str) -> str:
|
||||
return (
|
||||
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). "
|
||||
@ -30,6 +31,7 @@ def build_openai_silver_data_prompt_v2(base_text: str) -> str:
|
||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||
)
|
||||
|
||||
|
||||
def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str:
|
||||
return (
|
||||
f"Attached is the image of one page of a PDF document."
|
||||
@ -44,6 +46,7 @@ def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int)
|
||||
f"Page width: {page_width}, Page height: {page_height}"
|
||||
)
|
||||
|
||||
|
||||
def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str:
|
||||
return (
|
||||
f"Attached is the image of one page of a PDF document."
|
||||
@ -60,7 +63,6 @@ def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PageResponse:
|
||||
primary_language: Optional[str]
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import re
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from html.parser import HTMLParser
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
@ -419,8 +423,6 @@ class LatexBracketNormalizer(PipelineStep):
|
||||
|
||||
# Update the page_data with normalized text
|
||||
# Since PageResponse is frozen, we need to create a new instance
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
new_page_data = PageResponse(
|
||||
primary_language=page_data.primary_language,
|
||||
is_rotation_valid=page_data.is_rotation_valid,
|
||||
@ -482,8 +484,6 @@ class RotationAugmentation(PipelineStep):
|
||||
else: # 270
|
||||
correction = 90
|
||||
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
new_page_data = PageResponse(
|
||||
primary_language=page_data.primary_language,
|
||||
is_rotation_valid=False, # Mark as invalid since we rotated it
|
||||
@ -523,7 +523,7 @@ class FilterOutRotatedDocuments(PipelineStep):
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DatasetTextRuleFilter(PipelineStep):
|
||||
"""Pipeline step that filters samples based on text content rules.
|
||||
|
||||
|
||||
Filters out samples that:
|
||||
- Contain markdown tables
|
||||
- Contain malformed HTML tables
|
||||
@ -539,205 +539,244 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
# Look for pipe-separated table patterns
|
||||
# Markdown tables have lines like: | col1 | col2 | col3 |
|
||||
# And separator lines like: |------|------|------|
|
||||
lines = text.split('\n')
|
||||
lines = text.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
# Check if line looks like a table row
|
||||
if line.startswith('|') and line.endswith('|') and line.count('|') >= 3:
|
||||
if line.startswith("|") and line.endswith("|") and line.count("|") >= 3:
|
||||
# Check if next line is a separator (for header rows)
|
||||
if i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('|') and '-' in next_line:
|
||||
if next_line.startswith("|") and "-" in next_line:
|
||||
return True
|
||||
# Check if previous line is a separator (for data rows)
|
||||
if i > 0:
|
||||
prev_line = lines[i - 1].strip()
|
||||
if prev_line.startswith('|') and '-' in prev_line:
|
||||
if prev_line.startswith("|") and "-" in prev_line:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _contains_math_symbols(self, text: str) -> bool:
|
||||
"""Check if text contains specific mathematical symbols outside of table cells.
|
||||
|
||||
|
||||
Returns:
|
||||
True if text contains any of the specified math symbols outside tables
|
||||
False otherwise
|
||||
"""
|
||||
import re
|
||||
|
||||
# List of mathematical symbols to check for
|
||||
math_symbols = [
|
||||
# Set theory and logic
|
||||
'∈', '∉', '⊂', '⊃', '⊆', '⊇', '∅', '∪', '∩', '∀', '∃', '¬',
|
||||
"∈",
|
||||
"∉",
|
||||
"⊂",
|
||||
"⊃",
|
||||
"⊆",
|
||||
"⊇",
|
||||
"∅",
|
||||
"∪",
|
||||
"∩",
|
||||
"∀",
|
||||
"∃",
|
||||
"¬",
|
||||
# Common mathematical operators
|
||||
'⊕', '⊗', '⊙',
|
||||
"⊕",
|
||||
"⊗",
|
||||
"⊙",
|
||||
# Calculus and analysis
|
||||
'∂', '∇', '∆', '∫', '∬', '∭', '∮', '∏', '∑', '√', '∛', '∜',
|
||||
"∂",
|
||||
"∇",
|
||||
"∆",
|
||||
"∫",
|
||||
"∬",
|
||||
"∭",
|
||||
"∮",
|
||||
"∏",
|
||||
"∑",
|
||||
"√",
|
||||
"∛",
|
||||
"∜",
|
||||
# Arrows and relations
|
||||
'⊥',
|
||||
"⊥",
|
||||
# Other common math symbols
|
||||
'∠', '∡', '⊤', '⊢', '⊣', '∴', '∵', '∶', '∷', '∝', '≅', '≆', '≇', '≊', '≋',
|
||||
"∠",
|
||||
"∡",
|
||||
"⊤",
|
||||
"⊢",
|
||||
"⊣",
|
||||
"∴",
|
||||
"∵",
|
||||
"∶",
|
||||
"∷",
|
||||
"∝",
|
||||
"≅",
|
||||
"≆",
|
||||
"≇",
|
||||
"≊",
|
||||
"≋",
|
||||
# Matrix and vector notation
|
||||
'⊕', '⊖', '⊗', '⊘', '⊙', '⊚', '⊛', '⊜', '⊝',
|
||||
"⊕",
|
||||
"⊖",
|
||||
"⊗",
|
||||
"⊘",
|
||||
"⊙",
|
||||
"⊚",
|
||||
"⊛",
|
||||
"⊜",
|
||||
"⊝",
|
||||
]
|
||||
|
||||
|
||||
# First, remove all HTML tables from the text
|
||||
text_without_tables = text
|
||||
|
||||
|
||||
# Remove HTML tables
|
||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
||||
text_without_tables = table_pattern.sub('', text_without_tables)
|
||||
|
||||
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||
text_without_tables = table_pattern.sub("", text_without_tables)
|
||||
|
||||
# Now check if any of these symbols appear in the text without tables
|
||||
for symbol in math_symbols:
|
||||
if symbol in text_without_tables:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _contains_latex_tables(self, text: str) -> bool:
|
||||
"""Check if text contains LaTeX table environments.
|
||||
|
||||
|
||||
Returns:
|
||||
True if text contains LaTeX tables (\\begin{table}, \\begin{tabular}, etc.)
|
||||
False otherwise
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# Check for various LaTeX table environments
|
||||
latex_table_patterns = [
|
||||
r'\\begin\{table\}',
|
||||
r'\\begin\{tabular\}',
|
||||
r"\\begin\{table\}",
|
||||
r"\\begin\{tabular\}",
|
||||
]
|
||||
|
||||
|
||||
# Check if any LaTeX table pattern exists in the text
|
||||
for pattern in latex_table_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _contains_latex_formatting_outside_math(self, text: str) -> bool:
|
||||
"""Check if text contains LaTeX formatting commands outside of math equations.
|
||||
|
||||
|
||||
Returns:
|
||||
True if text contains LaTeX formatting commands outside math equations
|
||||
False otherwise
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# List of common LaTeX formatting commands to check for
|
||||
latex_commands = [
|
||||
# Lists & basic content
|
||||
r'\begin{itemize}',
|
||||
r'\begin{enumerate}',
|
||||
r'\item',
|
||||
|
||||
r"\begin{itemize}",
|
||||
r"\begin{enumerate}",
|
||||
r"\item",
|
||||
# Figures, tables, and captions
|
||||
r'\begin{figure}',
|
||||
r'\includegraphics',
|
||||
r'\caption',
|
||||
r'\label',
|
||||
r'\ref',
|
||||
r'\eqref',
|
||||
r'\begin{table}',
|
||||
r'\begin{tabular}',
|
||||
|
||||
r"\begin{figure}",
|
||||
r"\includegraphics",
|
||||
r"\caption",
|
||||
r"\label",
|
||||
r"\ref",
|
||||
r"\eqref",
|
||||
r"\begin{table}",
|
||||
r"\begin{tabular}",
|
||||
# Formatting,
|
||||
# r'\textit',
|
||||
# r'\textbb',
|
||||
|
||||
# Math (strong signals)
|
||||
r'\begin{equation}',
|
||||
r'\begin{align}',
|
||||
r'\frac',
|
||||
r'\sum',
|
||||
r'\int',
|
||||
r'\sqrt',
|
||||
r'\prod',
|
||||
r'\lim',
|
||||
r'\binom',
|
||||
r'\mathbb',
|
||||
r'\mathcal',
|
||||
r'\to',
|
||||
r'\varphi',
|
||||
r'\cdot',
|
||||
r'\langle',
|
||||
r'\rangle',
|
||||
|
||||
r"\begin{equation}",
|
||||
r"\begin{align}",
|
||||
r"\frac",
|
||||
r"\sum",
|
||||
r"\int",
|
||||
r"\sqrt",
|
||||
r"\prod",
|
||||
r"\lim",
|
||||
r"\binom",
|
||||
r"\mathbb",
|
||||
r"\mathcal",
|
||||
r"\to",
|
||||
r"\varphi",
|
||||
r"\cdot",
|
||||
r"\langle",
|
||||
r"\rangle",
|
||||
# Citations (bibliography stacks)
|
||||
r'\cite',
|
||||
r"\cite",
|
||||
]
|
||||
|
||||
|
||||
# First, remove all math equations from the text
|
||||
text_without_math = text
|
||||
|
||||
|
||||
# Patterns for math equations
|
||||
math_patterns = [
|
||||
r"\$\$(.+?)\$\$", # $$...$$
|
||||
r"\\\((.+?)\\\)", # \(...\)
|
||||
r"\\\[(.+?)\\\]", # \[...\]
|
||||
]
|
||||
|
||||
|
||||
# Remove all math equations
|
||||
for pattern in math_patterns:
|
||||
text_without_math = re.sub(pattern, '', text_without_math, flags=re.DOTALL)
|
||||
|
||||
text_without_math = re.sub(pattern, "", text_without_math, flags=re.DOTALL)
|
||||
|
||||
# Check if any LaTeX commands appear in the remaining text
|
||||
for command in latex_commands:
|
||||
if command in text_without_math:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _validate_math_equations(self, text: str) -> bool:
|
||||
"""Check if all math equations in the text can render without errors.
|
||||
|
||||
|
||||
Returns:
|
||||
True if all equations render successfully or no equations exist
|
||||
False if any equation fails to render
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# Patterns to find math equations (same as in MathTest)
|
||||
patterns = [
|
||||
r"\$\$(.+?)\$\$", # $$...$$
|
||||
r"\\\((.+?)\\\)", # \(...\)
|
||||
r"\\\[(.+?)\\\]", # \[...\]
|
||||
]
|
||||
|
||||
|
||||
equations = []
|
||||
for pattern in patterns:
|
||||
# Find all matches for the current pattern
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
equations.extend([eq.strip() for eq in matches])
|
||||
|
||||
|
||||
# If no equations found, that's fine
|
||||
if not equations:
|
||||
return True
|
||||
|
||||
|
||||
# Try to render each equation
|
||||
try:
|
||||
from olmocr.bench.katex.render import render_equation
|
||||
|
||||
|
||||
for equation in equations:
|
||||
# Skip empty or whitespace-only equations
|
||||
if not equation or not equation.strip():
|
||||
continue
|
||||
|
||||
|
||||
# Try to render the equation
|
||||
rendered = render_equation(equation)
|
||||
|
||||
|
||||
# Check if there was an error
|
||||
if rendered is None or (hasattr(rendered, 'error') and rendered.error):
|
||||
if rendered is None or (hasattr(rendered, "error") and rendered.error):
|
||||
# Equation failed to render
|
||||
logger.warning(f"Could not render equation '{repr(equation)}', skipping sample")
|
||||
return False
|
||||
|
||||
|
||||
# All equations rendered successfully
|
||||
return True
|
||||
|
||||
|
||||
except ImportError:
|
||||
# If we can't import the render module, skip this check
|
||||
# This allows the filter to work even without the rendering dependencies
|
||||
@ -746,87 +785,86 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
# If any unexpected error occurs during validation, be conservative and filter out
|
||||
print(f"Error validating math equations: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _contains_br_in_table_cells(self, text: str) -> bool:
|
||||
"""Check if text contains <br> tags within HTML table cells.
|
||||
|
||||
|
||||
Returns:
|
||||
True if any table cell contains <br> tags
|
||||
False otherwise
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# Check if there are any tables in the text
|
||||
if '<table' not in text.lower() or '<br' not in text.lower():
|
||||
if "<table" not in text.lower() or "<br" not in text.lower():
|
||||
return False # No tables or no <br> tags at all
|
||||
|
||||
|
||||
# Pattern to find HTML tables (case-insensitive)
|
||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
||||
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||
tables = table_pattern.findall(text)
|
||||
|
||||
|
||||
# Check each table for <br> tags in cells
|
||||
for table_html in tables:
|
||||
# Pattern to find table cells (td and th tags)
|
||||
cell_pattern = re.compile(r'<(td|th)\b[^>]*>(.*?)</\1>', re.IGNORECASE | re.DOTALL)
|
||||
cell_pattern = re.compile(r"<(td|th)\b[^>]*>(.*?)</\1>", re.IGNORECASE | re.DOTALL)
|
||||
cells = cell_pattern.findall(table_html)
|
||||
|
||||
|
||||
for tag_type, cell_content in cells:
|
||||
# Check if cell content contains <br> tags (any variation)
|
||||
if re.search(r'<br\s*/?>', cell_content, re.IGNORECASE):
|
||||
if re.search(r"<br\s*/?>", cell_content, re.IGNORECASE):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _extract_and_validate_html_tables(self, text: str) -> bool:
|
||||
"""Extract HTML tables and validate they parse correctly.
|
||||
|
||||
|
||||
Returns:
|
||||
True if all HTML tables are valid or no tables exist
|
||||
False if any HTML table is malformed
|
||||
"""
|
||||
# Find all HTML table blocks
|
||||
import re
|
||||
|
||||
|
||||
# Check if there are any <table> tags at all
|
||||
if '<table' not in text.lower():
|
||||
if "<table" not in text.lower():
|
||||
return True # No tables, that's fine
|
||||
|
||||
|
||||
# Pattern to find HTML tables (case-insensitive)
|
||||
# Note: This pattern might not catch malformed tables where </table> is missing
|
||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
||||
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||
tables = table_pattern.findall(text)
|
||||
|
||||
|
||||
# Also check for unclosed table tags
|
||||
table_open_count = len(re.findall(r'<table\b[^>]*>', text, re.IGNORECASE))
|
||||
table_close_count = len(re.findall(r'</table>', text, re.IGNORECASE))
|
||||
|
||||
table_open_count = len(re.findall(r"<table\b[^>]*>", text, re.IGNORECASE))
|
||||
table_close_count = len(re.findall(r"</table>", text, re.IGNORECASE))
|
||||
|
||||
if table_open_count != table_close_count:
|
||||
return False # Mismatched table tags
|
||||
|
||||
|
||||
if not tables and table_open_count > 0:
|
||||
# Found table tags but couldn't extract complete tables
|
||||
return False
|
||||
|
||||
|
||||
# Try to parse each table
|
||||
from html.parser import HTMLParser
|
||||
|
||||
|
||||
class TableValidator(HTMLParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tag_stack = []
|
||||
self.is_valid = True
|
||||
self.error_msg = None
|
||||
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
self.tag_stack.append(tag.lower())
|
||||
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
tag = tag.lower()
|
||||
if not self.tag_stack:
|
||||
self.is_valid = False
|
||||
self.error_msg = f"Unexpected closing tag: {tag}"
|
||||
return
|
||||
|
||||
|
||||
# Check if the closing tag matches the most recent opening tag
|
||||
if self.tag_stack[-1] == tag:
|
||||
self.tag_stack.pop()
|
||||
@ -842,11 +880,11 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
else:
|
||||
self.is_valid = False
|
||||
self.error_msg = f"Mismatched tag: expected {self.tag_stack[-1]}, got {tag}"
|
||||
|
||||
|
||||
def error(self, message):
|
||||
self.is_valid = False
|
||||
self.error_msg = message
|
||||
|
||||
|
||||
# Validate each table
|
||||
for table_html in tables:
|
||||
parser = TableValidator()
|
||||
@ -860,90 +898,90 @@ class DatasetTextRuleFilter(PipelineStep):
|
||||
except Exception:
|
||||
# Any parsing exception means the table is malformed
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
def __call__(self, sample: Sample) -> Optional[Sample]:
|
||||
"""Filter samples based on text content rules."""
|
||||
# Get the natural text from page_data if it exists
|
||||
text = None
|
||||
|
||||
|
||||
if "page_data" in sample:
|
||||
page_data = sample["page_data"]
|
||||
if hasattr(page_data, "natural_text") and page_data.natural_text:
|
||||
text = page_data.natural_text
|
||||
|
||||
|
||||
# If no text to check, pass the sample through
|
||||
if text is None:
|
||||
return sample
|
||||
|
||||
# Check for markdown tables
|
||||
if self._contains_markdown_table(text):
|
||||
return None # Filter out samples with markdown tables
|
||||
|
||||
# Check for HTML tables and validate them
|
||||
if not self._extract_and_validate_html_tables(text):
|
||||
return None # Filter out samples with malformed HTML tables
|
||||
|
||||
# Check for <br> tags in table cells
|
||||
if self._contains_br_in_table_cells(text):
|
||||
return None # Filter out samples with <br> tags in table cells
|
||||
|
||||
# Check if all math equations can render without errors
|
||||
if not self._validate_math_equations(text):
|
||||
return None # Filter out samples with invalid math equations
|
||||
|
||||
# Check for mathematical symbols
|
||||
if self._contains_math_symbols(text):
|
||||
return None # Filter out samples with mathematical symbols
|
||||
|
||||
|
||||
# # Check for markdown tables
|
||||
# if self._contains_markdown_table(text):
|
||||
# return None # Filter out samples with markdown tables
|
||||
|
||||
# # Check for HTML tables and validate them
|
||||
# if not self._extract_and_validate_html_tables(text):
|
||||
# return None # Filter out samples with malformed HTML tables
|
||||
|
||||
# # Check for <br> tags in table cells
|
||||
# if self._contains_br_in_table_cells(text):
|
||||
# return None # Filter out samples with <br> tags in table cells
|
||||
|
||||
# # Check if all math equations can render without errors
|
||||
# if not self._validate_math_equations(text):
|
||||
# return None # Filter out samples with invalid math equations
|
||||
|
||||
# # Check for mathematical symbols
|
||||
# if self._contains_math_symbols(text):
|
||||
# return None # Filter out samples with mathematical symbols
|
||||
|
||||
# Check for LaTeX formatting outside math equations
|
||||
if self._contains_latex_formatting_outside_math(text):
|
||||
return None # Filter out samples with \textit or \textbf outside math
|
||||
|
||||
|
||||
# Check for LaTeX tables
|
||||
if self._contains_latex_tables(text):
|
||||
return None # Filter out samples with LaTeX tables
|
||||
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ReformatLatexBoldItalic(PipelineStep):
|
||||
"""Pipeline step that converts LaTeX formatting commands to markdown equivalents.
|
||||
|
||||
|
||||
Converts:
|
||||
- \\textit{...} to *...* (italic)
|
||||
- \\textbf{...} to **...** (bold)
|
||||
|
||||
|
||||
These conversions only happen outside of math equations.
|
||||
"""
|
||||
|
||||
|
||||
def __call__(self, sample: Sample) -> Optional[Sample]:
|
||||
"""Convert LaTeX formatting to markdown in the sample text."""
|
||||
# Get the natural text from page_data if it exists
|
||||
if "page_data" not in sample:
|
||||
return sample
|
||||
|
||||
|
||||
page_data = sample["page_data"]
|
||||
if not hasattr(page_data, "natural_text") or not page_data.natural_text:
|
||||
return sample
|
||||
|
||||
|
||||
text = page_data.natural_text
|
||||
|
||||
|
||||
import re
|
||||
|
||||
|
||||
# Math equation patterns to preserve
|
||||
math_patterns = [
|
||||
r"\$\$(.+?)\$\$", # $$...$$
|
||||
r"\\\((.+?)\\\)", # \(...\)
|
||||
r"\\\[(.+?)\\\]", # \[...\]
|
||||
]
|
||||
|
||||
|
||||
# Store math equations with placeholders
|
||||
math_placeholders = []
|
||||
preserved_text = text
|
||||
|
||||
|
||||
# Replace math equations with placeholders
|
||||
for i, pattern in enumerate(math_patterns):
|
||||
matches = re.finditer(pattern, preserved_text, re.DOTALL)
|
||||
@ -951,65 +989,66 @@ class ReformatLatexBoldItalic(PipelineStep):
|
||||
placeholder = f"__MATH_PLACEHOLDER_{i}_{j}__"
|
||||
math_placeholders.append((placeholder, match.group(0)))
|
||||
preserved_text = preserved_text.replace(match.group(0), placeholder, 1)
|
||||
|
||||
|
||||
# Now convert LaTeX formatting to markdown
|
||||
# We need to handle nested braces properly
|
||||
# Use a function to find matching braces
|
||||
def replace_latex_command(text, command, markdown):
|
||||
"""Replace LaTeX command with markdown, handling nested braces."""
|
||||
import re
|
||||
pattern = r'\\' + command + r'\{'
|
||||
|
||||
pattern = r"\\" + command + r"\{"
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
|
||||
while i < len(text):
|
||||
match = re.search(pattern, text[i:])
|
||||
if not match:
|
||||
result.append(text[i:])
|
||||
break
|
||||
|
||||
|
||||
# Add text before the match
|
||||
result.append(text[i:i + match.start()])
|
||||
|
||||
result.append(text[i : i + match.start()])
|
||||
|
||||
# Find the matching closing brace
|
||||
start_pos = i + match.end()
|
||||
brace_count = 1
|
||||
j = start_pos
|
||||
|
||||
|
||||
while j < len(text) and brace_count > 0:
|
||||
if text[j] == '{':
|
||||
if text[j] == "{":
|
||||
brace_count += 1
|
||||
elif text[j] == '}':
|
||||
elif text[j] == "}":
|
||||
brace_count -= 1
|
||||
j += 1
|
||||
|
||||
|
||||
if brace_count == 0:
|
||||
# Extract the content between braces
|
||||
content = text[start_pos:j-1]
|
||||
content = text[start_pos : j - 1]
|
||||
result.append(markdown + content + markdown)
|
||||
i = j
|
||||
else:
|
||||
# Unmatched braces, keep original
|
||||
result.append(text[i + match.start():i + match.end()])
|
||||
result.append(text[i + match.start() : i + match.end()])
|
||||
i = i + match.end()
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
return "".join(result)
|
||||
|
||||
# Handle \textbf{...} -> **...**
|
||||
preserved_text = replace_latex_command(preserved_text, 'textbf', '**')
|
||||
|
||||
preserved_text = replace_latex_command(preserved_text, "textbf", "**")
|
||||
|
||||
# Handle \textit{...} -> *...*
|
||||
preserved_text = replace_latex_command(preserved_text, 'textit', '*')
|
||||
|
||||
preserved_text = replace_latex_command(preserved_text, "textit", "*")
|
||||
|
||||
# Restore math equations
|
||||
for placeholder, original in math_placeholders:
|
||||
preserved_text = preserved_text.replace(placeholder, original)
|
||||
|
||||
|
||||
# Create a new PageResponse with the updated text (since it's frozen)
|
||||
from dataclasses import replace
|
||||
|
||||
updated_page_data = replace(page_data, natural_text=preserved_text)
|
||||
sample["page_data"] = updated_page_data
|
||||
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@ -1382,78 +1421,71 @@ if __name__ == "__main__":
|
||||
if args.save_filtered:
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
save_dir = Path(args.save_filtered)
|
||||
|
||||
|
||||
# Clear and create directory
|
||||
if save_dir.exists():
|
||||
shutil.rmtree(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
print(f"\n=== Checking for filtered samples ===")
|
||||
print(f"Will save filtered samples to: {save_dir}")
|
||||
|
||||
|
||||
# Function to process and copy a single sample
|
||||
def process_and_copy_sample(idx, dataset_samples, save_dir_str):
|
||||
"""Process a sample and return info if it's filtered.
|
||||
|
||||
|
||||
Note: This function needs to be picklable for ProcessPoolExecutor,
|
||||
so it takes simple arguments rather than complex objects.
|
||||
"""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Recreate dataset with same parameters
|
||||
# This is needed because dataset objects can't be pickled
|
||||
temp_dataset = BaseMarkdownPDFDataset.__new__(BaseMarkdownPDFDataset)
|
||||
temp_dataset.samples = dataset_samples
|
||||
temp_dataset.pipeline_steps = pipeline_steps
|
||||
|
||||
|
||||
try:
|
||||
sample = temp_dataset[idx]
|
||||
if sample is None:
|
||||
# This sample was filtered out - get the original paths
|
||||
original_sample = dataset_samples[idx]
|
||||
md_path = original_sample['markdown_path']
|
||||
pdf_path = original_sample['pdf_path']
|
||||
|
||||
md_path = original_sample["markdown_path"]
|
||||
pdf_path = original_sample["pdf_path"]
|
||||
|
||||
save_dir = Path(save_dir_str)
|
||||
|
||||
|
||||
# Create subdirectory to preserve some structure
|
||||
# Use the parent directory name and file name
|
||||
rel_path = md_path.parent.name
|
||||
target_subdir = save_dir / rel_path
|
||||
target_subdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Copy markdown file
|
||||
target_md = target_subdir / md_path.name
|
||||
shutil.copy2(md_path, target_md)
|
||||
|
||||
|
||||
# Copy PDF file
|
||||
target_pdf = target_subdir / pdf_path.name
|
||||
shutil.copy2(pdf_path, target_pdf)
|
||||
|
||||
return {
|
||||
'index': idx,
|
||||
'markdown_path': str(md_path),
|
||||
'pdf_path': str(pdf_path)
|
||||
}
|
||||
|
||||
return {"index": idx, "markdown_path": str(md_path), "pdf_path": str(pdf_path)}
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error processing sample {idx}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Process all samples in parallel
|
||||
filtered_samples = []
|
||||
print(f"Processing {len(dataset)} samples to find and copy filtered ones...")
|
||||
|
||||
|
||||
with ProcessPoolExecutor(max_workers=8) as executor:
|
||||
# Submit all tasks
|
||||
futures = {
|
||||
executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx
|
||||
for idx in range(len(dataset))
|
||||
}
|
||||
|
||||
futures = {executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx for idx in range(len(dataset))}
|
||||
|
||||
# Process results with progress bar
|
||||
with tqdm(total=len(dataset), desc="Processing samples") as pbar:
|
||||
for future in as_completed(futures):
|
||||
@ -1461,20 +1493,20 @@ if __name__ == "__main__":
|
||||
if result is not None:
|
||||
filtered_samples.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
# Sort filtered samples by index for consistent output
|
||||
filtered_samples.sort(key=lambda x: x['index'])
|
||||
|
||||
filtered_samples.sort(key=lambda x: x["index"])
|
||||
|
||||
print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}")
|
||||
|
||||
|
||||
if filtered_samples:
|
||||
print(f"First 10 filtered samples:")
|
||||
for i, sample_info in enumerate(filtered_samples[:10]):
|
||||
md_name = Path(sample_info['markdown_path']).name
|
||||
md_name = Path(sample_info["markdown_path"]).name
|
||||
print(f" Sample {sample_info['index']}: {md_name}")
|
||||
if len(filtered_samples) > 10:
|
||||
print(f" ... and {len(filtered_samples) - 10} more")
|
||||
|
||||
|
||||
# Exit early if --save-filtered is used (don't continue with other analyses)
|
||||
print("\nCompleted saving filtered samples. Exiting.")
|
||||
exit(0)
|
||||
|
||||
@ -44,7 +44,7 @@ Some conclusion text.
|
||||
""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_with_md_table)
|
||||
self.assertIsNone(result, "Should filter out samples with markdown tables")
|
||||
|
||||
@ -60,7 +60,7 @@ Some conclusion text.
|
||||
natural_text="This is regular text without any tables. It has | pipes | but not in table format.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_without_table)
|
||||
self.assertIsNotNone(result, "Should pass through samples without markdown tables")
|
||||
self.assertEqual(result, sample_without_table)
|
||||
@ -92,7 +92,7 @@ Some text after table.
|
||||
""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_with_valid_html)
|
||||
self.assertIsNotNone(result, "Should pass through samples with valid HTML tables")
|
||||
|
||||
@ -121,7 +121,7 @@ Text after.
|
||||
""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_with_malformed_html)
|
||||
self.assertIsNone(result, "Should filter out samples with malformed HTML tables")
|
||||
|
||||
@ -147,7 +147,7 @@ Text after without closing table tag.
|
||||
""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_with_unclosed_table)
|
||||
self.assertIsNone(result, "Should filter out HTML tables without closing tags")
|
||||
|
||||
@ -157,7 +157,7 @@ Text after without closing table tag.
|
||||
"markdown_path": Path("/path/to/file.md"),
|
||||
"pdf_path": Path("/path/to/file.pdf"),
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_without_page_data)
|
||||
self.assertIsNotNone(result, "Should pass through samples without page_data")
|
||||
self.assertEqual(result, sample_without_page_data)
|
||||
@ -174,7 +174,7 @@ Text after without closing table tag.
|
||||
natural_text=None,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_without_text)
|
||||
self.assertIsNotNone(result, "Should pass through samples without natural_text")
|
||||
|
||||
@ -190,7 +190,7 @@ Text after without closing table tag.
|
||||
natural_text="",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_with_empty_text)
|
||||
self.assertIsNotNone(result, "Should pass through samples with empty natural_text")
|
||||
|
||||
@ -211,7 +211,7 @@ Text after without closing table tag.
|
||||
""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_with_alignment)
|
||||
self.assertIsNone(result, "Should filter out markdown tables with alignment")
|
||||
|
||||
@ -240,7 +240,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
|
||||
""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_mixed)
|
||||
self.assertIsNotNone(result, "Should pass through with valid HTML and no markdown tables")
|
||||
|
||||
@ -261,7 +261,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
|
||||
</table>""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_with_br)
|
||||
self.assertIsNone(result, "Should filter out tables with <br> tags in cells")
|
||||
|
||||
@ -283,7 +283,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
|
||||
</table>""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_br_outside)
|
||||
self.assertIsNotNone(result, "Should allow <br> tags outside tables")
|
||||
|
||||
@ -308,7 +308,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
|
||||
</table>""",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample_br_variations)
|
||||
self.assertIsNone(result, "Should filter out tables with any <br> variation in cells")
|
||||
|
||||
@ -332,7 +332,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="This is \\textbf{bold} text.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "This is **bold** text.")
|
||||
|
||||
@ -348,7 +348,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="This is \\textit{italic} text.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "This is *italic* text.")
|
||||
|
||||
@ -364,7 +364,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="This has \\textbf{bold} and \\textit{italic} text.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "This has **bold** and *italic* text.")
|
||||
|
||||
@ -380,7 +380,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="Text outside $$ \\textbf{x} = \\textit{y} $$ more text.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "Text outside $$ \\textbf{x} = \\textit{y} $$ more text.")
|
||||
|
||||
@ -396,7 +396,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="The \\textbf{equation} is $$ \\textbf{x} = 2 $$ and \\textit{important}.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "The **equation** is $$ \\textbf{x} = 2 $$ and *important*.")
|
||||
|
||||
@ -412,7 +412,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="This is \\textbf{bold with {nested} braces} text.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "This is **bold with {nested} braces** text.")
|
||||
|
||||
@ -428,12 +428,9 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="\\textbf{First} and \\textbf{second} bold, \\textit{first} and \\textit{second} italic.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(
|
||||
result["page_data"].natural_text,
|
||||
"**First** and **second** bold, *first* and *second* italic."
|
||||
)
|
||||
self.assertEqual(result["page_data"].natural_text, "**First** and **second** bold, *first* and *second* italic.")
|
||||
|
||||
def test_latex_in_parenthesis_delimiter(self):
|
||||
"""Test LaTeX preserved in \\(...\\) math delimiter."""
|
||||
@ -447,7 +444,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="Text \\( \\textbf{math} \\) more text \\textbf{bold}.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "Text \\( \\textbf{math} \\) more text **bold**.")
|
||||
|
||||
@ -463,7 +460,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="Text \\[ \\textit{math} \\] more text \\textit{italic}.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "Text \\[ \\textit{math} \\] more text *italic*.")
|
||||
|
||||
@ -479,14 +476,14 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text="Plain text without any formatting.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "Plain text without any formatting.")
|
||||
|
||||
def test_no_page_data(self):
|
||||
"""Test handling of samples without page_data."""
|
||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result, sample)
|
||||
|
||||
@ -502,10 +499,10 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
||||
natural_text=None,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertIsNone(result["page_data"].natural_text)
|
||||
|
||||
|
||||
def test_complex_latex_with_parenthesis_delimiters(self):
|
||||
"""Test complex LaTeX text with \\(...\\) delimiters and textit."""
|
||||
input_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx
|
||||
@ -517,7 +514,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
|
||||
\\[
|
||||
\\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2},
|
||||
\\]"""
|
||||
|
||||
|
||||
expected_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx
|
||||
+ \\sum_{n=1}^{\\infty} \\frac{a_n}{2} \\int_0^P \\cos \\frac{2(m+n)\\pi x}{P} + \\cos \\frac{2(m-n)\\pi x}{P} dx
|
||||
+ b_n \\int_0^P \\sin \\frac{2(m+n)\\pi x}{P} - \\sin \\frac{2(m-n)\\pi x}{P} dx.
|
||||
@ -527,7 +524,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
|
||||
\\[
|
||||
\\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2},
|
||||
\\]"""
|
||||
|
||||
|
||||
sample = {
|
||||
"page_data": PageResponse(
|
||||
primary_language="en",
|
||||
@ -538,7 +535,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
|
||||
natural_text=input_text,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.reformatter(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, expected_text)
|
||||
|
||||
@ -562,7 +559,7 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
|
||||
natural_text="Some text",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample)
|
||||
self.assertIsNotNone(result, "Should pass through documents with valid rotation")
|
||||
|
||||
@ -578,7 +575,7 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
|
||||
natural_text="Some text",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample)
|
||||
self.assertIsNone(result, "Should filter out documents with invalid rotation")
|
||||
|
||||
@ -594,14 +591,14 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
|
||||
natural_text="Some text",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.filter(sample)
|
||||
self.assertIsNone(result, "Should filter out documents with non-zero rotation correction")
|
||||
|
||||
def test_no_page_data(self):
|
||||
"""Test that samples without page_data pass through."""
|
||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||
|
||||
|
||||
result = self.filter(sample)
|
||||
self.assertIsNotNone(result, "Should pass through samples without page_data")
|
||||
|
||||
@ -625,7 +622,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
||||
natural_text="The equation $x^2 + y^2 = z^2$ is famous.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.normalizer(sample)
|
||||
expected_text = "The equation \\(x^2 + y^2 = z^2\\) is famous."
|
||||
self.assertEqual(result["page_data"].natural_text, expected_text)
|
||||
@ -642,7 +639,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
||||
natural_text="Display equation:\n$$\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}$$",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.normalizer(sample)
|
||||
expected_text = "Display equation:\n\\[\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}\\]"
|
||||
self.assertEqual(result["page_data"].natural_text, expected_text)
|
||||
@ -659,7 +656,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
||||
natural_text="Inline $a + b$ and display:\n$$c^2 = a^2 + b^2$$\nMore inline $x = y$.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.normalizer(sample)
|
||||
expected_text = "Inline \\(a + b\\) and display:\n\\[c^2 = a^2 + b^2\\]\nMore inline \\(x = y\\)."
|
||||
self.assertEqual(result["page_data"].natural_text, expected_text)
|
||||
@ -676,7 +673,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
||||
natural_text="Regular text without any equations.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.normalizer(sample)
|
||||
self.assertEqual(result["page_data"].natural_text, "Regular text without any equations.")
|
||||
|
||||
@ -692,14 +689,14 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
||||
natural_text=None,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
result = self.normalizer(sample)
|
||||
self.assertIsNone(result["page_data"].natural_text)
|
||||
|
||||
def test_no_page_data(self):
|
||||
"""Test handling of missing page_data."""
|
||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||
|
||||
|
||||
result = self.normalizer(sample)
|
||||
self.assertEqual(result, sample)
|
||||
|
||||
@ -712,7 +709,7 @@ class TestFrontMatterParser(unittest.TestCase):
|
||||
self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse)
|
||||
self.parser_without_class = FrontMatterParser(front_matter_class=None)
|
||||
|
||||
@patch.object(Path, 'read_text')
|
||||
@patch.object(Path, "read_text")
|
||||
def test_parse_yaml_front_matter(self, mock_read_text):
|
||||
"""Test parsing of YAML front matter."""
|
||||
mock_read_text.return_value = """---
|
||||
@ -724,27 +721,27 @@ is_diagram: false
|
||||
---
|
||||
This is the document content.
|
||||
"""
|
||||
|
||||
|
||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||
result = self.parser_with_class(sample)
|
||||
|
||||
|
||||
self.assertIn("page_data", result)
|
||||
self.assertIsInstance(result["page_data"], PageResponse)
|
||||
self.assertEqual(result["page_data"].primary_language, "en")
|
||||
self.assertEqual(result["page_data"].natural_text, "This is the document content.")
|
||||
|
||||
@patch.object(Path, 'read_text')
|
||||
@patch.object(Path, "read_text")
|
||||
def test_no_front_matter(self, mock_read_text):
|
||||
"""Test handling of documents without front matter."""
|
||||
mock_read_text.return_value = "Just regular content without front matter."
|
||||
|
||||
|
||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||
|
||||
|
||||
# Should raise an error when front_matter_class is specified
|
||||
with self.assertRaises(ValueError):
|
||||
self.parser_with_class(sample)
|
||||
|
||||
@patch.object(Path, 'read_text')
|
||||
@patch.object(Path, "read_text")
|
||||
def test_malformed_yaml(self, mock_read_text):
|
||||
"""Test handling of malformed YAML."""
|
||||
mock_read_text.return_value = """---
|
||||
@ -753,14 +750,14 @@ is_rotation_valid: [this is not valid yaml}
|
||||
---
|
||||
Content
|
||||
"""
|
||||
|
||||
|
||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||
|
||||
|
||||
# Parser without class should return empty dict for malformed YAML
|
||||
result = self.parser_without_class(sample)
|
||||
self.assertEqual(result["page_data"], {})
|
||||
|
||||
@patch.object(Path, 'read_text')
|
||||
@patch.object(Path, "read_text")
|
||||
def test_preserve_existing_markdown_content(self, mock_read_text):
|
||||
"""Test that existing markdown_content is preserved if present."""
|
||||
sample = {
|
||||
@ -772,16 +769,16 @@ rotation_correction: 0
|
||||
is_table: true
|
||||
is_diagram: false
|
||||
---
|
||||
French content."""
|
||||
French content.""",
|
||||
}
|
||||
|
||||
|
||||
# Should not call read_text since markdown_content exists
|
||||
result = self.parser_with_class(sample)
|
||||
mock_read_text.assert_not_called()
|
||||
|
||||
|
||||
self.assertEqual(result["page_data"].primary_language, "fr")
|
||||
self.assertEqual(result["page_data"].is_table, True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user