Better equation rendering checker with more tests.

This commit is contained in:
Jake Poznanski 2025-03-26 18:49:48 +00:00
parent b8e3034847
commit d45c0323a4
2 changed files with 197 additions and 75 deletions

View File

@ -284,7 +284,7 @@ def render_equation(
page.wait_for_selector(".katex", state="attached")
if debug_dom:
if True:
katex_dom_html = page.evaluate(
"""
() => {
@ -394,8 +394,18 @@ def compare_rendered_equations(reference: RenderedEquation, hypothesis: Rendered
def expand_span_info(span_info: SpanInfo) -> list[SpanInfo]:
total_elems = len(span_info.text)
return [SpanInfo(c, BoundingBox(span_info.bounding_box.x + (span_info.bounding_box.width * index) / total_elems, span_info.bounding_box.y,
span_info.bounding_box.width / total_elems, span_info.bounding_box.height)) for index, c in enumerate(span_info.text)]
return [
SpanInfo(
c,
BoundingBox(
span_info.bounding_box.x + (span_info.bounding_box.width * index) / total_elems,
span_info.bounding_box.y,
span_info.bounding_box.width / total_elems,
span_info.bounding_box.height,
),
)
for index, c in enumerate(span_info.text)
]
H = [span for sublist in H for span in expand_span_info(sublist)]
R = [span for sublist in R for span in expand_span_info(sublist)]
@ -560,6 +570,113 @@ class TestRenderedEquationComparison(unittest.TestCase):
align_rendered = render_equation("u \\in\\left(R / \\operatorname{Ann}_{R}\\left(x_{i}\\right)\\right)^{\\times}")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_fraction_vs_divided_by(self):
eq1 = render_equation("\\frac{a}{b}", use_cache=False)
eq2 = render_equation("a / b", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_different_bracket_types(self):
eq1 = render_equation("\\left[ a + b \\right]", use_cache=False)
eq2 = render_equation("\\left\\{ a + b \\right\\}", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_inline_vs_display_style_fraction(self):
eq1 = render_equation("\\frac{1}{2}", use_cache=False)
eq2 = render_equation("\\displaystyle\\frac{1}{2}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_matrix_equivalent_forms(self):
eq1 = render_equation("\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}", use_cache=False)
eq2 = render_equation("\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_different_matrix_types(self):
eq1 = render_equation("\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}", use_cache=False)
eq2 = render_equation("\\begin{bmatrix} a & b \\\\ c & d \\end{bmatrix}", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_thinspace_vs_regular_space(self):
eq1 = render_equation("a \\, b", use_cache=False)
eq2 = render_equation("a \\: b", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
@unittest.skip("Currently these compare to the same thing, because they use the symbol 'x' with a different span class and thus font")
def test_mathbf_vs_boldsymbol(self):
eq1 = render_equation("\\mathbf{x}", use_cache=False)
eq2 = render_equation("\\boldsymbol{x}", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_tensor_notation_equivalent(self):
eq1 = render_equation("T_{ij}^{kl}", use_cache=False)
eq2 = render_equation("T^{kl}_{ij}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_partial_derivative_forms(self):
eq1 = render_equation("\\frac{\\partial f}{\\partial x}", use_cache=False)
eq2 = render_equation("\\frac{\\partial_f}{\\partial_x}", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_equivalent_sin_forms_diff_parens(self):
eq1 = render_equation("\\sin(\\theta)", use_cache=False)
eq2 = render_equation("\\sin \\theta", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_aligned_multiline_equation(self):
eq1 = render_equation("\\begin{align*} a &= b \\\\ c &= d \\end{align*}", use_cache=False)
eq2 = render_equation("\\begin{aligned} a &= b \\\\ c &= d \\end{aligned}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_subscript_order_invariance(self):
eq1 = render_equation("x_{i,j}", use_cache=False)
eq2 = render_equation("x_{j,i}", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_hat_vs_widehat(self):
eq1 = render_equation("\\hat{x}", use_cache=False)
eq2 = render_equation("\\widehat{x}", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_equivalent_integral_bounds(self):
eq1 = render_equation("\\int_{a}^{b} f(x) dx", use_cache=False)
eq2 = render_equation("\\int\\limits_{a}^{b} f(x) dx", use_cache=False)
# Could go either way honestly?
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_equivalent_summation_notation(self):
eq1 = render_equation("\\sum_{i=1}^{n} x_i", use_cache=False)
eq2 = render_equation("\\sum\\limits_{i=1}^{n} x_i", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_different_symbol_with_same_appearance(self):
eq1 = render_equation("\\phi", use_cache=False)
eq2 = render_equation("\\varphi", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_aligned_vs_gathered(self):
eq1 = render_equation("\\begin{aligned} a &= b \\\\ c &= d \\end{aligned}", use_cache=False)
eq2 = render_equation("\\begin{gathered} a = b \\\\ c = d \\end{gathered}", use_cache=False)
# Different whitespacing, should be invariant to that.
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_identical_but_with_color1(self):
eq1 = render_equation("a + b", use_cache=False)
eq2 = render_equation("\\color{black}{a + b}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_identical_but_with_color2(self):
eq1 = render_equation("a + b", use_cache=False)
eq2 = render_equation("\\color{black}{a} + \\color{black}{b}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
eq1 = render_equation("a + b", use_cache=False)
eq2 = render_equation("\\color{red}{a} + \\color{black}{b}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_newcommand_expansion(self):
eq1 = render_equation("\\alpha + \\beta", use_cache=False)
eq2 = render_equation("\\newcommand{\\ab}{\\alpha + \\beta}\\ab", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
if __name__ == "__main__":
unittest.main()

View File

@ -6,20 +6,23 @@
# Also take an argument args.repeats which will repeat this whole process N times
import argparse
import datetime
import json
import os
import random
import boto3
import tempfile
import datetime
import re
import sqlite3
from pathlib import Path
from tqdm import tqdm
import tempfile
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Optional
import boto3
from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64webp
from olmocr.s3_utils import parse_s3_path, get_s3_bytes
from typing import Optional, Tuple
from olmocr.s3_utils import get_s3_bytes, parse_s3_path
def parse_args():
parser = argparse.ArgumentParser(description="Scan OLMO OCR workspace results and create visual samples")
@ -29,10 +32,14 @@ def parse_args():
parser.add_argument("--pdf_profile", help="AWS profile for accessing PDFs")
parser.add_argument("--output_dir", default="dolma_samples", help="Directory to save output HTML files")
parser.add_argument("--max_workers", type=int, default=4, help="Maximum number of worker threads")
parser.add_argument("--db_path", default="~/s2pdf_url_data/d65142df-6588-4b68-a12c-d468b3761189.csv.db",
help="Path to the SQLite database containing PDF hash to URL mapping")
parser.add_argument(
"--db_path",
default="~/s2pdf_url_data/d65142df-6588-4b68-a12c-d468b3761189.csv.db",
help="Path to the SQLite database containing PDF hash to URL mapping",
)
return parser.parse_args()
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf"
match = re.match(pattern, pretty_pdf_path)
@ -40,25 +47,26 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
return match.group(1) + match.group(2)
return None
def get_original_url(pdf_hash: str, db_path: str) -> Optional[str]:
"""Look up the original URL for a PDF hash in the SQLite database."""
if not pdf_hash:
return None
try:
sqlite_db_path = os.path.expanduser(db_path)
if not os.path.exists(sqlite_db_path):
print(f"SQLite database not found at {sqlite_db_path}")
return None
conn = sqlite3.connect(sqlite_db_path)
cursor = conn.cursor()
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
result = cursor.fetchone()
conn.close()
if result:
return result[0]
return None
@ -71,106 +79,104 @@ def list_result_files(s3_client, workspace_path):
"""List all JSON result files in the workspace results directory."""
bucket, prefix = parse_s3_path(workspace_path)
results_prefix = os.path.join(prefix, "results").rstrip("/") + "/"
all_files = []
paginator = s3_client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=bucket, Prefix=results_prefix):
if "Contents" in page:
all_files.extend([f"s3://{bucket}/{obj['Key']}" for obj in page["Contents"]
if obj["Key"].endswith(".jsonl") or obj["Key"].endswith(".json")])
all_files.extend([f"s3://{bucket}/{obj['Key']}" for obj in page["Contents"] if obj["Key"].endswith(".jsonl") or obj["Key"].endswith(".json")])
if len(all_files) > 1000:
break
return all_files
def get_random_pages(s3_client, result_files, count=30):
"""Get random pages from the result files."""
random_pages = []
# Try to collect the requested number of pages
attempts = 0
max_attempts = count * 3 # Allow extra attempts to handle potential failures
while len(random_pages) < count and attempts < max_attempts:
attempts += 1
# Pick a random result file
if not result_files:
print("No result files found!")
break
result_file = random.choice(result_files)
try:
# Get the content of the file
content = get_s3_bytes(s3_client, result_file)
lines = content.decode('utf-8').strip().split('\n')
lines = content.decode("utf-8").strip().split("\n")
if not lines:
continue
# Pick a random line (which contains a complete document)
line = random.choice(lines)
doc = json.loads(line)
# A Dolma document has "text", "metadata", and "attributes" fields
if "text" not in doc or "metadata" not in doc or "attributes" not in doc:
print(f"Document in {result_file} is not a valid Dolma document")
continue
# Get the original PDF path from metadata
pdf_path = doc["metadata"].get("Source-File")
if not pdf_path:
continue
# Get page spans from attributes
page_spans = doc["attributes"].get("pdf_page_numbers", [])
if not page_spans:
continue
# Pick a random page span
page_span = random.choice(page_spans)
if len(page_span) >= 3:
# Page spans are [start_pos, end_pos, page_num]
page_num = page_span[2]
# Extract text for this page
start_pos, end_pos = page_span[0], page_span[1]
page_text = doc["text"][start_pos:end_pos].strip()
# Include the text snippet with the page info
random_pages.append((pdf_path, page_num, page_text, result_file))
if len(random_pages) >= count:
break
except Exception as e:
print(f"Error processing {result_file}: {e}")
continue
print(f"Found {len(random_pages)} random pages from Dolma documents")
return random_pages
def create_presigned_url(s3_client, pdf_path, expiration=3600*24*7):
def create_presigned_url(s3_client, pdf_path, expiration=3600 * 24 * 7):
"""Create a presigned URL for the given S3 path."""
try:
bucket, key = parse_s3_path(pdf_path)
url = s3_client.generate_presigned_url(
'get_object',
Params={'Bucket': bucket, 'Key': key},
ExpiresIn=expiration
)
url = s3_client.generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=expiration)
return url
except Exception as e:
print(f"Error creating presigned URL for {pdf_path}: {e}")
return None
def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path, db_path, resolution=2048):
"""Create an HTML file with rendered PDF pages."""
# Get current date and time for the report
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
html_content = f"""
<!DOCTYPE html>
<html lang="en">
@ -369,7 +375,7 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
<div class="page-grid">
"""
for i, (pdf_path, page_num, page_text, result_file) in enumerate(tqdm(random_pages, desc="Rendering pages")):
# Get original URL from PDF hash
pdf_hash = parse_pdf_hash(pdf_path)
@ -379,23 +385,22 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
display_path = pdf_path
if len(display_path) > 60:
display_path = "..." + display_path[-57:]
# Generate presigned URL
presigned_url = create_presigned_url(pdf_s3_client, pdf_path)
try:
# Download PDF to temp file
bucket, key = parse_s3_path(pdf_path)
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_file:
pdf_data = pdf_s3_client.get_object(Bucket=bucket, Key=key)['Body'].read()
pdf_data = pdf_s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
temp_file.write(pdf_data)
temp_file_path = temp_file.name
# Render PDF to base64 webp
try:
base64_image = render_pdf_to_base64webp(temp_file_path, page_num, resolution)
# Add to HTML
html_content += f"""
<div class="page-container">
@ -426,10 +431,10 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
</div>
</div>
"""
# Clean up temp file
os.unlink(temp_file_path)
except Exception as e:
html_content += f"""
<div class="page-container">
@ -444,7 +449,7 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
</div>
</div>
"""
html_content += """
</div>
<footer>
@ -454,60 +459,59 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
</body>
</html>
"""
with open(output_path, 'w') as f:
with open(output_path, "w") as f:
f.write(html_content)
print(f"Created HTML output at {output_path}")
def generate_sample_set(args, i, s3_client, pdf_s3_client, result_files):
"""Generate a single sample set."""
output_filename = Path(args.output_dir) / f"dolma_samples_{i+1}.html"
print(f"\nGenerating sample set {i+1} of {args.repeats}")
# Get random pages
random_pages = get_random_pages(s3_client, result_files, args.pages_per_output)
# Create HTML output
create_html_output(random_pages, pdf_s3_client, output_filename, args.workspace, args.db_path)
return output_filename
def main():
args = parse_args()
# Set up S3 clients
s3_client = boto3.client('s3')
s3_client = boto3.client("s3")
# Set up PDF S3 client with profile if specified
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3_client = pdf_session.client("s3")
else:
pdf_s3_client = s3_client
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
# List all result files
print(f"Listing result files in {args.workspace}/results...")
result_files = list_result_files(s3_client, args.workspace)
print(f"Found {len(result_files)} result files")
# Use ThreadPoolExecutor to parallelize the generation of sample sets
if args.repeats > 1:
print(f"Using ThreadPoolExecutor with {min(args.max_workers, args.repeats)} workers")
with ThreadPoolExecutor(max_workers=min(args.max_workers, args.repeats)) as executor:
futures = []
for i in range(args.repeats):
future = executor.submit(
generate_sample_set,
args, i, s3_client, pdf_s3_client, result_files
)
future = executor.submit(generate_sample_set, args, i, s3_client, pdf_s3_client, result_files)
futures.append(future)
# Wait for all futures to complete and collect results
for future in futures:
try:
@ -519,5 +523,6 @@ def main():
# If only one repeat, just run it directly
generate_sample_set(args, 0, s3_client, pdf_s3_client, result_files)
if __name__ == "__main__":
main()
main()