mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-28 08:11:33 +00:00
Lints
This commit is contained in:
parent
768cb33937
commit
41201b6317
@ -8,16 +8,19 @@ from olmocr.bench.prompts import (
|
|||||||
build_basic_prompt,
|
build_basic_prompt,
|
||||||
build_openai_silver_data_prompt_no_document_anchoring,
|
build_openai_silver_data_prompt_no_document_anchoring,
|
||||||
)
|
)
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
|
from olmocr.data.renderpdf import (
|
||||||
|
get_png_dimensions_from_base64,
|
||||||
|
render_pdf_to_base64png,
|
||||||
|
)
|
||||||
from olmocr.prompts.anchor import get_anchor_text
|
from olmocr.prompts.anchor import get_anchor_text
|
||||||
from olmocr.prompts.prompts import (
|
from olmocr.prompts.prompts import (
|
||||||
PageResponse,
|
PageResponse,
|
||||||
build_finetuning_prompt,
|
build_finetuning_prompt,
|
||||||
build_openai_silver_data_prompt,
|
build_openai_silver_data_prompt,
|
||||||
openai_response_format_schema,
|
|
||||||
build_openai_silver_data_prompt_v2,
|
build_openai_silver_data_prompt_v2,
|
||||||
build_openai_silver_data_prompt_v2_simple,
|
build_openai_silver_data_prompt_v2_simple,
|
||||||
build_openai_silver_data_prompt_v3_simple,
|
build_openai_silver_data_prompt_v3_simple,
|
||||||
|
openai_response_format_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +68,7 @@ def run_chatgpt(
|
|||||||
prompt = build_openai_silver_data_prompt_v2_simple(width, height)
|
prompt = build_openai_silver_data_prompt_v2_simple(width, height)
|
||||||
elif prompt_template == "fullv3simple":
|
elif prompt_template == "fullv3simple":
|
||||||
width, height = get_png_dimensions_from_base64(image_base64)
|
width, height = get_png_dimensions_from_base64(image_base64)
|
||||||
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
|
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown prompt template")
|
raise ValueError("Unknown prompt template")
|
||||||
|
|
||||||
@ -82,7 +85,7 @@ def run_chatgpt(
|
|||||||
],
|
],
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_completion_tokens=20000,
|
max_completion_tokens=20000,
|
||||||
#reasoning_effort="high",
|
# reasoning_effort="high",
|
||||||
response_format=openai_response_format_schema() if response_template == "json" else None,
|
response_format=openai_response_format_schema() if response_template == "json" else None,
|
||||||
safety_identifier="olmocr-bench-runner",
|
safety_identifier="olmocr-bench-runner",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,14 +8,17 @@ and generates OpenAI batch API requests for processing PDFs.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Generator, Dict, Any, Optional, Tuple
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
|
from olmocr.data.renderpdf import (
|
||||||
|
get_png_dimensions_from_base64,
|
||||||
|
render_pdf_to_base64png,
|
||||||
|
)
|
||||||
from olmocr.prompts.prompts import (
|
from olmocr.prompts.prompts import (
|
||||||
build_openai_silver_data_prompt_v3_simple,
|
build_openai_silver_data_prompt_v3_simple,
|
||||||
openai_response_format_schema,
|
openai_response_format_schema,
|
||||||
@ -28,10 +31,10 @@ MAX_FILE_SIZE = 99 * 1024 * 1024 # 99MB in bytes
|
|||||||
def validate_single_page_pdf(pdf_path: Path) -> bool:
|
def validate_single_page_pdf(pdf_path: Path) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate that a PDF has exactly one page.
|
Validate that a PDF has exactly one page.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pdf_path: Path to the PDF file
|
pdf_path: Path to the PDF file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if PDF has exactly one page, False otherwise
|
True if PDF has exactly one page, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -46,32 +49,32 @@ def validate_single_page_pdf(pdf_path: Path) -> bool:
|
|||||||
def build_custom_id(pdf_path: Path, base_dir: Path) -> str:
|
def build_custom_id(pdf_path: Path, base_dir: Path) -> str:
|
||||||
"""
|
"""
|
||||||
Build a custom ID for the request that can be used to recover the file later.
|
Build a custom ID for the request that can be used to recover the file later.
|
||||||
|
|
||||||
The ID preserves the full path structure for easy recovery.
|
The ID preserves the full path structure for easy recovery.
|
||||||
Example: extracted/document_id.pdf becomes "extracted/document_id"
|
Example: extracted/document_id.pdf becomes "extracted/document_id"
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pdf_path: Full path to the PDF file
|
pdf_path: Full path to the PDF file
|
||||||
base_dir: Base directory containing the processed folder
|
base_dir: Base directory containing the processed folder
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Custom ID string that preserves path structure
|
Custom ID string that preserves path structure
|
||||||
"""
|
"""
|
||||||
# Get relative path from base directory
|
# Get relative path from base directory
|
||||||
rel_path = pdf_path.relative_to(base_dir)
|
rel_path = pdf_path.relative_to(base_dir)
|
||||||
# Remove .pdf extension but keep directory structure
|
# Remove .pdf extension but keep directory structure
|
||||||
path_without_ext = str(rel_path).replace('.pdf', '')
|
path_without_ext = str(rel_path).replace(".pdf", "")
|
||||||
return path_without_ext
|
return path_without_ext
|
||||||
|
|
||||||
|
|
||||||
def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[str, Any], Path]]:
|
def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[str, Any], Path]]:
|
||||||
"""
|
"""
|
||||||
Process a single PDF and return the batch request if valid.
|
Process a single PDF and return the batch request if valid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pdf_path: Path to the PDF file
|
pdf_path: Path to the PDF file
|
||||||
base_dir: Base directory for building custom IDs
|
base_dir: Base directory for building custom IDs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (request dict, pdf_path) if successful, None otherwise
|
Tuple of (request dict, pdf_path) if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
@ -83,20 +86,20 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading PDF {pdf_path}: {e}")
|
print(f"Error reading PDF {pdf_path}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Render PDF to base64 image
|
# Render PDF to base64 image
|
||||||
image_base64 = render_pdf_to_base64png(str(pdf_path), page_num=1, target_longest_image_dim=TARGET_IMAGE_DIM)
|
image_base64 = render_pdf_to_base64png(str(pdf_path), page_num=1, target_longest_image_dim=TARGET_IMAGE_DIM)
|
||||||
|
|
||||||
# Get image dimensions for the prompt
|
# Get image dimensions for the prompt
|
||||||
width, height = get_png_dimensions_from_base64(image_base64)
|
width, height = get_png_dimensions_from_base64(image_base64)
|
||||||
|
|
||||||
# Build the prompt using v3 simple version
|
# Build the prompt using v3 simple version
|
||||||
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
|
prompt = build_openai_silver_data_prompt_v3_simple(width, height)
|
||||||
|
|
||||||
# Build custom ID
|
# Build custom ID
|
||||||
custom_id = build_custom_id(pdf_path, base_dir)
|
custom_id = build_custom_id(pdf_path, base_dir)
|
||||||
|
|
||||||
# Build the request in OpenAI batch format
|
# Build the request in OpenAI batch format
|
||||||
request = {
|
request = {
|
||||||
"custom_id": custom_id,
|
"custom_id": custom_id,
|
||||||
@ -118,7 +121,7 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
|
|||||||
"response_format": openai_response_format_schema(),
|
"response_format": openai_response_format_schema(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return (request, pdf_path)
|
return (request, pdf_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {pdf_path}: {e}")
|
print(f"Error processing {pdf_path}: {e}")
|
||||||
@ -128,21 +131,21 @@ def process_single_pdf(pdf_path: Path, base_dir: Path) -> Optional[Tuple[Dict[st
|
|||||||
def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
|
def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
|
||||||
"""
|
"""
|
||||||
Find all PDF files in the processed folder structure.
|
Find all PDF files in the processed folder structure.
|
||||||
|
|
||||||
The structure is expected to be:
|
The structure is expected to be:
|
||||||
processed_XX_subset_split/
|
processed_XX_subset_split/
|
||||||
extracted/
|
extracted/
|
||||||
*.pdf
|
*.pdf
|
||||||
|
|
||||||
Or for hugging_face downloads:
|
Or for hugging_face downloads:
|
||||||
hugging_face/
|
hugging_face/
|
||||||
pdf_tarballs/
|
pdf_tarballs/
|
||||||
extracted/
|
extracted/
|
||||||
*.pdf
|
*.pdf
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dir: Input directory path
|
input_dir: Input directory path
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Path objects for each PDF file found
|
Path objects for each PDF file found
|
||||||
"""
|
"""
|
||||||
@ -151,64 +154,56 @@ def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
|
|||||||
yield pdf_path
|
yield pdf_path
|
||||||
|
|
||||||
|
|
||||||
def process_pdfs_to_batch_requests(
|
def process_pdfs_to_batch_requests(input_dir: Path, output_dir: Path, max_pdfs: int = None, num_workers: int = 8) -> int:
|
||||||
input_dir: Path,
|
|
||||||
output_dir: Path,
|
|
||||||
max_pdfs: int = None,
|
|
||||||
num_workers: int = 8
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
Process PDFs and create batch request files using parallel processing.
|
Process PDFs and create batch request files using parallel processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dir: Directory containing the processed folder structure
|
input_dir: Directory containing the processed folder structure
|
||||||
output_dir: Directory to save batch request files
|
output_dir: Directory to save batch request files
|
||||||
max_pdfs: Maximum number of PDFs to process (None for all)
|
max_pdfs: Maximum number of PDFs to process (None for all)
|
||||||
num_workers: Number of parallel workers for processing
|
num_workers: Number of parallel workers for processing
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of PDFs processed
|
Number of PDFs processed
|
||||||
"""
|
"""
|
||||||
# Ensure output directory exists
|
# Ensure output directory exists
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Initialize file management
|
# Initialize file management
|
||||||
file_num = 0
|
file_num = 0
|
||||||
current_file_size = 0
|
current_file_size = 0
|
||||||
current_file_path = output_dir / f"batch_requests_{file_num:04d}.jsonl"
|
current_file_path = output_dir / f"batch_requests_{file_num:04d}.jsonl"
|
||||||
current_file = open(current_file_path, "w")
|
current_file = open(current_file_path, "w")
|
||||||
|
|
||||||
pdfs_processed = 0
|
pdfs_processed = 0
|
||||||
pdfs_skipped = 0
|
pdfs_skipped = 0
|
||||||
|
|
||||||
# Find PDF files
|
# Find PDF files
|
||||||
pdf_files = list(find_pdf_files(input_dir))
|
pdf_files = list(find_pdf_files(input_dir))
|
||||||
|
|
||||||
# Limit files if max_pdfs is specified
|
# Limit files if max_pdfs is specified
|
||||||
if max_pdfs:
|
if max_pdfs:
|
||||||
pdf_files = pdf_files[:max_pdfs]
|
pdf_files = pdf_files[:max_pdfs]
|
||||||
|
|
||||||
total_pdfs = len(pdf_files)
|
total_pdfs = len(pdf_files)
|
||||||
|
|
||||||
print(f"Found {total_pdfs} PDF files to process")
|
print(f"Found {total_pdfs} PDF files to process")
|
||||||
print(f"Using {num_workers} parallel workers")
|
print(f"Using {num_workers} parallel workers")
|
||||||
|
|
||||||
# Process PDFs in parallel using ThreadPoolExecutor
|
# Process PDFs in parallel using ThreadPoolExecutor
|
||||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||||
# Submit all PDF processing tasks
|
# Submit all PDF processing tasks
|
||||||
future_to_pdf = {
|
future_to_pdf = {executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path for pdf_path in pdf_files}
|
||||||
executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path
|
|
||||||
for pdf_path in pdf_files
|
|
||||||
}
|
|
||||||
|
|
||||||
# Process results as they complete
|
# Process results as they complete
|
||||||
with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar:
|
with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar:
|
||||||
for future in as_completed(future_to_pdf):
|
for future in as_completed(future_to_pdf):
|
||||||
pdf_path = future_to_pdf[future]
|
pdf_path = future_to_pdf[future]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = future.result()
|
result = future.result()
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
# PDF was skipped (multi-page or error)
|
# PDF was skipped (multi-page or error)
|
||||||
pdfs_skipped += 1
|
pdfs_skipped += 1
|
||||||
@ -216,7 +211,7 @@ def process_pdfs_to_batch_requests(
|
|||||||
request, _ = result
|
request, _ = result
|
||||||
request_json = json.dumps(request)
|
request_json = json.dumps(request)
|
||||||
request_size = len(request_json.encode("utf-8"))
|
request_size = len(request_json.encode("utf-8"))
|
||||||
|
|
||||||
# Check if we need to start a new file
|
# Check if we need to start a new file
|
||||||
if current_file_size + request_size > MAX_FILE_SIZE:
|
if current_file_size + request_size > MAX_FILE_SIZE:
|
||||||
current_file.close()
|
current_file.close()
|
||||||
@ -225,88 +220,66 @@ def process_pdfs_to_batch_requests(
|
|||||||
current_file = open(current_file_path, "w")
|
current_file = open(current_file_path, "w")
|
||||||
current_file_size = 0
|
current_file_size = 0
|
||||||
print(f"\nStarting new batch file: {current_file_path.name}")
|
print(f"\nStarting new batch file: {current_file_path.name}")
|
||||||
|
|
||||||
# Write the request (only in main thread)
|
# Write the request (only in main thread)
|
||||||
current_file.write(request_json)
|
current_file.write(request_json)
|
||||||
current_file.write("\n")
|
current_file.write("\n")
|
||||||
current_file_size += request_size
|
current_file_size += request_size
|
||||||
|
|
||||||
pdfs_processed += 1
|
pdfs_processed += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError with {pdf_path}: {e}")
|
print(f"\nError with {pdf_path}: {e}")
|
||||||
pdfs_skipped += 1
|
pdfs_skipped += 1
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
# Close the last file
|
# Close the last file
|
||||||
current_file.close()
|
current_file.close()
|
||||||
|
|
||||||
print(f"\nProcessing complete:")
|
print(f"\nProcessing complete:")
|
||||||
print(f" - PDFs processed: {pdfs_processed}")
|
print(f" - PDFs processed: {pdfs_processed}")
|
||||||
print(f" - PDFs skipped: {pdfs_skipped}")
|
print(f" - PDFs skipped: {pdfs_skipped}")
|
||||||
print(f" - Batch files created: {file_num + 1}")
|
print(f" - Batch files created: {file_num + 1}")
|
||||||
print(f" - Output directory: {output_dir}")
|
print(f" - Output directory: {output_dir}")
|
||||||
|
|
||||||
return pdfs_processed
|
return pdfs_processed
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Build OpenAI batch requests from OLMoCR-mix folder structure")
|
||||||
description="Build OpenAI batch requests from OLMoCR-mix folder structure"
|
parser.add_argument("--output_dir", type=str, default=None, help="Output directory for batch request files (default: input_dir/batch_requests)")
|
||||||
)
|
parser.add_argument("--max_pdfs", type=int, default=None, help="Maximum number of PDFs to process (default: all)")
|
||||||
parser.add_argument(
|
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
|
||||||
"--output_dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Output directory for batch request files (default: input_dir/batch_requests)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_pdfs",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Maximum number of PDFs to process (default: all)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_workers",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Number of parallel workers for processing (default: 8)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"input_dir",
|
"input_dir",
|
||||||
type=str,
|
type=str,
|
||||||
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)"
|
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Convert paths to Path objects
|
# Convert paths to Path objects
|
||||||
input_dir = Path(args.input_dir).expanduser().resolve()
|
input_dir = Path(args.input_dir).expanduser().resolve()
|
||||||
|
|
||||||
if not input_dir.exists():
|
if not input_dir.exists():
|
||||||
print(f"Error: Input directory does not exist: {input_dir}")
|
print(f"Error: Input directory does not exist: {input_dir}")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# Set default output directory if not specified
|
# Set default output directory if not specified
|
||||||
if args.output_dir:
|
if args.output_dir:
|
||||||
output_dir = Path(args.output_dir).expanduser().resolve()
|
output_dir = Path(args.output_dir).expanduser().resolve()
|
||||||
else:
|
else:
|
||||||
output_dir = input_dir / "batch_requests"
|
output_dir = input_dir / "batch_requests"
|
||||||
|
|
||||||
print(f"Input directory: {input_dir}")
|
print(f"Input directory: {input_dir}")
|
||||||
print(f"Output directory: {output_dir}")
|
print(f"Output directory: {output_dir}")
|
||||||
|
|
||||||
# Process PDFs
|
# Process PDFs
|
||||||
process_pdfs_to_batch_requests(
|
process_pdfs_to_batch_requests(input_dir=input_dir, output_dir=output_dir, max_pdfs=args.max_pdfs, num_workers=args.num_workers)
|
||||||
input_dir=input_dir,
|
|
||||||
output_dir=output_dir,
|
|
||||||
max_pdfs=args.max_pdfs,
|
|
||||||
num_workers=args.num_workers
|
|
||||||
)
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
exit(main())
|
exit(main())
|
||||||
|
|||||||
@ -8,30 +8,31 @@ that mirrors the original structure with side-by-side PDF and MD files.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import shutil
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
import shutil
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
|
def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Parse a single line from the batch response file.
|
Parse a single line from the batch response file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response_line: JSON line from batch response file
|
response_line: JSON line from batch response file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Parsed response dictionary or None if error
|
Parsed response dictionary or None if error
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data = json.loads(response_line)
|
data = json.loads(response_line)
|
||||||
|
|
||||||
# Extract the custom_id and response
|
# Extract the custom_id and response
|
||||||
custom_id = data.get("custom_id")
|
custom_id = data.get("custom_id")
|
||||||
|
|
||||||
# Check if the response was successful
|
# Check if the response was successful
|
||||||
if "response" in data and data["response"].get("status_code") == 200:
|
if "response" in data and data["response"].get("status_code") == 200:
|
||||||
body = data["response"]["body"]
|
body = data["response"]["body"]
|
||||||
@ -39,14 +40,11 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
|
|||||||
content = body["choices"][0]["message"]["content"]
|
content = body["choices"][0]["message"]["content"]
|
||||||
# Parse the JSON response
|
# Parse the JSON response
|
||||||
parsed_content = json.loads(content)
|
parsed_content = json.loads(content)
|
||||||
return {
|
return {"custom_id": custom_id, "content": parsed_content}
|
||||||
"custom_id": custom_id,
|
|
||||||
"content": parsed_content
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}")
|
print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error parsing response line: {e}")
|
print(f"Error parsing response line: {e}")
|
||||||
return None
|
return None
|
||||||
@ -55,10 +53,10 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
|
|||||||
def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
|
def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Format the response data as FrontMatter markdown.
|
Format the response data as FrontMatter markdown.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response_data: Parsed response data from OpenAI
|
response_data: Parsed response data from OpenAI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted markdown string with FrontMatter
|
Formatted markdown string with FrontMatter
|
||||||
"""
|
"""
|
||||||
@ -69,7 +67,7 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
|
|||||||
is_table = response_data.get("is_table", False)
|
is_table = response_data.get("is_table", False)
|
||||||
is_diagram = response_data.get("is_diagram", False)
|
is_diagram = response_data.get("is_diagram", False)
|
||||||
natural_text = response_data.get("natural_text", "")
|
natural_text = response_data.get("natural_text", "")
|
||||||
|
|
||||||
# Format as FrontMatter
|
# Format as FrontMatter
|
||||||
markdown = "---\n"
|
markdown = "---\n"
|
||||||
markdown += f"primary_language: {primary_language if primary_language else 'null'}\n"
|
markdown += f"primary_language: {primary_language if primary_language else 'null'}\n"
|
||||||
@ -78,29 +76,24 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
|
|||||||
markdown += f"is_table: {str(is_table)}\n"
|
markdown += f"is_table: {str(is_table)}\n"
|
||||||
markdown += f"is_diagram: {str(is_diagram)}\n"
|
markdown += f"is_diagram: {str(is_diagram)}\n"
|
||||||
markdown += "---\n"
|
markdown += "---\n"
|
||||||
|
|
||||||
# Add the natural text content
|
# Add the natural text content
|
||||||
if natural_text:
|
if natural_text:
|
||||||
markdown += natural_text
|
markdown += natural_text
|
||||||
|
|
||||||
return markdown.strip()
|
return markdown.strip()
|
||||||
|
|
||||||
|
|
||||||
def process_single_result(
|
def process_single_result(custom_id: str, response_content: Dict[str, Any], original_pdf_dir: Path, output_dir: Path) -> bool:
|
||||||
custom_id: str,
|
|
||||||
response_content: Dict[str, Any],
|
|
||||||
original_pdf_dir: Path,
|
|
||||||
output_dir: Path
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Process a single batch result: copy PDF and create MD file.
|
Process a single batch result: copy PDF and create MD file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
custom_id: Custom ID from the batch request
|
custom_id: Custom ID from the batch request
|
||||||
response_content: Parsed response content
|
response_content: Parsed response content
|
||||||
original_pdf_dir: Directory containing original PDFs
|
original_pdf_dir: Directory containing original PDFs
|
||||||
output_dir: Output directory for results
|
output_dir: Output directory for results
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful, False otherwise
|
True if successful, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -109,75 +102,70 @@ def process_single_result(
|
|||||||
# Custom ID format: "folder/filename" (without .pdf)
|
# Custom ID format: "folder/filename" (without .pdf)
|
||||||
pdf_relative_path = f"{custom_id}.pdf"
|
pdf_relative_path = f"{custom_id}.pdf"
|
||||||
original_pdf_path = original_pdf_dir / pdf_relative_path
|
original_pdf_path = original_pdf_dir / pdf_relative_path
|
||||||
|
|
||||||
if not original_pdf_path.exists():
|
if not original_pdf_path.exists():
|
||||||
print(f"Warning: Original PDF not found: {original_pdf_path}")
|
print(f"Warning: Original PDF not found: {original_pdf_path}")
|
||||||
|
|
||||||
original_pdf_path = str(original_pdf_path)
|
original_pdf_path = str(original_pdf_path)
|
||||||
pattern = r'(.+?)(-\d+)\.pdf$'
|
pattern = r"(.+?)(-\d+)\.pdf$"
|
||||||
replacement = r'\1.pdf\2.pdf'
|
replacement = r"\1.pdf\2.pdf"
|
||||||
|
|
||||||
original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path))
|
original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path))
|
||||||
|
|
||||||
if not original_pdf_path.exists():
|
if not original_pdf_path.exists():
|
||||||
print(f"Error: Original PDF not found: {original_pdf_path}")
|
print(f"Error: Original PDF not found: {original_pdf_path}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Create output paths
|
# Create output paths
|
||||||
output_pdf_path = output_dir / pdf_relative_path
|
output_pdf_path = output_dir / pdf_relative_path
|
||||||
output_md_path = output_dir / f"{custom_id}.md"
|
output_md_path = output_dir / f"{custom_id}.md"
|
||||||
|
|
||||||
# Create parent directories if needed
|
# Create parent directories if needed
|
||||||
output_pdf_path.parent.mkdir(parents=True, exist_ok=True)
|
output_pdf_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Copy the PDF file
|
# Copy the PDF file
|
||||||
shutil.copy2(original_pdf_path, output_pdf_path)
|
shutil.copy2(original_pdf_path, output_pdf_path)
|
||||||
|
|
||||||
# Create the markdown file
|
# Create the markdown file
|
||||||
markdown_content = format_frontmatter_markdown(response_content)
|
markdown_content = format_frontmatter_markdown(response_content)
|
||||||
with open(output_md_path, "w", encoding="utf-8") as f:
|
with open(output_md_path, "w", encoding="utf-8") as f:
|
||||||
f.write(markdown_content)
|
f.write(markdown_content)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {custom_id}: {e}")
|
print(f"Error processing {custom_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def process_batch_results(
|
def process_batch_results(batch_results_dir: Path, original_pdf_dir: Path, output_dir: Path, num_workers: int = 8) -> int:
|
||||||
batch_results_dir: Path,
|
|
||||||
original_pdf_dir: Path,
|
|
||||||
output_dir: Path,
|
|
||||||
num_workers: int = 8
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
Process all batch result files and create output structure.
|
Process all batch result files and create output structure.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_results_dir: Directory containing batch result JSONL files
|
batch_results_dir: Directory containing batch result JSONL files
|
||||||
original_pdf_dir: Directory containing original PDFs
|
original_pdf_dir: Directory containing original PDFs
|
||||||
output_dir: Output directory for processed results
|
output_dir: Output directory for processed results
|
||||||
num_workers: Number of parallel workers
|
num_workers: Number of parallel workers
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of successfully processed results
|
Number of successfully processed results
|
||||||
"""
|
"""
|
||||||
# Ensure output directory exists
|
# Ensure output directory exists
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Find all batch result files (both .jsonl and .json)
|
# Find all batch result files (both .jsonl and .json)
|
||||||
batch_files = list(batch_results_dir.glob("*.jsonl")) + list(batch_results_dir.glob("*.json"))
|
batch_files = list(batch_results_dir.glob("*.jsonl")) + list(batch_results_dir.glob("*.json"))
|
||||||
|
|
||||||
if not batch_files:
|
if not batch_files:
|
||||||
print(f"No batch result files found in {batch_results_dir}")
|
print(f"No batch result files found in {batch_results_dir}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
print(f"Found {len(batch_files)} batch result files")
|
print(f"Found {len(batch_files)} batch result files")
|
||||||
|
|
||||||
# Collect all results to process
|
# Collect all results to process
|
||||||
results_to_process = []
|
results_to_process = []
|
||||||
|
|
||||||
for batch_file in batch_files:
|
for batch_file in batch_files:
|
||||||
print(f"Reading {batch_file.name}...")
|
print(f"Reading {batch_file.name}...")
|
||||||
with open(batch_file, "r") as f:
|
with open(batch_file, "r") as f:
|
||||||
@ -186,33 +174,27 @@ def process_batch_results(
|
|||||||
parsed = parse_batch_response(line)
|
parsed = parse_batch_response(line)
|
||||||
if parsed:
|
if parsed:
|
||||||
results_to_process.append(parsed)
|
results_to_process.append(parsed)
|
||||||
|
|
||||||
total_results = len(results_to_process)
|
total_results = len(results_to_process)
|
||||||
print(f"Found {total_results} valid results to process")
|
print(f"Found {total_results} valid results to process")
|
||||||
print(f"Using {num_workers} parallel workers")
|
print(f"Using {num_workers} parallel workers")
|
||||||
|
|
||||||
successful = 0
|
successful = 0
|
||||||
failed = 0
|
failed = 0
|
||||||
|
|
||||||
# Process results in parallel
|
# Process results in parallel
|
||||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||||
# Submit all processing tasks
|
# Submit all processing tasks
|
||||||
future_to_result = {
|
future_to_result = {
|
||||||
executor.submit(
|
executor.submit(process_single_result, result["custom_id"], result["content"], original_pdf_dir, output_dir): result["custom_id"]
|
||||||
process_single_result,
|
|
||||||
result["custom_id"],
|
|
||||||
result["content"],
|
|
||||||
original_pdf_dir,
|
|
||||||
output_dir
|
|
||||||
): result["custom_id"]
|
|
||||||
for result in results_to_process
|
for result in results_to_process
|
||||||
}
|
}
|
||||||
|
|
||||||
# Process results as they complete
|
# Process results as they complete
|
||||||
with tqdm(total=total_results, desc="Processing results") as pbar:
|
with tqdm(total=total_results, desc="Processing results") as pbar:
|
||||||
for future in as_completed(future_to_result):
|
for future in as_completed(future_to_result):
|
||||||
custom_id = future_to_result[future]
|
custom_id = future_to_result[future]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = future.result()
|
success = future.result()
|
||||||
if success:
|
if success:
|
||||||
@ -222,73 +204,51 @@ def process_batch_results(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError with {custom_id}: {e}")
|
print(f"\nError with {custom_id}: {e}")
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
print(f"\nProcessing complete:")
|
print(f"\nProcessing complete:")
|
||||||
print(f" - Successfully processed: {successful}")
|
print(f" - Successfully processed: {successful}")
|
||||||
print(f" - Failed: {failed}")
|
print(f" - Failed: {failed}")
|
||||||
print(f" - Output directory: {output_dir}")
|
print(f" - Output directory: {output_dir}")
|
||||||
|
|
||||||
return successful
|
return successful
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Process OpenAI batch results and create output folder with PDFs and Markdown files")
|
||||||
description="Process OpenAI batch results and create output folder with PDFs and Markdown files"
|
parser.add_argument("batch_results_dir", type=str, help="Directory containing completed OpenAI batch result files (JSONL)")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"batch_results_dir",
|
"original_pdf_dir", type=str, help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
|
||||||
type=str,
|
|
||||||
help="Directory containing completed OpenAI batch result files (JSONL)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("output_dir", type=str, help="Output directory for processed results with PDFs and MD files")
|
||||||
"original_pdf_dir",
|
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
|
||||||
type=str,
|
|
||||||
help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"output_dir",
|
|
||||||
type=str,
|
|
||||||
help="Output directory for processed results with PDFs and MD files"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_workers",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Number of parallel workers for processing (default: 8)"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Convert paths to Path objects
|
# Convert paths to Path objects
|
||||||
batch_results_dir = Path(args.batch_results_dir).expanduser().resolve()
|
batch_results_dir = Path(args.batch_results_dir).expanduser().resolve()
|
||||||
original_pdf_dir = Path(args.original_pdf_dir).expanduser().resolve()
|
original_pdf_dir = Path(args.original_pdf_dir).expanduser().resolve()
|
||||||
output_dir = Path(args.output_dir).expanduser().resolve()
|
output_dir = Path(args.output_dir).expanduser().resolve()
|
||||||
|
|
||||||
# Validate input directories
|
# Validate input directories
|
||||||
if not batch_results_dir.exists():
|
if not batch_results_dir.exists():
|
||||||
print(f"Error: Batch results directory does not exist: {batch_results_dir}")
|
print(f"Error: Batch results directory does not exist: {batch_results_dir}")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
if not original_pdf_dir.exists():
|
if not original_pdf_dir.exists():
|
||||||
print(f"Error: Original PDF directory does not exist: {original_pdf_dir}")
|
print(f"Error: Original PDF directory does not exist: {original_pdf_dir}")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
print(f"Batch results directory: {batch_results_dir}")
|
print(f"Batch results directory: {batch_results_dir}")
|
||||||
print(f"Original PDF directory: {original_pdf_dir}")
|
print(f"Original PDF directory: {original_pdf_dir}")
|
||||||
print(f"Output directory: {output_dir}")
|
print(f"Output directory: {output_dir}")
|
||||||
|
|
||||||
# Process the batch results
|
# Process the batch results
|
||||||
process_batch_results(
|
process_batch_results(batch_results_dir=batch_results_dir, original_pdf_dir=original_pdf_dir, output_dir=output_dir, num_workers=args.num_workers)
|
||||||
batch_results_dir=batch_results_dir,
|
|
||||||
original_pdf_dir=original_pdf_dir,
|
|
||||||
output_dir=output_dir,
|
|
||||||
num_workers=args.num_workers
|
|
||||||
)
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
exit(main())
|
exit(main())
|
||||||
|
|||||||
@ -16,6 +16,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
|
|||||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_openai_silver_data_prompt_v2(base_text: str) -> str:
|
def build_openai_silver_data_prompt_v2(base_text: str) -> str:
|
||||||
return (
|
return (
|
||||||
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). "
|
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). "
|
||||||
@ -30,6 +31,7 @@ def build_openai_silver_data_prompt_v2(base_text: str) -> str:
|
|||||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str:
|
def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str:
|
||||||
return (
|
return (
|
||||||
f"Attached is the image of one page of a PDF document."
|
f"Attached is the image of one page of a PDF document."
|
||||||
@ -44,6 +46,7 @@ def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int)
|
|||||||
f"Page width: {page_width}, Page height: {page_height}"
|
f"Page width: {page_width}, Page height: {page_height}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str:
|
def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str:
|
||||||
return (
|
return (
|
||||||
f"Attached is the image of one page of a PDF document."
|
f"Attached is the image of one page of a PDF document."
|
||||||
@ -60,7 +63,6 @@ def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int)
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PageResponse:
|
class PageResponse:
|
||||||
primary_language: Optional[str]
|
primary_language: Optional[str]
|
||||||
|
|||||||
@ -1,10 +1,14 @@
|
|||||||
|
import argparse
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields, replace
|
||||||
|
from html.parser import HTMLParser
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -419,8 +423,6 @@ class LatexBracketNormalizer(PipelineStep):
|
|||||||
|
|
||||||
# Update the page_data with normalized text
|
# Update the page_data with normalized text
|
||||||
# Since PageResponse is frozen, we need to create a new instance
|
# Since PageResponse is frozen, we need to create a new instance
|
||||||
from olmocr.prompts.prompts import PageResponse
|
|
||||||
|
|
||||||
new_page_data = PageResponse(
|
new_page_data = PageResponse(
|
||||||
primary_language=page_data.primary_language,
|
primary_language=page_data.primary_language,
|
||||||
is_rotation_valid=page_data.is_rotation_valid,
|
is_rotation_valid=page_data.is_rotation_valid,
|
||||||
@ -482,8 +484,6 @@ class RotationAugmentation(PipelineStep):
|
|||||||
else: # 270
|
else: # 270
|
||||||
correction = 90
|
correction = 90
|
||||||
|
|
||||||
from olmocr.prompts.prompts import PageResponse
|
|
||||||
|
|
||||||
new_page_data = PageResponse(
|
new_page_data = PageResponse(
|
||||||
primary_language=page_data.primary_language,
|
primary_language=page_data.primary_language,
|
||||||
is_rotation_valid=False, # Mark as invalid since we rotated it
|
is_rotation_valid=False, # Mark as invalid since we rotated it
|
||||||
@ -523,7 +523,7 @@ class FilterOutRotatedDocuments(PipelineStep):
|
|||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class DatasetTextRuleFilter(PipelineStep):
|
class DatasetTextRuleFilter(PipelineStep):
|
||||||
"""Pipeline step that filters samples based on text content rules.
|
"""Pipeline step that filters samples based on text content rules.
|
||||||
|
|
||||||
Filters out samples that:
|
Filters out samples that:
|
||||||
- Contain markdown tables
|
- Contain markdown tables
|
||||||
- Contain malformed HTML tables
|
- Contain malformed HTML tables
|
||||||
@ -539,205 +539,244 @@ class DatasetTextRuleFilter(PipelineStep):
|
|||||||
# Look for pipe-separated table patterns
|
# Look for pipe-separated table patterns
|
||||||
# Markdown tables have lines like: | col1 | col2 | col3 |
|
# Markdown tables have lines like: | col1 | col2 | col3 |
|
||||||
# And separator lines like: |------|------|------|
|
# And separator lines like: |------|------|------|
|
||||||
lines = text.split('\n')
|
lines = text.split("\n")
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
# Check if line looks like a table row
|
# Check if line looks like a table row
|
||||||
if line.startswith('|') and line.endswith('|') and line.count('|') >= 3:
|
if line.startswith("|") and line.endswith("|") and line.count("|") >= 3:
|
||||||
# Check if next line is a separator (for header rows)
|
# Check if next line is a separator (for header rows)
|
||||||
if i + 1 < len(lines):
|
if i + 1 < len(lines):
|
||||||
next_line = lines[i + 1].strip()
|
next_line = lines[i + 1].strip()
|
||||||
if next_line.startswith('|') and '-' in next_line:
|
if next_line.startswith("|") and "-" in next_line:
|
||||||
return True
|
return True
|
||||||
# Check if previous line is a separator (for data rows)
|
# Check if previous line is a separator (for data rows)
|
||||||
if i > 0:
|
if i > 0:
|
||||||
prev_line = lines[i - 1].strip()
|
prev_line = lines[i - 1].strip()
|
||||||
if prev_line.startswith('|') and '-' in prev_line:
|
if prev_line.startswith("|") and "-" in prev_line:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _contains_math_symbols(self, text: str) -> bool:
|
def _contains_math_symbols(self, text: str) -> bool:
|
||||||
"""Check if text contains specific mathematical symbols outside of table cells.
|
"""Check if text contains specific mathematical symbols outside of table cells.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if text contains any of the specified math symbols outside tables
|
True if text contains any of the specified math symbols outside tables
|
||||||
False otherwise
|
False otherwise
|
||||||
"""
|
"""
|
||||||
import re
|
|
||||||
|
|
||||||
# List of mathematical symbols to check for
|
# List of mathematical symbols to check for
|
||||||
math_symbols = [
|
math_symbols = [
|
||||||
# Set theory and logic
|
# Set theory and logic
|
||||||
'∈', '∉', '⊂', '⊃', '⊆', '⊇', '∅', '∪', '∩', '∀', '∃', '¬',
|
"∈",
|
||||||
|
"∉",
|
||||||
|
"⊂",
|
||||||
|
"⊃",
|
||||||
|
"⊆",
|
||||||
|
"⊇",
|
||||||
|
"∅",
|
||||||
|
"∪",
|
||||||
|
"∩",
|
||||||
|
"∀",
|
||||||
|
"∃",
|
||||||
|
"¬",
|
||||||
# Common mathematical operators
|
# Common mathematical operators
|
||||||
'⊕', '⊗', '⊙',
|
"⊕",
|
||||||
|
"⊗",
|
||||||
|
"⊙",
|
||||||
# Calculus and analysis
|
# Calculus and analysis
|
||||||
'∂', '∇', '∆', '∫', '∬', '∭', '∮', '∏', '∑', '√', '∛', '∜',
|
"∂",
|
||||||
|
"∇",
|
||||||
|
"∆",
|
||||||
|
"∫",
|
||||||
|
"∬",
|
||||||
|
"∭",
|
||||||
|
"∮",
|
||||||
|
"∏",
|
||||||
|
"∑",
|
||||||
|
"√",
|
||||||
|
"∛",
|
||||||
|
"∜",
|
||||||
# Arrows and relations
|
# Arrows and relations
|
||||||
'⊥',
|
"⊥",
|
||||||
# Other common math symbols
|
# Other common math symbols
|
||||||
'∠', '∡', '⊤', '⊢', '⊣', '∴', '∵', '∶', '∷', '∝', '≅', '≆', '≇', '≊', '≋',
|
"∠",
|
||||||
|
"∡",
|
||||||
|
"⊤",
|
||||||
|
"⊢",
|
||||||
|
"⊣",
|
||||||
|
"∴",
|
||||||
|
"∵",
|
||||||
|
"∶",
|
||||||
|
"∷",
|
||||||
|
"∝",
|
||||||
|
"≅",
|
||||||
|
"≆",
|
||||||
|
"≇",
|
||||||
|
"≊",
|
||||||
|
"≋",
|
||||||
# Matrix and vector notation
|
# Matrix and vector notation
|
||||||
'⊕', '⊖', '⊗', '⊘', '⊙', '⊚', '⊛', '⊜', '⊝',
|
"⊕",
|
||||||
|
"⊖",
|
||||||
|
"⊗",
|
||||||
|
"⊘",
|
||||||
|
"⊙",
|
||||||
|
"⊚",
|
||||||
|
"⊛",
|
||||||
|
"⊜",
|
||||||
|
"⊝",
|
||||||
]
|
]
|
||||||
|
|
||||||
# First, remove all HTML tables from the text
|
# First, remove all HTML tables from the text
|
||||||
text_without_tables = text
|
text_without_tables = text
|
||||||
|
|
||||||
# Remove HTML tables
|
# Remove HTML tables
|
||||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||||
text_without_tables = table_pattern.sub('', text_without_tables)
|
text_without_tables = table_pattern.sub("", text_without_tables)
|
||||||
|
|
||||||
# Now check if any of these symbols appear in the text without tables
|
# Now check if any of these symbols appear in the text without tables
|
||||||
for symbol in math_symbols:
|
for symbol in math_symbols:
|
||||||
if symbol in text_without_tables:
|
if symbol in text_without_tables:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _contains_latex_tables(self, text: str) -> bool:
|
def _contains_latex_tables(self, text: str) -> bool:
|
||||||
"""Check if text contains LaTeX table environments.
|
"""Check if text contains LaTeX table environments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if text contains LaTeX tables (\\begin{table}, \\begin{tabular}, etc.)
|
True if text contains LaTeX tables (\\begin{table}, \\begin{tabular}, etc.)
|
||||||
False otherwise
|
False otherwise
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Check for various LaTeX table environments
|
# Check for various LaTeX table environments
|
||||||
latex_table_patterns = [
|
latex_table_patterns = [
|
||||||
r'\\begin\{table\}',
|
r"\\begin\{table\}",
|
||||||
r'\\begin\{tabular\}',
|
r"\\begin\{tabular\}",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check if any LaTeX table pattern exists in the text
|
# Check if any LaTeX table pattern exists in the text
|
||||||
for pattern in latex_table_patterns:
|
for pattern in latex_table_patterns:
|
||||||
if re.search(pattern, text, re.IGNORECASE):
|
if re.search(pattern, text, re.IGNORECASE):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _contains_latex_formatting_outside_math(self, text: str) -> bool:
|
def _contains_latex_formatting_outside_math(self, text: str) -> bool:
|
||||||
"""Check if text contains LaTeX formatting commands outside of math equations.
|
"""Check if text contains LaTeX formatting commands outside of math equations.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if text contains LaTeX formatting commands outside math equations
|
True if text contains LaTeX formatting commands outside math equations
|
||||||
False otherwise
|
False otherwise
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# List of common LaTeX formatting commands to check for
|
# List of common LaTeX formatting commands to check for
|
||||||
latex_commands = [
|
latex_commands = [
|
||||||
# Lists & basic content
|
# Lists & basic content
|
||||||
r'\begin{itemize}',
|
r"\begin{itemize}",
|
||||||
r'\begin{enumerate}',
|
r"\begin{enumerate}",
|
||||||
r'\item',
|
r"\item",
|
||||||
|
|
||||||
# Figures, tables, and captions
|
# Figures, tables, and captions
|
||||||
r'\begin{figure}',
|
r"\begin{figure}",
|
||||||
r'\includegraphics',
|
r"\includegraphics",
|
||||||
r'\caption',
|
r"\caption",
|
||||||
r'\label',
|
r"\label",
|
||||||
r'\ref',
|
r"\ref",
|
||||||
r'\eqref',
|
r"\eqref",
|
||||||
r'\begin{table}',
|
r"\begin{table}",
|
||||||
r'\begin{tabular}',
|
r"\begin{tabular}",
|
||||||
|
|
||||||
# Formatting,
|
# Formatting,
|
||||||
# r'\textit',
|
# r'\textit',
|
||||||
# r'\textbb',
|
# r'\textbb',
|
||||||
|
|
||||||
# Math (strong signals)
|
# Math (strong signals)
|
||||||
r'\begin{equation}',
|
r"\begin{equation}",
|
||||||
r'\begin{align}',
|
r"\begin{align}",
|
||||||
r'\frac',
|
r"\frac",
|
||||||
r'\sum',
|
r"\sum",
|
||||||
r'\int',
|
r"\int",
|
||||||
r'\sqrt',
|
r"\sqrt",
|
||||||
r'\prod',
|
r"\prod",
|
||||||
r'\lim',
|
r"\lim",
|
||||||
r'\binom',
|
r"\binom",
|
||||||
r'\mathbb',
|
r"\mathbb",
|
||||||
r'\mathcal',
|
r"\mathcal",
|
||||||
r'\to',
|
r"\to",
|
||||||
r'\varphi',
|
r"\varphi",
|
||||||
r'\cdot',
|
r"\cdot",
|
||||||
r'\langle',
|
r"\langle",
|
||||||
r'\rangle',
|
r"\rangle",
|
||||||
|
|
||||||
# Citations (bibliography stacks)
|
# Citations (bibliography stacks)
|
||||||
r'\cite',
|
r"\cite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# First, remove all math equations from the text
|
# First, remove all math equations from the text
|
||||||
text_without_math = text
|
text_without_math = text
|
||||||
|
|
||||||
# Patterns for math equations
|
# Patterns for math equations
|
||||||
math_patterns = [
|
math_patterns = [
|
||||||
r"\$\$(.+?)\$\$", # $$...$$
|
r"\$\$(.+?)\$\$", # $$...$$
|
||||||
r"\\\((.+?)\\\)", # \(...\)
|
r"\\\((.+?)\\\)", # \(...\)
|
||||||
r"\\\[(.+?)\\\]", # \[...\]
|
r"\\\[(.+?)\\\]", # \[...\]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove all math equations
|
# Remove all math equations
|
||||||
for pattern in math_patterns:
|
for pattern in math_patterns:
|
||||||
text_without_math = re.sub(pattern, '', text_without_math, flags=re.DOTALL)
|
text_without_math = re.sub(pattern, "", text_without_math, flags=re.DOTALL)
|
||||||
|
|
||||||
# Check if any LaTeX commands appear in the remaining text
|
# Check if any LaTeX commands appear in the remaining text
|
||||||
for command in latex_commands:
|
for command in latex_commands:
|
||||||
if command in text_without_math:
|
if command in text_without_math:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _validate_math_equations(self, text: str) -> bool:
|
def _validate_math_equations(self, text: str) -> bool:
|
||||||
"""Check if all math equations in the text can render without errors.
|
"""Check if all math equations in the text can render without errors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if all equations render successfully or no equations exist
|
True if all equations render successfully or no equations exist
|
||||||
False if any equation fails to render
|
False if any equation fails to render
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Patterns to find math equations (same as in MathTest)
|
# Patterns to find math equations (same as in MathTest)
|
||||||
patterns = [
|
patterns = [
|
||||||
r"\$\$(.+?)\$\$", # $$...$$
|
r"\$\$(.+?)\$\$", # $$...$$
|
||||||
r"\\\((.+?)\\\)", # \(...\)
|
r"\\\((.+?)\\\)", # \(...\)
|
||||||
r"\\\[(.+?)\\\]", # \[...\]
|
r"\\\[(.+?)\\\]", # \[...\]
|
||||||
]
|
]
|
||||||
|
|
||||||
equations = []
|
equations = []
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
# Find all matches for the current pattern
|
# Find all matches for the current pattern
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
equations.extend([eq.strip() for eq in matches])
|
equations.extend([eq.strip() for eq in matches])
|
||||||
|
|
||||||
# If no equations found, that's fine
|
# If no equations found, that's fine
|
||||||
if not equations:
|
if not equations:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Try to render each equation
|
# Try to render each equation
|
||||||
try:
|
try:
|
||||||
from olmocr.bench.katex.render import render_equation
|
from olmocr.bench.katex.render import render_equation
|
||||||
|
|
||||||
for equation in equations:
|
for equation in equations:
|
||||||
# Skip empty or whitespace-only equations
|
# Skip empty or whitespace-only equations
|
||||||
if not equation or not equation.strip():
|
if not equation or not equation.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Try to render the equation
|
# Try to render the equation
|
||||||
rendered = render_equation(equation)
|
rendered = render_equation(equation)
|
||||||
|
|
||||||
# Check if there was an error
|
# Check if there was an error
|
||||||
if rendered is None or (hasattr(rendered, 'error') and rendered.error):
|
if rendered is None or (hasattr(rendered, "error") and rendered.error):
|
||||||
# Equation failed to render
|
# Equation failed to render
|
||||||
logger.warning(f"Could not render equation '{repr(equation)}', skipping sample")
|
logger.warning(f"Could not render equation '{repr(equation)}', skipping sample")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# All equations rendered successfully
|
# All equations rendered successfully
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# If we can't import the render module, skip this check
|
# If we can't import the render module, skip this check
|
||||||
# This allows the filter to work even without the rendering dependencies
|
# This allows the filter to work even without the rendering dependencies
|
||||||
@ -746,87 +785,86 @@ class DatasetTextRuleFilter(PipelineStep):
|
|||||||
# If any unexpected error occurs during validation, be conservative and filter out
|
# If any unexpected error occurs during validation, be conservative and filter out
|
||||||
print(f"Error validating math equations: {e}")
|
print(f"Error validating math equations: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _contains_br_in_table_cells(self, text: str) -> bool:
|
def _contains_br_in_table_cells(self, text: str) -> bool:
|
||||||
"""Check if text contains <br> tags within HTML table cells.
|
"""Check if text contains <br> tags within HTML table cells.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if any table cell contains <br> tags
|
True if any table cell contains <br> tags
|
||||||
False otherwise
|
False otherwise
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Check if there are any tables in the text
|
# Check if there are any tables in the text
|
||||||
if '<table' not in text.lower() or '<br' not in text.lower():
|
if "<table" not in text.lower() or "<br" not in text.lower():
|
||||||
return False # No tables or no <br> tags at all
|
return False # No tables or no <br> tags at all
|
||||||
|
|
||||||
# Pattern to find HTML tables (case-insensitive)
|
# Pattern to find HTML tables (case-insensitive)
|
||||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||||
tables = table_pattern.findall(text)
|
tables = table_pattern.findall(text)
|
||||||
|
|
||||||
# Check each table for <br> tags in cells
|
# Check each table for <br> tags in cells
|
||||||
for table_html in tables:
|
for table_html in tables:
|
||||||
# Pattern to find table cells (td and th tags)
|
# Pattern to find table cells (td and th tags)
|
||||||
cell_pattern = re.compile(r'<(td|th)\b[^>]*>(.*?)</\1>', re.IGNORECASE | re.DOTALL)
|
cell_pattern = re.compile(r"<(td|th)\b[^>]*>(.*?)</\1>", re.IGNORECASE | re.DOTALL)
|
||||||
cells = cell_pattern.findall(table_html)
|
cells = cell_pattern.findall(table_html)
|
||||||
|
|
||||||
for tag_type, cell_content in cells:
|
for tag_type, cell_content in cells:
|
||||||
# Check if cell content contains <br> tags (any variation)
|
# Check if cell content contains <br> tags (any variation)
|
||||||
if re.search(r'<br\s*/?>', cell_content, re.IGNORECASE):
|
if re.search(r"<br\s*/?>", cell_content, re.IGNORECASE):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _extract_and_validate_html_tables(self, text: str) -> bool:
|
def _extract_and_validate_html_tables(self, text: str) -> bool:
|
||||||
"""Extract HTML tables and validate they parse correctly.
|
"""Extract HTML tables and validate they parse correctly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if all HTML tables are valid or no tables exist
|
True if all HTML tables are valid or no tables exist
|
||||||
False if any HTML table is malformed
|
False if any HTML table is malformed
|
||||||
"""
|
"""
|
||||||
# Find all HTML table blocks
|
# Find all HTML table blocks
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Check if there are any <table> tags at all
|
# Check if there are any <table> tags at all
|
||||||
if '<table' not in text.lower():
|
if "<table" not in text.lower():
|
||||||
return True # No tables, that's fine
|
return True # No tables, that's fine
|
||||||
|
|
||||||
# Pattern to find HTML tables (case-insensitive)
|
# Pattern to find HTML tables (case-insensitive)
|
||||||
# Note: This pattern might not catch malformed tables where </table> is missing
|
# Note: This pattern might not catch malformed tables where </table> is missing
|
||||||
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
|
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
|
||||||
tables = table_pattern.findall(text)
|
tables = table_pattern.findall(text)
|
||||||
|
|
||||||
# Also check for unclosed table tags
|
# Also check for unclosed table tags
|
||||||
table_open_count = len(re.findall(r'<table\b[^>]*>', text, re.IGNORECASE))
|
table_open_count = len(re.findall(r"<table\b[^>]*>", text, re.IGNORECASE))
|
||||||
table_close_count = len(re.findall(r'</table>', text, re.IGNORECASE))
|
table_close_count = len(re.findall(r"</table>", text, re.IGNORECASE))
|
||||||
|
|
||||||
if table_open_count != table_close_count:
|
if table_open_count != table_close_count:
|
||||||
return False # Mismatched table tags
|
return False # Mismatched table tags
|
||||||
|
|
||||||
if not tables and table_open_count > 0:
|
if not tables and table_open_count > 0:
|
||||||
# Found table tags but couldn't extract complete tables
|
# Found table tags but couldn't extract complete tables
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Try to parse each table
|
# Try to parse each table
|
||||||
from html.parser import HTMLParser
|
|
||||||
|
|
||||||
class TableValidator(HTMLParser):
|
class TableValidator(HTMLParser):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tag_stack = []
|
self.tag_stack = []
|
||||||
self.is_valid = True
|
self.is_valid = True
|
||||||
self.error_msg = None
|
self.error_msg = None
|
||||||
|
|
||||||
def handle_starttag(self, tag, attrs):
|
def handle_starttag(self, tag, attrs):
|
||||||
self.tag_stack.append(tag.lower())
|
self.tag_stack.append(tag.lower())
|
||||||
|
|
||||||
def handle_endtag(self, tag):
|
def handle_endtag(self, tag):
|
||||||
tag = tag.lower()
|
tag = tag.lower()
|
||||||
if not self.tag_stack:
|
if not self.tag_stack:
|
||||||
self.is_valid = False
|
self.is_valid = False
|
||||||
self.error_msg = f"Unexpected closing tag: {tag}"
|
self.error_msg = f"Unexpected closing tag: {tag}"
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if the closing tag matches the most recent opening tag
|
# Check if the closing tag matches the most recent opening tag
|
||||||
if self.tag_stack[-1] == tag:
|
if self.tag_stack[-1] == tag:
|
||||||
self.tag_stack.pop()
|
self.tag_stack.pop()
|
||||||
@ -842,11 +880,11 @@ class DatasetTextRuleFilter(PipelineStep):
|
|||||||
else:
|
else:
|
||||||
self.is_valid = False
|
self.is_valid = False
|
||||||
self.error_msg = f"Mismatched tag: expected {self.tag_stack[-1]}, got {tag}"
|
self.error_msg = f"Mismatched tag: expected {self.tag_stack[-1]}, got {tag}"
|
||||||
|
|
||||||
def error(self, message):
|
def error(self, message):
|
||||||
self.is_valid = False
|
self.is_valid = False
|
||||||
self.error_msg = message
|
self.error_msg = message
|
||||||
|
|
||||||
# Validate each table
|
# Validate each table
|
||||||
for table_html in tables:
|
for table_html in tables:
|
||||||
parser = TableValidator()
|
parser = TableValidator()
|
||||||
@ -860,90 +898,90 @@ class DatasetTextRuleFilter(PipelineStep):
|
|||||||
except Exception:
|
except Exception:
|
||||||
# Any parsing exception means the table is malformed
|
# Any parsing exception means the table is malformed
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __call__(self, sample: Sample) -> Optional[Sample]:
|
def __call__(self, sample: Sample) -> Optional[Sample]:
|
||||||
"""Filter samples based on text content rules."""
|
"""Filter samples based on text content rules."""
|
||||||
# Get the natural text from page_data if it exists
|
# Get the natural text from page_data if it exists
|
||||||
text = None
|
text = None
|
||||||
|
|
||||||
if "page_data" in sample:
|
if "page_data" in sample:
|
||||||
page_data = sample["page_data"]
|
page_data = sample["page_data"]
|
||||||
if hasattr(page_data, "natural_text") and page_data.natural_text:
|
if hasattr(page_data, "natural_text") and page_data.natural_text:
|
||||||
text = page_data.natural_text
|
text = page_data.natural_text
|
||||||
|
|
||||||
# If no text to check, pass the sample through
|
# If no text to check, pass the sample through
|
||||||
if text is None:
|
if text is None:
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
# Check for markdown tables
|
# # Check for markdown tables
|
||||||
if self._contains_markdown_table(text):
|
# if self._contains_markdown_table(text):
|
||||||
return None # Filter out samples with markdown tables
|
# return None # Filter out samples with markdown tables
|
||||||
|
|
||||||
# Check for HTML tables and validate them
|
# # Check for HTML tables and validate them
|
||||||
if not self._extract_and_validate_html_tables(text):
|
# if not self._extract_and_validate_html_tables(text):
|
||||||
return None # Filter out samples with malformed HTML tables
|
# return None # Filter out samples with malformed HTML tables
|
||||||
|
|
||||||
# Check for <br> tags in table cells
|
# # Check for <br> tags in table cells
|
||||||
if self._contains_br_in_table_cells(text):
|
# if self._contains_br_in_table_cells(text):
|
||||||
return None # Filter out samples with <br> tags in table cells
|
# return None # Filter out samples with <br> tags in table cells
|
||||||
|
|
||||||
# Check if all math equations can render without errors
|
# # Check if all math equations can render without errors
|
||||||
if not self._validate_math_equations(text):
|
# if not self._validate_math_equations(text):
|
||||||
return None # Filter out samples with invalid math equations
|
# return None # Filter out samples with invalid math equations
|
||||||
|
|
||||||
# Check for mathematical symbols
|
# # Check for mathematical symbols
|
||||||
if self._contains_math_symbols(text):
|
# if self._contains_math_symbols(text):
|
||||||
return None # Filter out samples with mathematical symbols
|
# return None # Filter out samples with mathematical symbols
|
||||||
|
|
||||||
# Check for LaTeX formatting outside math equations
|
# Check for LaTeX formatting outside math equations
|
||||||
if self._contains_latex_formatting_outside_math(text):
|
if self._contains_latex_formatting_outside_math(text):
|
||||||
return None # Filter out samples with \textit or \textbf outside math
|
return None # Filter out samples with \textit or \textbf outside math
|
||||||
|
|
||||||
# Check for LaTeX tables
|
# Check for LaTeX tables
|
||||||
if self._contains_latex_tables(text):
|
if self._contains_latex_tables(text):
|
||||||
return None # Filter out samples with LaTeX tables
|
return None # Filter out samples with LaTeX tables
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class ReformatLatexBoldItalic(PipelineStep):
|
class ReformatLatexBoldItalic(PipelineStep):
|
||||||
"""Pipeline step that converts LaTeX formatting commands to markdown equivalents.
|
"""Pipeline step that converts LaTeX formatting commands to markdown equivalents.
|
||||||
|
|
||||||
Converts:
|
Converts:
|
||||||
- \\textit{...} to *...* (italic)
|
- \\textit{...} to *...* (italic)
|
||||||
- \\textbf{...} to **...** (bold)
|
- \\textbf{...} to **...** (bold)
|
||||||
|
|
||||||
These conversions only happen outside of math equations.
|
These conversions only happen outside of math equations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, sample: Sample) -> Optional[Sample]:
|
def __call__(self, sample: Sample) -> Optional[Sample]:
|
||||||
"""Convert LaTeX formatting to markdown in the sample text."""
|
"""Convert LaTeX formatting to markdown in the sample text."""
|
||||||
# Get the natural text from page_data if it exists
|
# Get the natural text from page_data if it exists
|
||||||
if "page_data" not in sample:
|
if "page_data" not in sample:
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
page_data = sample["page_data"]
|
page_data = sample["page_data"]
|
||||||
if not hasattr(page_data, "natural_text") or not page_data.natural_text:
|
if not hasattr(page_data, "natural_text") or not page_data.natural_text:
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
text = page_data.natural_text
|
text = page_data.natural_text
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Math equation patterns to preserve
|
# Math equation patterns to preserve
|
||||||
math_patterns = [
|
math_patterns = [
|
||||||
r"\$\$(.+?)\$\$", # $$...$$
|
r"\$\$(.+?)\$\$", # $$...$$
|
||||||
r"\\\((.+?)\\\)", # \(...\)
|
r"\\\((.+?)\\\)", # \(...\)
|
||||||
r"\\\[(.+?)\\\]", # \[...\]
|
r"\\\[(.+?)\\\]", # \[...\]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Store math equations with placeholders
|
# Store math equations with placeholders
|
||||||
math_placeholders = []
|
math_placeholders = []
|
||||||
preserved_text = text
|
preserved_text = text
|
||||||
|
|
||||||
# Replace math equations with placeholders
|
# Replace math equations with placeholders
|
||||||
for i, pattern in enumerate(math_patterns):
|
for i, pattern in enumerate(math_patterns):
|
||||||
matches = re.finditer(pattern, preserved_text, re.DOTALL)
|
matches = re.finditer(pattern, preserved_text, re.DOTALL)
|
||||||
@ -951,65 +989,66 @@ class ReformatLatexBoldItalic(PipelineStep):
|
|||||||
placeholder = f"__MATH_PLACEHOLDER_{i}_{j}__"
|
placeholder = f"__MATH_PLACEHOLDER_{i}_{j}__"
|
||||||
math_placeholders.append((placeholder, match.group(0)))
|
math_placeholders.append((placeholder, match.group(0)))
|
||||||
preserved_text = preserved_text.replace(match.group(0), placeholder, 1)
|
preserved_text = preserved_text.replace(match.group(0), placeholder, 1)
|
||||||
|
|
||||||
# Now convert LaTeX formatting to markdown
|
# Now convert LaTeX formatting to markdown
|
||||||
# We need to handle nested braces properly
|
# We need to handle nested braces properly
|
||||||
# Use a function to find matching braces
|
# Use a function to find matching braces
|
||||||
def replace_latex_command(text, command, markdown):
|
def replace_latex_command(text, command, markdown):
|
||||||
"""Replace LaTeX command with markdown, handling nested braces."""
|
"""Replace LaTeX command with markdown, handling nested braces."""
|
||||||
import re
|
import re
|
||||||
pattern = r'\\' + command + r'\{'
|
|
||||||
|
pattern = r"\\" + command + r"\{"
|
||||||
result = []
|
result = []
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
while i < len(text):
|
while i < len(text):
|
||||||
match = re.search(pattern, text[i:])
|
match = re.search(pattern, text[i:])
|
||||||
if not match:
|
if not match:
|
||||||
result.append(text[i:])
|
result.append(text[i:])
|
||||||
break
|
break
|
||||||
|
|
||||||
# Add text before the match
|
# Add text before the match
|
||||||
result.append(text[i:i + match.start()])
|
result.append(text[i : i + match.start()])
|
||||||
|
|
||||||
# Find the matching closing brace
|
# Find the matching closing brace
|
||||||
start_pos = i + match.end()
|
start_pos = i + match.end()
|
||||||
brace_count = 1
|
brace_count = 1
|
||||||
j = start_pos
|
j = start_pos
|
||||||
|
|
||||||
while j < len(text) and brace_count > 0:
|
while j < len(text) and brace_count > 0:
|
||||||
if text[j] == '{':
|
if text[j] == "{":
|
||||||
brace_count += 1
|
brace_count += 1
|
||||||
elif text[j] == '}':
|
elif text[j] == "}":
|
||||||
brace_count -= 1
|
brace_count -= 1
|
||||||
j += 1
|
j += 1
|
||||||
|
|
||||||
if brace_count == 0:
|
if brace_count == 0:
|
||||||
# Extract the content between braces
|
# Extract the content between braces
|
||||||
content = text[start_pos:j-1]
|
content = text[start_pos : j - 1]
|
||||||
result.append(markdown + content + markdown)
|
result.append(markdown + content + markdown)
|
||||||
i = j
|
i = j
|
||||||
else:
|
else:
|
||||||
# Unmatched braces, keep original
|
# Unmatched braces, keep original
|
||||||
result.append(text[i + match.start():i + match.end()])
|
result.append(text[i + match.start() : i + match.end()])
|
||||||
i = i + match.end()
|
i = i + match.end()
|
||||||
|
|
||||||
return ''.join(result)
|
return "".join(result)
|
||||||
|
|
||||||
# Handle \textbf{...} -> **...**
|
# Handle \textbf{...} -> **...**
|
||||||
preserved_text = replace_latex_command(preserved_text, 'textbf', '**')
|
preserved_text = replace_latex_command(preserved_text, "textbf", "**")
|
||||||
|
|
||||||
# Handle \textit{...} -> *...*
|
# Handle \textit{...} -> *...*
|
||||||
preserved_text = replace_latex_command(preserved_text, 'textit', '*')
|
preserved_text = replace_latex_command(preserved_text, "textit", "*")
|
||||||
|
|
||||||
# Restore math equations
|
# Restore math equations
|
||||||
for placeholder, original in math_placeholders:
|
for placeholder, original in math_placeholders:
|
||||||
preserved_text = preserved_text.replace(placeholder, original)
|
preserved_text = preserved_text.replace(placeholder, original)
|
||||||
|
|
||||||
# Create a new PageResponse with the updated text (since it's frozen)
|
# Create a new PageResponse with the updated text (since it's frozen)
|
||||||
from dataclasses import replace
|
|
||||||
updated_page_data = replace(page_data, natural_text=preserved_text)
|
updated_page_data = replace(page_data, natural_text=preserved_text)
|
||||||
sample["page_data"] = updated_page_data
|
sample["page_data"] = updated_page_data
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@ -1382,78 +1421,71 @@ if __name__ == "__main__":
|
|||||||
if args.save_filtered:
|
if args.save_filtered:
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
save_dir = Path(args.save_filtered)
|
save_dir = Path(args.save_filtered)
|
||||||
|
|
||||||
# Clear and create directory
|
# Clear and create directory
|
||||||
if save_dir.exists():
|
if save_dir.exists():
|
||||||
shutil.rmtree(save_dir)
|
shutil.rmtree(save_dir)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"\n=== Checking for filtered samples ===")
|
print(f"\n=== Checking for filtered samples ===")
|
||||||
print(f"Will save filtered samples to: {save_dir}")
|
print(f"Will save filtered samples to: {save_dir}")
|
||||||
|
|
||||||
# Function to process and copy a single sample
|
# Function to process and copy a single sample
|
||||||
def process_and_copy_sample(idx, dataset_samples, save_dir_str):
|
def process_and_copy_sample(idx, dataset_samples, save_dir_str):
|
||||||
"""Process a sample and return info if it's filtered.
|
"""Process a sample and return info if it's filtered.
|
||||||
|
|
||||||
Note: This function needs to be picklable for ProcessPoolExecutor,
|
Note: This function needs to be picklable for ProcessPoolExecutor,
|
||||||
so it takes simple arguments rather than complex objects.
|
so it takes simple arguments rather than complex objects.
|
||||||
"""
|
"""
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Recreate dataset with same parameters
|
# Recreate dataset with same parameters
|
||||||
# This is needed because dataset objects can't be pickled
|
# This is needed because dataset objects can't be pickled
|
||||||
temp_dataset = BaseMarkdownPDFDataset.__new__(BaseMarkdownPDFDataset)
|
temp_dataset = BaseMarkdownPDFDataset.__new__(BaseMarkdownPDFDataset)
|
||||||
temp_dataset.samples = dataset_samples
|
temp_dataset.samples = dataset_samples
|
||||||
temp_dataset.pipeline_steps = pipeline_steps
|
temp_dataset.pipeline_steps = pipeline_steps
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sample = temp_dataset[idx]
|
sample = temp_dataset[idx]
|
||||||
if sample is None:
|
if sample is None:
|
||||||
# This sample was filtered out - get the original paths
|
# This sample was filtered out - get the original paths
|
||||||
original_sample = dataset_samples[idx]
|
original_sample = dataset_samples[idx]
|
||||||
md_path = original_sample['markdown_path']
|
md_path = original_sample["markdown_path"]
|
||||||
pdf_path = original_sample['pdf_path']
|
pdf_path = original_sample["pdf_path"]
|
||||||
|
|
||||||
save_dir = Path(save_dir_str)
|
save_dir = Path(save_dir_str)
|
||||||
|
|
||||||
# Create subdirectory to preserve some structure
|
# Create subdirectory to preserve some structure
|
||||||
# Use the parent directory name and file name
|
# Use the parent directory name and file name
|
||||||
rel_path = md_path.parent.name
|
rel_path = md_path.parent.name
|
||||||
target_subdir = save_dir / rel_path
|
target_subdir = save_dir / rel_path
|
||||||
target_subdir.mkdir(parents=True, exist_ok=True)
|
target_subdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Copy markdown file
|
# Copy markdown file
|
||||||
target_md = target_subdir / md_path.name
|
target_md = target_subdir / md_path.name
|
||||||
shutil.copy2(md_path, target_md)
|
shutil.copy2(md_path, target_md)
|
||||||
|
|
||||||
# Copy PDF file
|
# Copy PDF file
|
||||||
target_pdf = target_subdir / pdf_path.name
|
target_pdf = target_subdir / pdf_path.name
|
||||||
shutil.copy2(pdf_path, target_pdf)
|
shutil.copy2(pdf_path, target_pdf)
|
||||||
|
|
||||||
return {
|
return {"index": idx, "markdown_path": str(md_path), "pdf_path": str(pdf_path)}
|
||||||
'index': idx,
|
|
||||||
'markdown_path': str(md_path),
|
|
||||||
'pdf_path': str(pdf_path)
|
|
||||||
}
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing sample {idx}: {e}")
|
print(f"Error processing sample {idx}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Process all samples in parallel
|
# Process all samples in parallel
|
||||||
filtered_samples = []
|
filtered_samples = []
|
||||||
print(f"Processing {len(dataset)} samples to find and copy filtered ones...")
|
print(f"Processing {len(dataset)} samples to find and copy filtered ones...")
|
||||||
|
|
||||||
with ProcessPoolExecutor(max_workers=8) as executor:
|
with ProcessPoolExecutor(max_workers=8) as executor:
|
||||||
# Submit all tasks
|
# Submit all tasks
|
||||||
futures = {
|
futures = {executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx for idx in range(len(dataset))}
|
||||||
executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx
|
|
||||||
for idx in range(len(dataset))
|
|
||||||
}
|
|
||||||
|
|
||||||
# Process results with progress bar
|
# Process results with progress bar
|
||||||
with tqdm(total=len(dataset), desc="Processing samples") as pbar:
|
with tqdm(total=len(dataset), desc="Processing samples") as pbar:
|
||||||
for future in as_completed(futures):
|
for future in as_completed(futures):
|
||||||
@ -1461,20 +1493,20 @@ if __name__ == "__main__":
|
|||||||
if result is not None:
|
if result is not None:
|
||||||
filtered_samples.append(result)
|
filtered_samples.append(result)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
# Sort filtered samples by index for consistent output
|
# Sort filtered samples by index for consistent output
|
||||||
filtered_samples.sort(key=lambda x: x['index'])
|
filtered_samples.sort(key=lambda x: x["index"])
|
||||||
|
|
||||||
print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}")
|
print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}")
|
||||||
|
|
||||||
if filtered_samples:
|
if filtered_samples:
|
||||||
print(f"First 10 filtered samples:")
|
print(f"First 10 filtered samples:")
|
||||||
for i, sample_info in enumerate(filtered_samples[:10]):
|
for i, sample_info in enumerate(filtered_samples[:10]):
|
||||||
md_name = Path(sample_info['markdown_path']).name
|
md_name = Path(sample_info["markdown_path"]).name
|
||||||
print(f" Sample {sample_info['index']}: {md_name}")
|
print(f" Sample {sample_info['index']}: {md_name}")
|
||||||
if len(filtered_samples) > 10:
|
if len(filtered_samples) > 10:
|
||||||
print(f" ... and {len(filtered_samples) - 10} more")
|
print(f" ... and {len(filtered_samples) - 10} more")
|
||||||
|
|
||||||
# Exit early if --save-filtered is used (don't continue with other analyses)
|
# Exit early if --save-filtered is used (don't continue with other analyses)
|
||||||
print("\nCompleted saving filtered samples. Exiting.")
|
print("\nCompleted saving filtered samples. Exiting.")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|||||||
@ -44,7 +44,7 @@ Some conclusion text.
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_with_md_table)
|
result = self.filter(sample_with_md_table)
|
||||||
self.assertIsNone(result, "Should filter out samples with markdown tables")
|
self.assertIsNone(result, "Should filter out samples with markdown tables")
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ Some conclusion text.
|
|||||||
natural_text="This is regular text without any tables. It has | pipes | but not in table format.",
|
natural_text="This is regular text without any tables. It has | pipes | but not in table format.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_without_table)
|
result = self.filter(sample_without_table)
|
||||||
self.assertIsNotNone(result, "Should pass through samples without markdown tables")
|
self.assertIsNotNone(result, "Should pass through samples without markdown tables")
|
||||||
self.assertEqual(result, sample_without_table)
|
self.assertEqual(result, sample_without_table)
|
||||||
@ -92,7 +92,7 @@ Some text after table.
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_with_valid_html)
|
result = self.filter(sample_with_valid_html)
|
||||||
self.assertIsNotNone(result, "Should pass through samples with valid HTML tables")
|
self.assertIsNotNone(result, "Should pass through samples with valid HTML tables")
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ Text after.
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_with_malformed_html)
|
result = self.filter(sample_with_malformed_html)
|
||||||
self.assertIsNone(result, "Should filter out samples with malformed HTML tables")
|
self.assertIsNone(result, "Should filter out samples with malformed HTML tables")
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ Text after without closing table tag.
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_with_unclosed_table)
|
result = self.filter(sample_with_unclosed_table)
|
||||||
self.assertIsNone(result, "Should filter out HTML tables without closing tags")
|
self.assertIsNone(result, "Should filter out HTML tables without closing tags")
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ Text after without closing table tag.
|
|||||||
"markdown_path": Path("/path/to/file.md"),
|
"markdown_path": Path("/path/to/file.md"),
|
||||||
"pdf_path": Path("/path/to/file.pdf"),
|
"pdf_path": Path("/path/to/file.pdf"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_without_page_data)
|
result = self.filter(sample_without_page_data)
|
||||||
self.assertIsNotNone(result, "Should pass through samples without page_data")
|
self.assertIsNotNone(result, "Should pass through samples without page_data")
|
||||||
self.assertEqual(result, sample_without_page_data)
|
self.assertEqual(result, sample_without_page_data)
|
||||||
@ -174,7 +174,7 @@ Text after without closing table tag.
|
|||||||
natural_text=None,
|
natural_text=None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_without_text)
|
result = self.filter(sample_without_text)
|
||||||
self.assertIsNotNone(result, "Should pass through samples without natural_text")
|
self.assertIsNotNone(result, "Should pass through samples without natural_text")
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ Text after without closing table tag.
|
|||||||
natural_text="",
|
natural_text="",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_with_empty_text)
|
result = self.filter(sample_with_empty_text)
|
||||||
self.assertIsNotNone(result, "Should pass through samples with empty natural_text")
|
self.assertIsNotNone(result, "Should pass through samples with empty natural_text")
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ Text after without closing table tag.
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_with_alignment)
|
result = self.filter(sample_with_alignment)
|
||||||
self.assertIsNone(result, "Should filter out markdown tables with alignment")
|
self.assertIsNone(result, "Should filter out markdown tables with alignment")
|
||||||
|
|
||||||
@ -240,7 +240,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_mixed)
|
result = self.filter(sample_mixed)
|
||||||
self.assertIsNotNone(result, "Should pass through with valid HTML and no markdown tables")
|
self.assertIsNotNone(result, "Should pass through with valid HTML and no markdown tables")
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
|
|||||||
</table>""",
|
</table>""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_with_br)
|
result = self.filter(sample_with_br)
|
||||||
self.assertIsNone(result, "Should filter out tables with <br> tags in cells")
|
self.assertIsNone(result, "Should filter out tables with <br> tags in cells")
|
||||||
|
|
||||||
@ -283,7 +283,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
|
|||||||
</table>""",
|
</table>""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_br_outside)
|
result = self.filter(sample_br_outside)
|
||||||
self.assertIsNotNone(result, "Should allow <br> tags outside tables")
|
self.assertIsNotNone(result, "Should allow <br> tags outside tables")
|
||||||
|
|
||||||
@ -308,7 +308,7 @@ But no markdown tables. Just some text with | pipes | that aren't tables.
|
|||||||
</table>""",
|
</table>""",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample_br_variations)
|
result = self.filter(sample_br_variations)
|
||||||
self.assertIsNone(result, "Should filter out tables with any <br> variation in cells")
|
self.assertIsNone(result, "Should filter out tables with any <br> variation in cells")
|
||||||
|
|
||||||
@ -332,7 +332,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="This is \\textbf{bold} text.",
|
natural_text="This is \\textbf{bold} text.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "This is **bold** text.")
|
self.assertEqual(result["page_data"].natural_text, "This is **bold** text.")
|
||||||
|
|
||||||
@ -348,7 +348,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="This is \\textit{italic} text.",
|
natural_text="This is \\textit{italic} text.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "This is *italic* text.")
|
self.assertEqual(result["page_data"].natural_text, "This is *italic* text.")
|
||||||
|
|
||||||
@ -364,7 +364,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="This has \\textbf{bold} and \\textit{italic} text.",
|
natural_text="This has \\textbf{bold} and \\textit{italic} text.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "This has **bold** and *italic* text.")
|
self.assertEqual(result["page_data"].natural_text, "This has **bold** and *italic* text.")
|
||||||
|
|
||||||
@ -380,7 +380,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="Text outside $$ \\textbf{x} = \\textit{y} $$ more text.",
|
natural_text="Text outside $$ \\textbf{x} = \\textit{y} $$ more text.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "Text outside $$ \\textbf{x} = \\textit{y} $$ more text.")
|
self.assertEqual(result["page_data"].natural_text, "Text outside $$ \\textbf{x} = \\textit{y} $$ more text.")
|
||||||
|
|
||||||
@ -396,7 +396,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="The \\textbf{equation} is $$ \\textbf{x} = 2 $$ and \\textit{important}.",
|
natural_text="The \\textbf{equation} is $$ \\textbf{x} = 2 $$ and \\textit{important}.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "The **equation** is $$ \\textbf{x} = 2 $$ and *important*.")
|
self.assertEqual(result["page_data"].natural_text, "The **equation** is $$ \\textbf{x} = 2 $$ and *important*.")
|
||||||
|
|
||||||
@ -412,7 +412,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="This is \\textbf{bold with {nested} braces} text.",
|
natural_text="This is \\textbf{bold with {nested} braces} text.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "This is **bold with {nested} braces** text.")
|
self.assertEqual(result["page_data"].natural_text, "This is **bold with {nested} braces** text.")
|
||||||
|
|
||||||
@ -428,12 +428,9 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="\\textbf{First} and \\textbf{second} bold, \\textit{first} and \\textit{second} italic.",
|
natural_text="\\textbf{First} and \\textbf{second} bold, \\textit{first} and \\textit{second} italic.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(
|
self.assertEqual(result["page_data"].natural_text, "**First** and **second** bold, *first* and *second* italic.")
|
||||||
result["page_data"].natural_text,
|
|
||||||
"**First** and **second** bold, *first* and *second* italic."
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_latex_in_parenthesis_delimiter(self):
|
def test_latex_in_parenthesis_delimiter(self):
|
||||||
"""Test LaTeX preserved in \\(...\\) math delimiter."""
|
"""Test LaTeX preserved in \\(...\\) math delimiter."""
|
||||||
@ -447,7 +444,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="Text \\( \\textbf{math} \\) more text \\textbf{bold}.",
|
natural_text="Text \\( \\textbf{math} \\) more text \\textbf{bold}.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "Text \\( \\textbf{math} \\) more text **bold**.")
|
self.assertEqual(result["page_data"].natural_text, "Text \\( \\textbf{math} \\) more text **bold**.")
|
||||||
|
|
||||||
@ -463,7 +460,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="Text \\[ \\textit{math} \\] more text \\textit{italic}.",
|
natural_text="Text \\[ \\textit{math} \\] more text \\textit{italic}.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "Text \\[ \\textit{math} \\] more text *italic*.")
|
self.assertEqual(result["page_data"].natural_text, "Text \\[ \\textit{math} \\] more text *italic*.")
|
||||||
|
|
||||||
@ -479,14 +476,14 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text="Plain text without any formatting.",
|
natural_text="Plain text without any formatting.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "Plain text without any formatting.")
|
self.assertEqual(result["page_data"].natural_text, "Plain text without any formatting.")
|
||||||
|
|
||||||
def test_no_page_data(self):
|
def test_no_page_data(self):
|
||||||
"""Test handling of samples without page_data."""
|
"""Test handling of samples without page_data."""
|
||||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result, sample)
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
@ -502,10 +499,10 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
|
|||||||
natural_text=None,
|
natural_text=None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertIsNone(result["page_data"].natural_text)
|
self.assertIsNone(result["page_data"].natural_text)
|
||||||
|
|
||||||
def test_complex_latex_with_parenthesis_delimiters(self):
|
def test_complex_latex_with_parenthesis_delimiters(self):
|
||||||
"""Test complex LaTeX text with \\(...\\) delimiters and textit."""
|
"""Test complex LaTeX text with \\(...\\) delimiters and textit."""
|
||||||
input_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx
|
input_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx
|
||||||
@ -517,7 +514,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
|
|||||||
\\[
|
\\[
|
||||||
\\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2},
|
\\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2},
|
||||||
\\]"""
|
\\]"""
|
||||||
|
|
||||||
expected_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx
|
expected_text = """= a_0 \\int_0^P \\cos \\frac{2m\\pi x}{P} dx
|
||||||
+ \\sum_{n=1}^{\\infty} \\frac{a_n}{2} \\int_0^P \\cos \\frac{2(m+n)\\pi x}{P} + \\cos \\frac{2(m-n)\\pi x}{P} dx
|
+ \\sum_{n=1}^{\\infty} \\frac{a_n}{2} \\int_0^P \\cos \\frac{2(m+n)\\pi x}{P} + \\cos \\frac{2(m-n)\\pi x}{P} dx
|
||||||
+ b_n \\int_0^P \\sin \\frac{2(m+n)\\pi x}{P} - \\sin \\frac{2(m-n)\\pi x}{P} dx.
|
+ b_n \\int_0^P \\sin \\frac{2(m+n)\\pi x}{P} - \\sin \\frac{2(m-n)\\pi x}{P} dx.
|
||||||
@ -527,7 +524,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
|
|||||||
\\[
|
\\[
|
||||||
\\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2},
|
\\int_0^P \\cos \\frac{2m\\pi x}{P} f(x) dx = \\frac{a_m P}{2},
|
||||||
\\]"""
|
\\]"""
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"page_data": PageResponse(
|
"page_data": PageResponse(
|
||||||
primary_language="en",
|
primary_language="en",
|
||||||
@ -538,7 +535,7 @@ Since \\( m \\) and \\( n \\) are both positive integers we have seen already th
|
|||||||
natural_text=input_text,
|
natural_text=input_text,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.reformatter(sample)
|
result = self.reformatter(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, expected_text)
|
self.assertEqual(result["page_data"].natural_text, expected_text)
|
||||||
|
|
||||||
@ -562,7 +559,7 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
|
|||||||
natural_text="Some text",
|
natural_text="Some text",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample)
|
result = self.filter(sample)
|
||||||
self.assertIsNotNone(result, "Should pass through documents with valid rotation")
|
self.assertIsNotNone(result, "Should pass through documents with valid rotation")
|
||||||
|
|
||||||
@ -578,7 +575,7 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
|
|||||||
natural_text="Some text",
|
natural_text="Some text",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample)
|
result = self.filter(sample)
|
||||||
self.assertIsNone(result, "Should filter out documents with invalid rotation")
|
self.assertIsNone(result, "Should filter out documents with invalid rotation")
|
||||||
|
|
||||||
@ -594,14 +591,14 @@ class TestFilterOutRotatedDocuments(unittest.TestCase):
|
|||||||
natural_text="Some text",
|
natural_text="Some text",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.filter(sample)
|
result = self.filter(sample)
|
||||||
self.assertIsNone(result, "Should filter out documents with non-zero rotation correction")
|
self.assertIsNone(result, "Should filter out documents with non-zero rotation correction")
|
||||||
|
|
||||||
def test_no_page_data(self):
|
def test_no_page_data(self):
|
||||||
"""Test that samples without page_data pass through."""
|
"""Test that samples without page_data pass through."""
|
||||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||||
|
|
||||||
result = self.filter(sample)
|
result = self.filter(sample)
|
||||||
self.assertIsNotNone(result, "Should pass through samples without page_data")
|
self.assertIsNotNone(result, "Should pass through samples without page_data")
|
||||||
|
|
||||||
@ -625,7 +622,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
|||||||
natural_text="The equation $x^2 + y^2 = z^2$ is famous.",
|
natural_text="The equation $x^2 + y^2 = z^2$ is famous.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.normalizer(sample)
|
result = self.normalizer(sample)
|
||||||
expected_text = "The equation \\(x^2 + y^2 = z^2\\) is famous."
|
expected_text = "The equation \\(x^2 + y^2 = z^2\\) is famous."
|
||||||
self.assertEqual(result["page_data"].natural_text, expected_text)
|
self.assertEqual(result["page_data"].natural_text, expected_text)
|
||||||
@ -642,7 +639,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
|||||||
natural_text="Display equation:\n$$\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}$$",
|
natural_text="Display equation:\n$$\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}$$",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.normalizer(sample)
|
result = self.normalizer(sample)
|
||||||
expected_text = "Display equation:\n\\[\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}\\]"
|
expected_text = "Display equation:\n\\[\\int_0^\\infty e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}\\]"
|
||||||
self.assertEqual(result["page_data"].natural_text, expected_text)
|
self.assertEqual(result["page_data"].natural_text, expected_text)
|
||||||
@ -659,7 +656,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
|||||||
natural_text="Inline $a + b$ and display:\n$$c^2 = a^2 + b^2$$\nMore inline $x = y$.",
|
natural_text="Inline $a + b$ and display:\n$$c^2 = a^2 + b^2$$\nMore inline $x = y$.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.normalizer(sample)
|
result = self.normalizer(sample)
|
||||||
expected_text = "Inline \\(a + b\\) and display:\n\\[c^2 = a^2 + b^2\\]\nMore inline \\(x = y\\)."
|
expected_text = "Inline \\(a + b\\) and display:\n\\[c^2 = a^2 + b^2\\]\nMore inline \\(x = y\\)."
|
||||||
self.assertEqual(result["page_data"].natural_text, expected_text)
|
self.assertEqual(result["page_data"].natural_text, expected_text)
|
||||||
@ -676,7 +673,7 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
|||||||
natural_text="Regular text without any equations.",
|
natural_text="Regular text without any equations.",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.normalizer(sample)
|
result = self.normalizer(sample)
|
||||||
self.assertEqual(result["page_data"].natural_text, "Regular text without any equations.")
|
self.assertEqual(result["page_data"].natural_text, "Regular text without any equations.")
|
||||||
|
|
||||||
@ -692,14 +689,14 @@ class TestLatexBracketNormalizer(unittest.TestCase):
|
|||||||
natural_text=None,
|
natural_text=None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.normalizer(sample)
|
result = self.normalizer(sample)
|
||||||
self.assertIsNone(result["page_data"].natural_text)
|
self.assertIsNone(result["page_data"].natural_text)
|
||||||
|
|
||||||
def test_no_page_data(self):
|
def test_no_page_data(self):
|
||||||
"""Test handling of missing page_data."""
|
"""Test handling of missing page_data."""
|
||||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||||
|
|
||||||
result = self.normalizer(sample)
|
result = self.normalizer(sample)
|
||||||
self.assertEqual(result, sample)
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
@ -712,7 +709,7 @@ class TestFrontMatterParser(unittest.TestCase):
|
|||||||
self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse)
|
self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse)
|
||||||
self.parser_without_class = FrontMatterParser(front_matter_class=None)
|
self.parser_without_class = FrontMatterParser(front_matter_class=None)
|
||||||
|
|
||||||
@patch.object(Path, 'read_text')
|
@patch.object(Path, "read_text")
|
||||||
def test_parse_yaml_front_matter(self, mock_read_text):
|
def test_parse_yaml_front_matter(self, mock_read_text):
|
||||||
"""Test parsing of YAML front matter."""
|
"""Test parsing of YAML front matter."""
|
||||||
mock_read_text.return_value = """---
|
mock_read_text.return_value = """---
|
||||||
@ -724,27 +721,27 @@ is_diagram: false
|
|||||||
---
|
---
|
||||||
This is the document content.
|
This is the document content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||||
result = self.parser_with_class(sample)
|
result = self.parser_with_class(sample)
|
||||||
|
|
||||||
self.assertIn("page_data", result)
|
self.assertIn("page_data", result)
|
||||||
self.assertIsInstance(result["page_data"], PageResponse)
|
self.assertIsInstance(result["page_data"], PageResponse)
|
||||||
self.assertEqual(result["page_data"].primary_language, "en")
|
self.assertEqual(result["page_data"].primary_language, "en")
|
||||||
self.assertEqual(result["page_data"].natural_text, "This is the document content.")
|
self.assertEqual(result["page_data"].natural_text, "This is the document content.")
|
||||||
|
|
||||||
@patch.object(Path, 'read_text')
|
@patch.object(Path, "read_text")
|
||||||
def test_no_front_matter(self, mock_read_text):
|
def test_no_front_matter(self, mock_read_text):
|
||||||
"""Test handling of documents without front matter."""
|
"""Test handling of documents without front matter."""
|
||||||
mock_read_text.return_value = "Just regular content without front matter."
|
mock_read_text.return_value = "Just regular content without front matter."
|
||||||
|
|
||||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||||
|
|
||||||
# Should raise an error when front_matter_class is specified
|
# Should raise an error when front_matter_class is specified
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.parser_with_class(sample)
|
self.parser_with_class(sample)
|
||||||
|
|
||||||
@patch.object(Path, 'read_text')
|
@patch.object(Path, "read_text")
|
||||||
def test_malformed_yaml(self, mock_read_text):
|
def test_malformed_yaml(self, mock_read_text):
|
||||||
"""Test handling of malformed YAML."""
|
"""Test handling of malformed YAML."""
|
||||||
mock_read_text.return_value = """---
|
mock_read_text.return_value = """---
|
||||||
@ -753,14 +750,14 @@ is_rotation_valid: [this is not valid yaml}
|
|||||||
---
|
---
|
||||||
Content
|
Content
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sample = {"markdown_path": Path("/path/to/file.md")}
|
sample = {"markdown_path": Path("/path/to/file.md")}
|
||||||
|
|
||||||
# Parser without class should return empty dict for malformed YAML
|
# Parser without class should return empty dict for malformed YAML
|
||||||
result = self.parser_without_class(sample)
|
result = self.parser_without_class(sample)
|
||||||
self.assertEqual(result["page_data"], {})
|
self.assertEqual(result["page_data"], {})
|
||||||
|
|
||||||
@patch.object(Path, 'read_text')
|
@patch.object(Path, "read_text")
|
||||||
def test_preserve_existing_markdown_content(self, mock_read_text):
|
def test_preserve_existing_markdown_content(self, mock_read_text):
|
||||||
"""Test that existing markdown_content is preserved if present."""
|
"""Test that existing markdown_content is preserved if present."""
|
||||||
sample = {
|
sample = {
|
||||||
@ -772,16 +769,16 @@ rotation_correction: 0
|
|||||||
is_table: true
|
is_table: true
|
||||||
is_diagram: false
|
is_diagram: false
|
||||||
---
|
---
|
||||||
French content."""
|
French content.""",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Should not call read_text since markdown_content exists
|
# Should not call read_text since markdown_content exists
|
||||||
result = self.parser_with_class(sample)
|
result = self.parser_with_class(sample)
|
||||||
mock_read_text.assert_not_called()
|
mock_read_text.assert_not_called()
|
||||||
|
|
||||||
self.assertEqual(result["page_data"].primary_language, "fr")
|
self.assertEqual(result["page_data"].primary_language, "fr")
|
||||||
self.assertEqual(result["page_data"].is_table, True)
|
self.assertEqual(result["page_data"].is_table, True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user