mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-24 05:36:12 +00:00
Better equation rendering checker with more tests.
This commit is contained in:
parent
b8e3034847
commit
d45c0323a4
@ -284,7 +284,7 @@ def render_equation(
|
||||
|
||||
page.wait_for_selector(".katex", state="attached")
|
||||
|
||||
if debug_dom:
|
||||
if True:
|
||||
katex_dom_html = page.evaluate(
|
||||
"""
|
||||
() => {
|
||||
@ -394,8 +394,18 @@ def compare_rendered_equations(reference: RenderedEquation, hypothesis: Rendered
|
||||
|
||||
def expand_span_info(span_info: SpanInfo) -> list[SpanInfo]:
|
||||
total_elems = len(span_info.text)
|
||||
return [SpanInfo(c, BoundingBox(span_info.bounding_box.x + (span_info.bounding_box.width * index) / total_elems, span_info.bounding_box.y,
|
||||
span_info.bounding_box.width / total_elems, span_info.bounding_box.height)) for index, c in enumerate(span_info.text)]
|
||||
return [
|
||||
SpanInfo(
|
||||
c,
|
||||
BoundingBox(
|
||||
span_info.bounding_box.x + (span_info.bounding_box.width * index) / total_elems,
|
||||
span_info.bounding_box.y,
|
||||
span_info.bounding_box.width / total_elems,
|
||||
span_info.bounding_box.height,
|
||||
),
|
||||
)
|
||||
for index, c in enumerate(span_info.text)
|
||||
]
|
||||
|
||||
H = [span for sublist in H for span in expand_span_info(sublist)]
|
||||
R = [span for sublist in R for span in expand_span_info(sublist)]
|
||||
@ -560,6 +570,113 @@ class TestRenderedEquationComparison(unittest.TestCase):
|
||||
align_rendered = render_equation("u \\in\\left(R / \\operatorname{Ann}_{R}\\left(x_{i}\\right)\\right)^{\\times}")
|
||||
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
||||
|
||||
def test_fraction_vs_divided_by(self):
|
||||
eq1 = render_equation("\\frac{a}{b}", use_cache=False)
|
||||
eq2 = render_equation("a / b", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_different_bracket_types(self):
|
||||
eq1 = render_equation("\\left[ a + b \\right]", use_cache=False)
|
||||
eq2 = render_equation("\\left\\{ a + b \\right\\}", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_inline_vs_display_style_fraction(self):
|
||||
eq1 = render_equation("\\frac{1}{2}", use_cache=False)
|
||||
eq2 = render_equation("\\displaystyle\\frac{1}{2}", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_matrix_equivalent_forms(self):
|
||||
eq1 = render_equation("\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}", use_cache=False)
|
||||
eq2 = render_equation("\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_different_matrix_types(self):
|
||||
eq1 = render_equation("\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}", use_cache=False)
|
||||
eq2 = render_equation("\\begin{bmatrix} a & b \\\\ c & d \\end{bmatrix}", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_thinspace_vs_regular_space(self):
|
||||
eq1 = render_equation("a \\, b", use_cache=False)
|
||||
eq2 = render_equation("a \\: b", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
@unittest.skip("Currently these compare to the same thing, because they use the symbol 'x' with a different span class and thus font")
|
||||
def test_mathbf_vs_boldsymbol(self):
|
||||
eq1 = render_equation("\\mathbf{x}", use_cache=False)
|
||||
eq2 = render_equation("\\boldsymbol{x}", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_tensor_notation_equivalent(self):
|
||||
eq1 = render_equation("T_{ij}^{kl}", use_cache=False)
|
||||
eq2 = render_equation("T^{kl}_{ij}", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_partial_derivative_forms(self):
|
||||
eq1 = render_equation("\\frac{\\partial f}{\\partial x}", use_cache=False)
|
||||
eq2 = render_equation("\\frac{\\partial_f}{\\partial_x}", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_equivalent_sin_forms_diff_parens(self):
|
||||
eq1 = render_equation("\\sin(\\theta)", use_cache=False)
|
||||
eq2 = render_equation("\\sin \\theta", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_aligned_multiline_equation(self):
|
||||
eq1 = render_equation("\\begin{align*} a &= b \\\\ c &= d \\end{align*}", use_cache=False)
|
||||
eq2 = render_equation("\\begin{aligned} a &= b \\\\ c &= d \\end{aligned}", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_subscript_order_invariance(self):
|
||||
eq1 = render_equation("x_{i,j}", use_cache=False)
|
||||
eq2 = render_equation("x_{j,i}", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_hat_vs_widehat(self):
|
||||
eq1 = render_equation("\\hat{x}", use_cache=False)
|
||||
eq2 = render_equation("\\widehat{x}", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_equivalent_integral_bounds(self):
|
||||
eq1 = render_equation("\\int_{a}^{b} f(x) dx", use_cache=False)
|
||||
eq2 = render_equation("\\int\\limits_{a}^{b} f(x) dx", use_cache=False)
|
||||
# Could go either way honestly?
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_equivalent_summation_notation(self):
|
||||
eq1 = render_equation("\\sum_{i=1}^{n} x_i", use_cache=False)
|
||||
eq2 = render_equation("\\sum\\limits_{i=1}^{n} x_i", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_different_symbol_with_same_appearance(self):
|
||||
eq1 = render_equation("\\phi", use_cache=False)
|
||||
eq2 = render_equation("\\varphi", use_cache=False)
|
||||
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_aligned_vs_gathered(self):
|
||||
eq1 = render_equation("\\begin{aligned} a &= b \\\\ c &= d \\end{aligned}", use_cache=False)
|
||||
eq2 = render_equation("\\begin{gathered} a = b \\\\ c = d \\end{gathered}", use_cache=False)
|
||||
# Different whitespacing, should be invariant to that.
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_identical_but_with_color1(self):
|
||||
eq1 = render_equation("a + b", use_cache=False)
|
||||
eq2 = render_equation("\\color{black}{a + b}", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_identical_but_with_color2(self):
|
||||
eq1 = render_equation("a + b", use_cache=False)
|
||||
eq2 = render_equation("\\color{black}{a} + \\color{black}{b}", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
eq1 = render_equation("a + b", use_cache=False)
|
||||
eq2 = render_equation("\\color{red}{a} + \\color{black}{b}", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
def test_newcommand_expansion(self):
|
||||
eq1 = render_equation("\\alpha + \\beta", use_cache=False)
|
||||
eq2 = render_equation("\\newcommand{\\ab}{\\alpha + \\beta}\\ab", use_cache=False)
|
||||
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -6,20 +6,23 @@
|
||||
# Also take an argument args.repeats which will repeat this whole process N times
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import boto3
|
||||
import tempfile
|
||||
import datetime
|
||||
import re
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from tqdm import tqdm
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64webp
|
||||
from olmocr.s3_utils import parse_s3_path, get_s3_bytes
|
||||
from typing import Optional, Tuple
|
||||
from olmocr.s3_utils import get_s3_bytes, parse_s3_path
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Scan OLMO OCR workspace results and create visual samples")
|
||||
@ -29,10 +32,14 @@ def parse_args():
|
||||
parser.add_argument("--pdf_profile", help="AWS profile for accessing PDFs")
|
||||
parser.add_argument("--output_dir", default="dolma_samples", help="Directory to save output HTML files")
|
||||
parser.add_argument("--max_workers", type=int, default=4, help="Maximum number of worker threads")
|
||||
parser.add_argument("--db_path", default="~/s2pdf_url_data/d65142df-6588-4b68-a12c-d468b3761189.csv.db",
|
||||
help="Path to the SQLite database containing PDF hash to URL mapping")
|
||||
parser.add_argument(
|
||||
"--db_path",
|
||||
default="~/s2pdf_url_data/d65142df-6588-4b68-a12c-d468b3761189.csv.db",
|
||||
help="Path to the SQLite database containing PDF hash to URL mapping",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf"
|
||||
match = re.match(pattern, pretty_pdf_path)
|
||||
@ -40,25 +47,26 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||
return match.group(1) + match.group(2)
|
||||
return None
|
||||
|
||||
|
||||
def get_original_url(pdf_hash: str, db_path: str) -> Optional[str]:
|
||||
"""Look up the original URL for a PDF hash in the SQLite database."""
|
||||
if not pdf_hash:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
sqlite_db_path = os.path.expanduser(db_path)
|
||||
if not os.path.exists(sqlite_db_path):
|
||||
print(f"SQLite database not found at {sqlite_db_path}")
|
||||
return None
|
||||
|
||||
|
||||
conn = sqlite3.connect(sqlite_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
if result:
|
||||
return result[0]
|
||||
return None
|
||||
@ -71,106 +79,104 @@ def list_result_files(s3_client, workspace_path):
|
||||
"""List all JSON result files in the workspace results directory."""
|
||||
bucket, prefix = parse_s3_path(workspace_path)
|
||||
results_prefix = os.path.join(prefix, "results").rstrip("/") + "/"
|
||||
|
||||
|
||||
all_files = []
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
for page in paginator.paginate(Bucket=bucket, Prefix=results_prefix):
|
||||
if "Contents" in page:
|
||||
all_files.extend([f"s3://{bucket}/{obj['Key']}" for obj in page["Contents"]
|
||||
if obj["Key"].endswith(".jsonl") or obj["Key"].endswith(".json")])
|
||||
|
||||
all_files.extend([f"s3://{bucket}/{obj['Key']}" for obj in page["Contents"] if obj["Key"].endswith(".jsonl") or obj["Key"].endswith(".json")])
|
||||
|
||||
if len(all_files) > 1000:
|
||||
break
|
||||
|
||||
|
||||
return all_files
|
||||
|
||||
|
||||
def get_random_pages(s3_client, result_files, count=30):
|
||||
"""Get random pages from the result files."""
|
||||
random_pages = []
|
||||
|
||||
|
||||
# Try to collect the requested number of pages
|
||||
attempts = 0
|
||||
max_attempts = count * 3 # Allow extra attempts to handle potential failures
|
||||
|
||||
|
||||
while len(random_pages) < count and attempts < max_attempts:
|
||||
attempts += 1
|
||||
|
||||
|
||||
# Pick a random result file
|
||||
if not result_files:
|
||||
print("No result files found!")
|
||||
break
|
||||
|
||||
|
||||
result_file = random.choice(result_files)
|
||||
|
||||
|
||||
try:
|
||||
# Get the content of the file
|
||||
content = get_s3_bytes(s3_client, result_file)
|
||||
lines = content.decode('utf-8').strip().split('\n')
|
||||
|
||||
lines = content.decode("utf-8").strip().split("\n")
|
||||
|
||||
if not lines:
|
||||
continue
|
||||
|
||||
|
||||
# Pick a random line (which contains a complete document)
|
||||
line = random.choice(lines)
|
||||
doc = json.loads(line)
|
||||
|
||||
|
||||
# A Dolma document has "text", "metadata", and "attributes" fields
|
||||
if "text" not in doc or "metadata" not in doc or "attributes" not in doc:
|
||||
print(f"Document in {result_file} is not a valid Dolma document")
|
||||
continue
|
||||
|
||||
|
||||
# Get the original PDF path from metadata
|
||||
pdf_path = doc["metadata"].get("Source-File")
|
||||
if not pdf_path:
|
||||
continue
|
||||
|
||||
|
||||
# Get page spans from attributes
|
||||
page_spans = doc["attributes"].get("pdf_page_numbers", [])
|
||||
if not page_spans:
|
||||
continue
|
||||
|
||||
|
||||
# Pick a random page span
|
||||
page_span = random.choice(page_spans)
|
||||
if len(page_span) >= 3:
|
||||
# Page spans are [start_pos, end_pos, page_num]
|
||||
page_num = page_span[2]
|
||||
|
||||
|
||||
# Extract text for this page
|
||||
start_pos, end_pos = page_span[0], page_span[1]
|
||||
page_text = doc["text"][start_pos:end_pos].strip()
|
||||
|
||||
|
||||
# Include the text snippet with the page info
|
||||
random_pages.append((pdf_path, page_num, page_text, result_file))
|
||||
|
||||
|
||||
if len(random_pages) >= count:
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {result_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
print(f"Found {len(random_pages)} random pages from Dolma documents")
|
||||
return random_pages
|
||||
|
||||
def create_presigned_url(s3_client, pdf_path, expiration=3600*24*7):
|
||||
|
||||
def create_presigned_url(s3_client, pdf_path, expiration=3600 * 24 * 7):
|
||||
"""Create a presigned URL for the given S3 path."""
|
||||
try:
|
||||
bucket, key = parse_s3_path(pdf_path)
|
||||
url = s3_client.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={'Bucket': bucket, 'Key': key},
|
||||
ExpiresIn=expiration
|
||||
)
|
||||
url = s3_client.generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=expiration)
|
||||
return url
|
||||
except Exception as e:
|
||||
print(f"Error creating presigned URL for {pdf_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path, db_path, resolution=2048):
|
||||
"""Create an HTML file with rendered PDF pages."""
|
||||
# Get current date and time for the report
|
||||
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
@ -369,7 +375,7 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
|
||||
|
||||
<div class="page-grid">
|
||||
"""
|
||||
|
||||
|
||||
for i, (pdf_path, page_num, page_text, result_file) in enumerate(tqdm(random_pages, desc="Rendering pages")):
|
||||
# Get original URL from PDF hash
|
||||
pdf_hash = parse_pdf_hash(pdf_path)
|
||||
@ -379,23 +385,22 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
|
||||
display_path = pdf_path
|
||||
if len(display_path) > 60:
|
||||
display_path = "..." + display_path[-57:]
|
||||
|
||||
|
||||
|
||||
# Generate presigned URL
|
||||
presigned_url = create_presigned_url(pdf_s3_client, pdf_path)
|
||||
|
||||
|
||||
try:
|
||||
# Download PDF to temp file
|
||||
bucket, key = parse_s3_path(pdf_path)
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_file:
|
||||
pdf_data = pdf_s3_client.get_object(Bucket=bucket, Key=key)['Body'].read()
|
||||
pdf_data = pdf_s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
|
||||
temp_file.write(pdf_data)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
|
||||
# Render PDF to base64 webp
|
||||
try:
|
||||
base64_image = render_pdf_to_base64webp(temp_file_path, page_num, resolution)
|
||||
|
||||
|
||||
# Add to HTML
|
||||
html_content += f"""
|
||||
<div class="page-container">
|
||||
@ -426,10 +431,10 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
html_content += f"""
|
||||
<div class="page-container">
|
||||
@ -444,7 +449,7 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
html_content += """
|
||||
</div>
|
||||
<footer>
|
||||
@ -454,60 +459,59 @@ def create_html_output(random_pages, pdf_s3_client, output_path, workspace_path,
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
f.write(html_content)
|
||||
|
||||
|
||||
print(f"Created HTML output at {output_path}")
|
||||
|
||||
|
||||
def generate_sample_set(args, i, s3_client, pdf_s3_client, result_files):
|
||||
"""Generate a single sample set."""
|
||||
output_filename = Path(args.output_dir) / f"dolma_samples_{i+1}.html"
|
||||
|
||||
|
||||
print(f"\nGenerating sample set {i+1} of {args.repeats}")
|
||||
|
||||
|
||||
# Get random pages
|
||||
random_pages = get_random_pages(s3_client, result_files, args.pages_per_output)
|
||||
|
||||
|
||||
# Create HTML output
|
||||
create_html_output(random_pages, pdf_s3_client, output_filename, args.workspace, args.db_path)
|
||||
|
||||
|
||||
return output_filename
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
|
||||
# Set up S3 clients
|
||||
s3_client = boto3.client('s3')
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
# Set up PDF S3 client with profile if specified
|
||||
if args.pdf_profile:
|
||||
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||
pdf_s3_client = pdf_session.client("s3")
|
||||
else:
|
||||
pdf_s3_client = s3_client
|
||||
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
# List all result files
|
||||
print(f"Listing result files in {args.workspace}/results...")
|
||||
result_files = list_result_files(s3_client, args.workspace)
|
||||
print(f"Found {len(result_files)} result files")
|
||||
|
||||
|
||||
# Use ThreadPoolExecutor to parallelize the generation of sample sets
|
||||
if args.repeats > 1:
|
||||
print(f"Using ThreadPoolExecutor with {min(args.max_workers, args.repeats)} workers")
|
||||
with ThreadPoolExecutor(max_workers=min(args.max_workers, args.repeats)) as executor:
|
||||
futures = []
|
||||
for i in range(args.repeats):
|
||||
future = executor.submit(
|
||||
generate_sample_set,
|
||||
args, i, s3_client, pdf_s3_client, result_files
|
||||
)
|
||||
future = executor.submit(generate_sample_set, args, i, s3_client, pdf_s3_client, result_files)
|
||||
futures.append(future)
|
||||
|
||||
|
||||
# Wait for all futures to complete and collect results
|
||||
for future in futures:
|
||||
try:
|
||||
@ -519,5 +523,6 @@ def main():
|
||||
# If only one repeat, just run it directly
|
||||
generate_sample_set(args, 0, s3_client, pdf_s3_client, result_files)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user