olmocr/olmocr/bench/katex/render.py
Jake Poznanski f2951f3f78 Lints
2025-03-17 08:47:57 -07:00

542 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Extract inner-most spans and their bounding boxes, and the MathML output,
from rendered LaTeX equations using Playwright and KaTeX.
Caching is maintained via a SHA1-based hash stored in a sqlite database.
Requirements:
pip install playwright
python -m playwright install chromium
Place katex.min.css and katex.min.js in the same directory as this script
"""
import hashlib
import json
import os
import pathlib
import re
import sqlite3
import threading
import unittest
from dataclasses import dataclass
from typing import List, Optional
from playwright.sync_api import Error as PlaywrightError
from playwright.sync_api import sync_playwright
# --- New SQLite Cache Implementation ---
class EquationCache:
def __init__(self, db_path: Optional[str] = None):
if db_path is None:
# Use the same cache directory as before
cache_dir = pathlib.Path.home() / ".cache" / "olmocr" / "bench" / "equations"
cache_dir.mkdir(parents=True, exist_ok=True)
db_path = str(cache_dir / "cache.db")
self.db_path = db_path
self.lock = threading.Lock()
self._init_db()
def _init_db(self):
with self.lock:
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
# Added an 'error' column to store rendering errors
c.execute(
"""
CREATE TABLE IF NOT EXISTS equations (
eq_hash TEXT PRIMARY KEY,
mathml TEXT,
spans TEXT,
error TEXT
)
"""
)
conn.commit()
conn.close()
def load(self, eq_hash: str) -> Optional["RenderedEquation"]:
with self.lock:
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute("SELECT mathml, spans, error FROM equations WHERE eq_hash = ?", (eq_hash,))
row = c.fetchone()
conn.close()
if row:
mathml, spans_json, error = row
if error:
# In error cases, we return an instance with error set and no spans.
return RenderedEquation(mathml=mathml, spans=[], error=error)
else:
spans_data = json.loads(spans_json)
spans = [
SpanInfo(
text=s["text"],
bounding_box=BoundingBox(
x=s["boundingBox"]["x"],
y=s["boundingBox"]["y"],
width=s["boundingBox"]["width"],
height=s["boundingBox"]["height"],
),
)
for s in spans_data
]
return RenderedEquation(mathml=mathml, spans=spans)
return None
def save(self, eq_hash: str, rendered_eq: "RenderedEquation"):
spans_data = [
{
"text": span.text,
"boundingBox": {
"x": span.bounding_box.x,
"y": span.bounding_box.y,
"width": span.bounding_box.width,
"height": span.bounding_box.height,
},
}
for span in rendered_eq.spans
]
spans_json = json.dumps(spans_data)
with self.lock:
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute(
"INSERT OR REPLACE INTO equations (eq_hash, mathml, spans, error) VALUES (?, ?, ?, ?)",
(eq_hash, rendered_eq.mathml, spans_json, rendered_eq.error),
)
conn.commit()
conn.close()
def clear(self):
with self.lock:
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute("DELETE FROM equations")
conn.commit()
conn.close()
# Global instance of EquationCache
equation_cache = EquationCache()
# --- End SQLite Cache Implementation ---
# Thread-local storage for Playwright and browser instances
_thread_local = threading.local()
@dataclass
class BoundingBox:
x: float
y: float
width: float
height: float
@dataclass
class SpanInfo:
text: str
bounding_box: BoundingBox
@dataclass
class RenderedEquation:
mathml: str
spans: List[SpanInfo]
error: Optional[str] = None # New field to store error messages if rendering fails
def get_equation_hash(equation, bg_color="white", text_color="black", font_size=24):
"""
Calculate SHA1 hash of the equation string and rendering parameters.
"""
params_str = f"{equation}|{bg_color}|{text_color}|{font_size}"
return hashlib.sha1(params_str.encode("utf-8")).hexdigest()
def init_browser():
"""
Initialize the Playwright and browser instance for the current thread if not already done.
"""
if not hasattr(_thread_local, "playwright"):
_thread_local.playwright = sync_playwright().start()
_thread_local.browser = _thread_local.playwright.chromium.launch()
def get_browser():
"""
Return the browser instance for the current thread.
"""
init_browser()
return _thread_local.browser
def render_equation(
equation,
bg_color="white",
text_color="black",
font_size=24,
use_cache=True,
debug_dom=False,
):
"""
Render a LaTeX equation using Playwright and KaTeX, extract the inner-most span elements
along with their bounding boxes, and extract the MathML output generated by KaTeX.
"""
# Calculate hash for caching.
eq_hash = get_equation_hash(equation, bg_color, text_color, font_size)
# Try to load from SQLite cache.
if use_cache:
cached = equation_cache.load(eq_hash)
if cached is not None:
return cached
# Escape the equation for use in a JavaScript string.
escaped_equation = json.dumps(equation)
# Get local paths for KaTeX files.
script_dir = os.path.dirname(os.path.abspath(__file__))
katex_css_path = os.path.join(script_dir, "katex.min.css")
katex_js_path = os.path.join(script_dir, "katex.min.js")
if not os.path.exists(katex_css_path) or not os.path.exists(katex_js_path):
raise FileNotFoundError(f"KaTeX files not found. Please ensure katex.min.css and katex.min.js are in {script_dir}")
# Get the browser instance for the current thread.
browser = get_browser()
# Create a new page.
page = browser.new_page(viewport={"width": 800, "height": 400})
# Basic HTML structure for rendering.
page_html = f"""
<!DOCTYPE html>
<html>
<head>
<style>
body {{
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background-color: {bg_color};
color: {text_color};
}}
#equation-container {{
padding: 0;
font-size: {font_size}px;
}}
</style>
</head>
<body>
<div id="equation-container"></div>
</body>
</html>
"""
page.set_content(page_html)
page.add_style_tag(path=katex_css_path)
page.add_script_tag(path=katex_js_path)
page.wait_for_load_state("networkidle")
katex_loaded = page.evaluate("typeof katex !== 'undefined'")
if not katex_loaded:
page.close()
raise RuntimeError("KaTeX library failed to load. Check your katex.min.js file.")
try:
error_message = page.evaluate(
f"""
() => {{
try {{
katex.render({escaped_equation}, document.getElementById("equation-container"), {{
displayMode: true,
throwOnError: true
}});
return null;
}} catch (error) {{
console.error("KaTeX error:", error.message);
return error.message;
}}
}}
"""
)
except PlaywrightError as ex:
print(escaped_equation)
error_message = str(ex)
page.close()
raise
if error_message:
print(f"Error rendering equation: '{equation}'")
print(error_message)
# Cache the error result so we don't retry it next time.
rendered_eq = RenderedEquation(mathml=error_message, spans=[], error=error_message)
if use_cache:
equation_cache.save(eq_hash, rendered_eq)
page.close()
return rendered_eq
page.wait_for_selector(".katex", state="attached")
if debug_dom:
katex_dom_html = page.evaluate(
"""
() => {
return document.getElementById("equation-container").innerHTML;
}
"""
)
print("\n===== KaTeX DOM HTML =====")
print(katex_dom_html)
# Extract inner-most spans with non-whitespace text.
spans_info = page.evaluate(
"""
() => {
const spans = Array.from(document.querySelectorAll('span'));
const list = [];
spans.forEach(span => {
if (span.children.length === 0 && /\\S/.test(span.textContent)) {
const rect = span.getBoundingClientRect();
list.push({
text: span.textContent.trim(),
boundingBox: {
x: rect.x,
y: rect.y,
width: rect.width,
height: rect.height
}
});
}
});
return list;
}
"""
)
if debug_dom:
print("\n===== Extracted Span Information =====")
print(spans_info)
# Extract MathML output (if available) from the KaTeX output.
mathml = page.evaluate(
"""
() => {
const mathElem = document.querySelector('.katex-mathml math');
return mathElem ? mathElem.outerHTML : "";
}
"""
)
page.close()
rendered_eq = RenderedEquation(
mathml=mathml,
spans=[
SpanInfo(
text=s["text"],
bounding_box=BoundingBox(
x=s["boundingBox"]["x"],
y=s["boundingBox"]["y"],
width=s["boundingBox"]["width"],
height=s["boundingBox"]["height"],
),
)
for s in spans_info
],
)
# Save the successfully rendered equation to the SQLite cache.
if use_cache:
equation_cache.save(eq_hash, rendered_eq)
return rendered_eq
def compare_rendered_equations(reference: RenderedEquation, hypothesis: RenderedEquation) -> bool:
"""
Compare two RenderedEquation objects.
First, check if the normalized MathML of the hypothesis is contained within that of the reference.
If not, perform a neighbor-based matching on the spans.
"""
from bs4 import BeautifulSoup
def extract_inner(mathml: str) -> str:
try:
soup = BeautifulSoup(mathml, "xml")
semantics = soup.find("semantics")
if semantics:
inner_parts = [str(child) for child in semantics.contents if getattr(child, "name", None) != "annotation"]
return "".join(inner_parts)
else:
return str(soup)
except Exception as e:
print("Error parsing MathML with BeautifulSoup:", e)
print(mathml)
return mathml
def normalize(s: str) -> str:
return re.sub(r"\s+", "", s)
reference_inner = normalize(extract_inner(reference.mathml))
hypothesis_inner = normalize(extract_inner(hypothesis.mathml))
if reference_inner in hypothesis_inner:
return True
H, R = reference.spans, hypothesis.spans
H = [span for span in H if span.text != "\u200b"]
R = [span for span in R if span.text != "\u200b"]
candidate_map = {}
for i, hspan in enumerate(H):
candidate_map[i] = [j for j, rsp in enumerate(R) if rsp.text == hspan.text]
if not candidate_map[i]:
return False
def compute_neighbors(spans, tol=5):
neighbors = {}
for i, span in enumerate(spans):
cx = span.bounding_box.x + span.bounding_box.width / 2
cy = span.bounding_box.y + span.bounding_box.height / 2
up = down = left = right = None
up_dist = down_dist = left_dist = right_dist = None
for j, other in enumerate(spans):
if i == j:
continue
ocx = other.bounding_box.x + other.bounding_box.width / 2
ocy = other.bounding_box.y + other.bounding_box.height / 2
if ocy < cy and abs(ocx - cx) <= tol:
dist = cy - ocy
if up is None or dist < up_dist:
up = j
up_dist = dist
if ocy > cy and abs(ocx - cx) <= tol:
dist = ocy - cy
if down is None or dist < down_dist:
down = j
down_dist = dist
if ocx < cx and abs(ocy - cy) <= tol:
dist = cx - ocx
if left is None or dist < left_dist:
left = j
left_dist = dist
if ocx > cx and abs(ocy - cy) <= tol:
dist = ocx - cx
if right is None or dist < right_dist:
right = j
right_dist = dist
neighbors[i] = {"up": up, "down": down, "left": left, "right": right}
return neighbors
hyp_neighbors = compute_neighbors(H)
ref_neighbors = compute_neighbors(R)
n = len(H)
used = [False] * len(R)
assignment = {}
def backtrack(i):
if i == n:
return True
for cand in candidate_map[i]:
if used[cand]:
continue
assignment[i] = cand
used[cand] = True
valid = True
for direction in ["up", "down", "left", "right"]:
hyp_nb = hyp_neighbors[i].get(direction)
ref_nb = ref_neighbors[cand].get(direction)
if hyp_nb is not None:
expected_text = H[hyp_nb].text
if ref_nb is None:
valid = False
break
if hyp_nb in assignment:
if assignment[hyp_nb] != ref_nb:
valid = False
break
else:
if R[ref_nb].text != expected_text:
valid = False
break
if valid:
if backtrack(i + 1):
return True
used[cand] = False
del assignment[i]
return False
return backtrack(0)
class TestRenderedEquationComparison(unittest.TestCase):
def test_exact_match(self):
eq1 = render_equation("a+b", use_cache=False)
eq2 = render_equation("a+b", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_whitespace_difference(self):
eq1 = render_equation("a+b", use_cache=False)
eq2 = render_equation("a + b", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_not_found(self):
eq1 = render_equation("c-d", use_cache=False)
eq2 = render_equation("a+b", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_align_block_contains_needle(self):
eq_plain = render_equation("a+b", use_cache=False)
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq_plain, eq_align))
def test_align_block_needle_not_in(self):
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
eq_diff = render_equation("c-d", use_cache=False)
self.assertFalse(compare_rendered_equations(eq_diff, eq_align))
def test_big(self):
ref_rendered = render_equation("\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}", use_cache=False, debug_dom=False)
align_rendered = render_equation(
"""\\begin{align*}\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}\\end{align*}""", use_cache=False, debug_dom=False
)
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_dot_end1(self):
ref_rendered = render_equation(
"\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]"
)
align_rendered = render_equation(
"\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]."
)
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_dot_end2(self):
ref_rendered = render_equation(
"\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\psi(g s)}\\right]"
)
align_rendered = render_equation(
"\\lambda_g = \\sum_{s \\in S} \\zeta_n^{\\psi(gs)} = \\sum_{i=1}^{k} \\left[ \\sum_{s, Rs = \\mathcal{I}_i} \\zeta_n^{\\psi(gs)} \\right]"
)
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_lambda(self):
ref_rendered = render_equation("\\lambda_g = \\lambda_{g'}")
align_rendered = render_equation("\\lambda_{g}=\\lambda_{g^{\\prime}}")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_gemini(self):
ref_rendered = render_equation("u \\in (R/\\operatorname{Ann}_R(x_i))^{\\times}")
align_rendered = render_equation("u \\in\\left(R / \\operatorname{Ann}_{R}\\left(x_{i}\\right)\\right)^{\\times}")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
if __name__ == "__main__":
unittest.main()