This commit is contained in:
Jake Poznanski 2025-08-19 21:30:41 +00:00
parent 768cb33937
commit 41201b6317
6 changed files with 419 additions and 452 deletions

View File

@ -8,16 +8,19 @@ from olmocr.bench.prompts import (
build_basic_prompt,
build_openai_silver_data_prompt_no_document_anchoring,
)
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
from olmocr.data.renderpdf import (
get_png_dimensions_from_base64,
render_pdf_to_base64png,
)
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import (
PageResponse,
build_finetuning_prompt,
build_openai_silver_data_prompt,
openai_response_format_schema,
build_openai_silver_data_prompt_v2,
build_openai_silver_data_prompt_v2_simple,
build_openai_silver_data_prompt_v3_simple,
openai_response_format_schema,
)

View File

@ -8,14 +8,17 @@ and generates OpenAI batch API requests for processing PDFs.
import argparse
import json
import os
from pathlib import Path
from typing import Generator, Dict, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple
from pypdf import PdfReader
from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64png, get_png_dimensions_from_base64
from olmocr.data.renderpdf import (
get_png_dimensions_from_base64,
render_pdf_to_base64png,
)
from olmocr.prompts.prompts import (
build_openai_silver_data_prompt_v3_simple,
openai_response_format_schema,
@ -60,7 +63,7 @@ def build_custom_id(pdf_path: Path, base_dir: Path) -> str:
# Get relative path from base directory
rel_path = pdf_path.relative_to(base_dir)
# Remove .pdf extension but keep directory structure
path_without_ext = str(rel_path).replace('.pdf', '')
path_without_ext = str(rel_path).replace(".pdf", "")
return path_without_ext
@ -151,12 +154,7 @@ def find_pdf_files(input_dir: Path) -> Generator[Path, None, None]:
yield pdf_path
def process_pdfs_to_batch_requests(
input_dir: Path,
output_dir: Path,
max_pdfs: int = None,
num_workers: int = 8
) -> int:
def process_pdfs_to_batch_requests(input_dir: Path, output_dir: Path, max_pdfs: int = None, num_workers: int = 8) -> int:
"""
Process PDFs and create batch request files using parallel processing.
@ -196,10 +194,7 @@ def process_pdfs_to_batch_requests(
# Process PDFs in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all PDF processing tasks
future_to_pdf = {
executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path
for pdf_path in pdf_files
}
future_to_pdf = {executor.submit(process_single_pdf, pdf_path, input_dir): pdf_path for pdf_path in pdf_files}
# Process results as they complete
with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar:
@ -252,31 +247,14 @@ def process_pdfs_to_batch_requests(
def main():
parser = argparse.ArgumentParser(
description="Build OpenAI batch requests from OLMoCR-mix folder structure"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Output directory for batch request files (default: input_dir/batch_requests)"
)
parser.add_argument(
"--max_pdfs",
type=int,
default=None,
help="Maximum number of PDFs to process (default: all)"
)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="Number of parallel workers for processing (default: 8)"
)
parser = argparse.ArgumentParser(description="Build OpenAI batch requests from OLMoCR-mix folder structure")
parser.add_argument("--output_dir", type=str, default=None, help="Output directory for batch request files (default: input_dir/batch_requests)")
parser.add_argument("--max_pdfs", type=int, default=None, help="Maximum number of PDFs to process (default: all)")
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
parser.add_argument(
"input_dir",
type=str,
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)"
help="Input directory containing processed folder structure (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf or ~/olmOCR-mix-0225)",
)
args = parser.parse_args()
@ -298,12 +276,7 @@ def main():
print(f"Output directory: {output_dir}")
# Process PDFs
process_pdfs_to_batch_requests(
input_dir=input_dir,
output_dir=output_dir,
max_pdfs=args.max_pdfs,
num_workers=args.num_workers
)
process_pdfs_to_batch_requests(input_dir=input_dir, output_dir=output_dir, max_pdfs=args.max_pdfs, num_workers=args.num_workers)
return 0

View File

@ -8,11 +8,12 @@ that mirrors the original structure with side-by-side PDF and MD files.
import argparse
import json
import shutil
import re
from pathlib import Path
from typing import Dict, Any, Optional
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, Optional
from tqdm import tqdm
@ -39,10 +40,7 @@ def parse_batch_response(response_line: str) -> Optional[Dict[str, Any]]:
content = body["choices"][0]["message"]["content"]
# Parse the JSON response
parsed_content = json.loads(content)
return {
"custom_id": custom_id,
"content": parsed_content
}
return {"custom_id": custom_id, "content": parsed_content}
else:
print(f"Error in response for {custom_id}: {data.get('error', 'Unknown error')}")
return None
@ -86,12 +84,7 @@ def format_frontmatter_markdown(response_data: Dict[str, Any]) -> str:
return markdown.strip()
def process_single_result(
custom_id: str,
response_content: Dict[str, Any],
original_pdf_dir: Path,
output_dir: Path
) -> bool:
def process_single_result(custom_id: str, response_content: Dict[str, Any], original_pdf_dir: Path, output_dir: Path) -> bool:
"""
Process a single batch result: copy PDF and create MD file.
@ -114,8 +107,8 @@ def process_single_result(
print(f"Warning: Original PDF not found: {original_pdf_path}")
original_pdf_path = str(original_pdf_path)
pattern = r'(.+?)(-\d+)\.pdf$'
replacement = r'\1.pdf\2.pdf'
pattern = r"(.+?)(-\d+)\.pdf$"
replacement = r"\1.pdf\2.pdf"
original_pdf_path = Path(re.sub(pattern, replacement, original_pdf_path))
@ -145,12 +138,7 @@ def process_single_result(
return False
def process_batch_results(
batch_results_dir: Path,
original_pdf_dir: Path,
output_dir: Path,
num_workers: int = 8
) -> int:
def process_batch_results(batch_results_dir: Path, original_pdf_dir: Path, output_dir: Path, num_workers: int = 8) -> int:
"""
Process all batch result files and create output structure.
@ -198,13 +186,7 @@ def process_batch_results(
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all processing tasks
future_to_result = {
executor.submit(
process_single_result,
result["custom_id"],
result["content"],
original_pdf_dir,
output_dir
): result["custom_id"]
executor.submit(process_single_result, result["custom_id"], result["content"], original_pdf_dir, output_dir): result["custom_id"]
for result in results_to_process
}
@ -234,30 +216,13 @@ def process_batch_results(
def main():
parser = argparse.ArgumentParser(
description="Process OpenAI batch results and create output folder with PDFs and Markdown files"
)
parser = argparse.ArgumentParser(description="Process OpenAI batch results and create output folder with PDFs and Markdown files")
parser.add_argument("batch_results_dir", type=str, help="Directory containing completed OpenAI batch result files (JSONL)")
parser.add_argument(
"batch_results_dir",
type=str,
help="Directory containing completed OpenAI batch result files (JSONL)"
)
parser.add_argument(
"original_pdf_dir",
type=str,
help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
)
parser.add_argument(
"output_dir",
type=str,
help="Output directory for processed results with PDFs and MD files"
)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="Number of parallel workers for processing (default: 8)"
"original_pdf_dir", type=str, help="Directory containing original PDF files (e.g., ~/olmOCR-mix-0225/processed_00_documents_eval_s2pdf)"
)
parser.add_argument("output_dir", type=str, help="Output directory for processed results with PDFs and MD files")
parser.add_argument("--num_workers", type=int, default=8, help="Number of parallel workers for processing (default: 8)")
args = parser.parse_args()
@ -280,12 +245,7 @@ def main():
print(f"Output directory: {output_dir}")
# Process the batch results
process_batch_results(
batch_results_dir=batch_results_dir,
original_pdf_dir=original_pdf_dir,
output_dir=output_dir,
num_workers=args.num_workers
)
process_batch_results(batch_results_dir=batch_results_dir, original_pdf_dir=original_pdf_dir, output_dir=output_dir, num_workers=args.num_workers)
return 0

View File

@ -16,6 +16,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
def build_openai_silver_data_prompt_v2(base_text: str) -> str:
return (
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). "
@ -30,6 +31,7 @@ def build_openai_silver_data_prompt_v2(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int) -> str:
return (
f"Attached is the image of one page of a PDF document."
@ -44,6 +46,7 @@ def build_openai_silver_data_prompt_v2_simple(page_width: int, page_height: int)
f"Page width: {page_width}, Page height: {page_height}"
)
def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int) -> str:
return (
f"Attached is the image of one page of a PDF document."
@ -60,7 +63,6 @@ def build_openai_silver_data_prompt_v3_simple(page_width: int, page_height: int)
)
@dataclass(frozen=True)
class PageResponse:
primary_language: Optional[str]

View File

@ -1,10 +1,14 @@
import argparse
import base64
import json
import logging
import multiprocessing
import re
import shutil
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, replace
from html.parser import HTMLParser
from io import BytesIO
from os import PathLike
from pathlib import Path
@ -419,8 +423,6 @@ class LatexBracketNormalizer(PipelineStep):
# Update the page_data with normalized text
# Since PageResponse is frozen, we need to create a new instance
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse(
primary_language=page_data.primary_language,
is_rotation_valid=page_data.is_rotation_valid,
@ -482,8 +484,6 @@ class RotationAugmentation(PipelineStep):
else: # 270
correction = 90
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse(
primary_language=page_data.primary_language,
is_rotation_valid=False, # Mark as invalid since we rotated it
@ -539,20 +539,20 @@ class DatasetTextRuleFilter(PipelineStep):
# Look for pipe-separated table patterns
# Markdown tables have lines like: | col1 | col2 | col3 |
# And separator lines like: |------|------|------|
lines = text.split('\n')
lines = text.split("\n")
for i, line in enumerate(lines):
line = line.strip()
# Check if line looks like a table row
if line.startswith('|') and line.endswith('|') and line.count('|') >= 3:
if line.startswith("|") and line.endswith("|") and line.count("|") >= 3:
# Check if next line is a separator (for header rows)
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
if next_line.startswith('|') and '-' in next_line:
if next_line.startswith("|") and "-" in next_line:
return True
# Check if previous line is a separator (for data rows)
if i > 0:
prev_line = lines[i - 1].strip()
if prev_line.startswith('|') and '-' in prev_line:
if prev_line.startswith("|") and "-" in prev_line:
return True
return False
@ -563,30 +563,74 @@ class DatasetTextRuleFilter(PipelineStep):
True if text contains any of the specified math symbols outside tables
False otherwise
"""
import re
# List of mathematical symbols to check for
math_symbols = [
# Set theory and logic
'', '', '', '', '', '', '', '', '', '', '', '¬',
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"¬",
# Common mathematical operators
'', '', '',
"",
"",
"",
# Calculus and analysis
'', '', '', '', '', '', '', '', '', '', '', '',
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
# Arrows and relations
'',
"",
# Other common math symbols
'', '', '', '', '', '', '', '', '', '', '', '', '', '', '',
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
# Matrix and vector notation
'', '', '', '', '', '', '', '', '',
"",
"",
"",
"",
"",
"",
"",
"",
"",
]
# First, remove all HTML tables from the text
text_without_tables = text
# Remove HTML tables
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
text_without_tables = table_pattern.sub('', text_without_tables)
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
text_without_tables = table_pattern.sub("", text_without_tables)
# Now check if any of these symbols appear in the text without tables
for symbol in math_symbols:
@ -606,8 +650,8 @@ class DatasetTextRuleFilter(PipelineStep):
# Check for various LaTeX table environments
latex_table_patterns = [
r'\\begin\{table\}',
r'\\begin\{tabular\}',
r"\\begin\{table\}",
r"\\begin\{tabular\}",
]
# Check if any LaTeX table pattern exists in the text
@ -629,47 +673,42 @@ class DatasetTextRuleFilter(PipelineStep):
# List of common LaTeX formatting commands to check for
latex_commands = [
# Lists & basic content
r'\begin{itemize}',
r'\begin{enumerate}',
r'\item',
r"\begin{itemize}",
r"\begin{enumerate}",
r"\item",
# Figures, tables, and captions
r'\begin{figure}',
r'\includegraphics',
r'\caption',
r'\label',
r'\ref',
r'\eqref',
r'\begin{table}',
r'\begin{tabular}',
r"\begin{figure}",
r"\includegraphics",
r"\caption",
r"\label",
r"\ref",
r"\eqref",
r"\begin{table}",
r"\begin{tabular}",
# Formatting,
# r'\textit',
# r'\textbb',
# Math (strong signals)
r'\begin{equation}',
r'\begin{align}',
r'\frac',
r'\sum',
r'\int',
r'\sqrt',
r'\prod',
r'\lim',
r'\binom',
r'\mathbb',
r'\mathcal',
r'\to',
r'\varphi',
r'\cdot',
r'\langle',
r'\rangle',
r"\begin{equation}",
r"\begin{align}",
r"\frac",
r"\sum",
r"\int",
r"\sqrt",
r"\prod",
r"\lim",
r"\binom",
r"\mathbb",
r"\mathcal",
r"\to",
r"\varphi",
r"\cdot",
r"\langle",
r"\rangle",
# Citations (bibliography stacks)
r'\cite',
r"\cite",
]
# First, remove all math equations from the text
text_without_math = text
@ -682,7 +721,7 @@ class DatasetTextRuleFilter(PipelineStep):
# Remove all math equations
for pattern in math_patterns:
text_without_math = re.sub(pattern, '', text_without_math, flags=re.DOTALL)
text_without_math = re.sub(pattern, "", text_without_math, flags=re.DOTALL)
# Check if any LaTeX commands appear in the remaining text
for command in latex_commands:
@ -730,7 +769,7 @@ class DatasetTextRuleFilter(PipelineStep):
rendered = render_equation(equation)
# Check if there was an error
if rendered is None or (hasattr(rendered, 'error') and rendered.error):
if rendered is None or (hasattr(rendered, "error") and rendered.error):
# Equation failed to render
logger.warning(f"Could not render equation '{repr(equation)}', skipping sample")
return False
@ -757,22 +796,22 @@ class DatasetTextRuleFilter(PipelineStep):
import re
# Check if there are any tables in the text
if '<table' not in text.lower() or '<br' not in text.lower():
if "<table" not in text.lower() or "<br" not in text.lower():
return False # No tables or no <br> tags at all
# Pattern to find HTML tables (case-insensitive)
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
tables = table_pattern.findall(text)
# Check each table for <br> tags in cells
for table_html in tables:
# Pattern to find table cells (td and th tags)
cell_pattern = re.compile(r'<(td|th)\b[^>]*>(.*?)</\1>', re.IGNORECASE | re.DOTALL)
cell_pattern = re.compile(r"<(td|th)\b[^>]*>(.*?)</\1>", re.IGNORECASE | re.DOTALL)
cells = cell_pattern.findall(table_html)
for tag_type, cell_content in cells:
# Check if cell content contains <br> tags (any variation)
if re.search(r'<br\s*/?>', cell_content, re.IGNORECASE):
if re.search(r"<br\s*/?>", cell_content, re.IGNORECASE):
return True
return False
@ -788,17 +827,17 @@ class DatasetTextRuleFilter(PipelineStep):
import re
# Check if there are any <table> tags at all
if '<table' not in text.lower():
if "<table" not in text.lower():
return True # No tables, that's fine
# Pattern to find HTML tables (case-insensitive)
# Note: This pattern might not catch malformed tables where </table> is missing
table_pattern = re.compile(r'<table\b[^>]*>.*?</table>', re.IGNORECASE | re.DOTALL)
table_pattern = re.compile(r"<table\b[^>]*>.*?</table>", re.IGNORECASE | re.DOTALL)
tables = table_pattern.findall(text)
# Also check for unclosed table tags
table_open_count = len(re.findall(r'<table\b[^>]*>', text, re.IGNORECASE))
table_close_count = len(re.findall(r'</table>', text, re.IGNORECASE))
table_open_count = len(re.findall(r"<table\b[^>]*>", text, re.IGNORECASE))
table_close_count = len(re.findall(r"</table>", text, re.IGNORECASE))
if table_open_count != table_close_count:
return False # Mismatched table tags
@ -808,7 +847,6 @@ class DatasetTextRuleFilter(PipelineStep):
return False
# Try to parse each table
from html.parser import HTMLParser
class TableValidator(HTMLParser):
def __init__(self):
@ -877,25 +915,25 @@ class DatasetTextRuleFilter(PipelineStep):
if text is None:
return sample
# Check for markdown tables
if self._contains_markdown_table(text):
return None # Filter out samples with markdown tables
# # Check for markdown tables
# if self._contains_markdown_table(text):
# return None # Filter out samples with markdown tables
# Check for HTML tables and validate them
if not self._extract_and_validate_html_tables(text):
return None # Filter out samples with malformed HTML tables
# # Check for HTML tables and validate them
# if not self._extract_and_validate_html_tables(text):
# return None # Filter out samples with malformed HTML tables
# Check for <br> tags in table cells
if self._contains_br_in_table_cells(text):
return None # Filter out samples with <br> tags in table cells
# # Check for <br> tags in table cells
# if self._contains_br_in_table_cells(text):
# return None # Filter out samples with <br> tags in table cells
# Check if all math equations can render without errors
if not self._validate_math_equations(text):
return None # Filter out samples with invalid math equations
# # Check if all math equations can render without errors
# if not self._validate_math_equations(text):
# return None # Filter out samples with invalid math equations
# Check for mathematical symbols
if self._contains_math_symbols(text):
return None # Filter out samples with mathematical symbols
# # Check for mathematical symbols
# if self._contains_math_symbols(text):
# return None # Filter out samples with mathematical symbols
# Check for LaTeX formatting outside math equations
if self._contains_latex_formatting_outside_math(text):
@ -958,7 +996,8 @@ class ReformatLatexBoldItalic(PipelineStep):
def replace_latex_command(text, command, markdown):
"""Replace LaTeX command with markdown, handling nested braces."""
import re
pattern = r'\\' + command + r'\{'
pattern = r"\\" + command + r"\{"
result = []
i = 0
@ -977,9 +1016,9 @@ class ReformatLatexBoldItalic(PipelineStep):
j = start_pos
while j < len(text) and brace_count > 0:
if text[j] == '{':
if text[j] == "{":
brace_count += 1
elif text[j] == '}':
elif text[j] == "}":
brace_count -= 1
j += 1
@ -993,20 +1032,20 @@ class ReformatLatexBoldItalic(PipelineStep):
result.append(text[i + match.start() : i + match.end()])
i = i + match.end()
return ''.join(result)
return "".join(result)
# Handle \textbf{...} -> **...**
preserved_text = replace_latex_command(preserved_text, 'textbf', '**')
preserved_text = replace_latex_command(preserved_text, "textbf", "**")
# Handle \textit{...} -> *...*
preserved_text = replace_latex_command(preserved_text, 'textit', '*')
preserved_text = replace_latex_command(preserved_text, "textit", "*")
# Restore math equations
for placeholder, original in math_placeholders:
preserved_text = preserved_text.replace(placeholder, original)
# Create a new PageResponse with the updated text (since it's frozen)
from dataclasses import replace
updated_page_data = replace(page_data, natural_text=preserved_text)
sample["page_data"] = updated_page_data
@ -1414,8 +1453,8 @@ if __name__ == "__main__":
if sample is None:
# This sample was filtered out - get the original paths
original_sample = dataset_samples[idx]
md_path = original_sample['markdown_path']
pdf_path = original_sample['pdf_path']
md_path = original_sample["markdown_path"]
pdf_path = original_sample["pdf_path"]
save_dir = Path(save_dir_str)
@ -1433,11 +1472,7 @@ if __name__ == "__main__":
target_pdf = target_subdir / pdf_path.name
shutil.copy2(pdf_path, target_pdf)
return {
'index': idx,
'markdown_path': str(md_path),
'pdf_path': str(pdf_path)
}
return {"index": idx, "markdown_path": str(md_path), "pdf_path": str(pdf_path)}
return None
except Exception as e:
print(f"Error processing sample {idx}: {e}")
@ -1449,10 +1484,7 @@ if __name__ == "__main__":
with ProcessPoolExecutor(max_workers=8) as executor:
# Submit all tasks
futures = {
executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx
for idx in range(len(dataset))
}
futures = {executor.submit(process_and_copy_sample, idx, dataset.samples, str(save_dir)): idx for idx in range(len(dataset))}
# Process results with progress bar
with tqdm(total=len(dataset), desc="Processing samples") as pbar:
@ -1463,14 +1495,14 @@ if __name__ == "__main__":
pbar.update(1)
# Sort filtered samples by index for consistent output
filtered_samples.sort(key=lambda x: x['index'])
filtered_samples.sort(key=lambda x: x["index"])
print(f"\nFound and copied {len(filtered_samples)} filtered samples to: {save_dir}")
if filtered_samples:
print(f"First 10 filtered samples:")
for i, sample_info in enumerate(filtered_samples[:10]):
md_name = Path(sample_info['markdown_path']).name
md_name = Path(sample_info["markdown_path"]).name
print(f" Sample {sample_info['index']}: {md_name}")
if len(filtered_samples) > 10:
print(f" ... and {len(filtered_samples) - 10} more")

View File

@ -430,10 +430,7 @@ class TestReformatLatexBoldItalic(unittest.TestCase):
}
result = self.reformatter(sample)
self.assertEqual(
result["page_data"].natural_text,
"**First** and **second** bold, *first* and *second* italic."
)
self.assertEqual(result["page_data"].natural_text, "**First** and **second** bold, *first* and *second* italic.")
def test_latex_in_parenthesis_delimiter(self):
"""Test LaTeX preserved in \\(...\\) math delimiter."""
@ -712,7 +709,7 @@ class TestFrontMatterParser(unittest.TestCase):
self.parser_with_class = FrontMatterParser(front_matter_class=PageResponse)
self.parser_without_class = FrontMatterParser(front_matter_class=None)
@patch.object(Path, 'read_text')
@patch.object(Path, "read_text")
def test_parse_yaml_front_matter(self, mock_read_text):
"""Test parsing of YAML front matter."""
mock_read_text.return_value = """---
@ -733,7 +730,7 @@ This is the document content.
self.assertEqual(result["page_data"].primary_language, "en")
self.assertEqual(result["page_data"].natural_text, "This is the document content.")
@patch.object(Path, 'read_text')
@patch.object(Path, "read_text")
def test_no_front_matter(self, mock_read_text):
"""Test handling of documents without front matter."""
mock_read_text.return_value = "Just regular content without front matter."
@ -744,7 +741,7 @@ This is the document content.
with self.assertRaises(ValueError):
self.parser_with_class(sample)
@patch.object(Path, 'read_text')
@patch.object(Path, "read_text")
def test_malformed_yaml(self, mock_read_text):
"""Test handling of malformed YAML."""
mock_read_text.return_value = """---
@ -760,7 +757,7 @@ Content
result = self.parser_without_class(sample)
self.assertEqual(result["page_data"], {})
@patch.object(Path, 'read_text')
@patch.object(Path, "read_text")
def test_preserve_existing_markdown_content(self, mock_read_text):
"""Test that existing markdown_content is preserved if present."""
sample = {
@ -772,7 +769,7 @@ rotation_correction: 0
is_table: true
is_diagram: false
---
French content."""
French content.""",
}
# Should not call read_text since markdown_content exists