mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-26 08:54:01 +00:00
542 lines
18 KiB
Python
542 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Extract inner-most spans and their bounding boxes, and the MathML output,
|
|
from rendered LaTeX equations using Playwright and KaTeX.
|
|
Caching is maintained via a SHA1-based hash stored in a sqlite database.
|
|
|
|
Requirements:
|
|
pip install playwright
|
|
python -m playwright install chromium
|
|
|
|
Place katex.min.css and katex.min.js in the same directory as this script
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import sqlite3
|
|
import threading
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
from playwright.sync_api import Error as PlaywrightError
|
|
from playwright.sync_api import sync_playwright
|
|
|
|
# --- New SQLite Cache Implementation ---
|
|
|
|
|
|
class EquationCache:
|
|
def __init__(self, db_path: Optional[str] = None):
|
|
if db_path is None:
|
|
# Use the same cache directory as before
|
|
cache_dir = pathlib.Path.home() / ".cache" / "olmocr" / "bench" / "equations"
|
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
db_path = str(cache_dir / "cache.db")
|
|
self.db_path = db_path
|
|
self.lock = threading.Lock()
|
|
self._init_db()
|
|
|
|
def _init_db(self):
|
|
with self.lock:
|
|
conn = sqlite3.connect(self.db_path)
|
|
c = conn.cursor()
|
|
# Added an 'error' column to store rendering errors
|
|
c.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS equations (
|
|
eq_hash TEXT PRIMARY KEY,
|
|
mathml TEXT,
|
|
spans TEXT,
|
|
error TEXT
|
|
)
|
|
"""
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def load(self, eq_hash: str) -> Optional["RenderedEquation"]:
|
|
with self.lock:
|
|
conn = sqlite3.connect(self.db_path)
|
|
c = conn.cursor()
|
|
c.execute("SELECT mathml, spans, error FROM equations WHERE eq_hash = ?", (eq_hash,))
|
|
row = c.fetchone()
|
|
conn.close()
|
|
if row:
|
|
mathml, spans_json, error = row
|
|
if error:
|
|
# In error cases, we return an instance with error set and no spans.
|
|
return RenderedEquation(mathml=mathml, spans=[], error=error)
|
|
else:
|
|
spans_data = json.loads(spans_json)
|
|
spans = [
|
|
SpanInfo(
|
|
text=s["text"],
|
|
bounding_box=BoundingBox(
|
|
x=s["boundingBox"]["x"],
|
|
y=s["boundingBox"]["y"],
|
|
width=s["boundingBox"]["width"],
|
|
height=s["boundingBox"]["height"],
|
|
),
|
|
)
|
|
for s in spans_data
|
|
]
|
|
return RenderedEquation(mathml=mathml, spans=spans)
|
|
return None
|
|
|
|
def save(self, eq_hash: str, rendered_eq: "RenderedEquation"):
|
|
spans_data = [
|
|
{
|
|
"text": span.text,
|
|
"boundingBox": {
|
|
"x": span.bounding_box.x,
|
|
"y": span.bounding_box.y,
|
|
"width": span.bounding_box.width,
|
|
"height": span.bounding_box.height,
|
|
},
|
|
}
|
|
for span in rendered_eq.spans
|
|
]
|
|
spans_json = json.dumps(spans_data)
|
|
with self.lock:
|
|
conn = sqlite3.connect(self.db_path)
|
|
c = conn.cursor()
|
|
c.execute(
|
|
"INSERT OR REPLACE INTO equations (eq_hash, mathml, spans, error) VALUES (?, ?, ?, ?)",
|
|
(eq_hash, rendered_eq.mathml, spans_json, rendered_eq.error),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def clear(self):
|
|
with self.lock:
|
|
conn = sqlite3.connect(self.db_path)
|
|
c = conn.cursor()
|
|
c.execute("DELETE FROM equations")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
# Global instance of EquationCache
|
|
equation_cache = EquationCache()
|
|
|
|
# --- End SQLite Cache Implementation ---
|
|
|
|
|
|
# Thread-local storage for Playwright and browser instances
|
|
_thread_local = threading.local()
|
|
|
|
|
|
@dataclass
|
|
class BoundingBox:
|
|
x: float
|
|
y: float
|
|
width: float
|
|
height: float
|
|
|
|
|
|
@dataclass
|
|
class SpanInfo:
|
|
text: str
|
|
bounding_box: BoundingBox
|
|
|
|
|
|
@dataclass
|
|
class RenderedEquation:
|
|
mathml: str
|
|
spans: List[SpanInfo]
|
|
error: Optional[str] = None # New field to store error messages if rendering fails
|
|
|
|
|
|
def get_equation_hash(equation, bg_color="white", text_color="black", font_size=24):
|
|
"""
|
|
Calculate SHA1 hash of the equation string and rendering parameters.
|
|
"""
|
|
params_str = f"{equation}|{bg_color}|{text_color}|{font_size}"
|
|
return hashlib.sha1(params_str.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def init_browser():
|
|
"""
|
|
Initialize the Playwright and browser instance for the current thread if not already done.
|
|
"""
|
|
if not hasattr(_thread_local, "playwright"):
|
|
_thread_local.playwright = sync_playwright().start()
|
|
_thread_local.browser = _thread_local.playwright.chromium.launch()
|
|
|
|
|
|
def get_browser():
|
|
"""
|
|
Return the browser instance for the current thread.
|
|
"""
|
|
init_browser()
|
|
return _thread_local.browser
|
|
|
|
|
|
def render_equation(
|
|
equation,
|
|
bg_color="white",
|
|
text_color="black",
|
|
font_size=24,
|
|
use_cache=True,
|
|
debug_dom=False,
|
|
):
|
|
"""
|
|
Render a LaTeX equation using Playwright and KaTeX, extract the inner-most span elements
|
|
along with their bounding boxes, and extract the MathML output generated by KaTeX.
|
|
"""
|
|
# Calculate hash for caching.
|
|
eq_hash = get_equation_hash(equation, bg_color, text_color, font_size)
|
|
|
|
# Try to load from SQLite cache.
|
|
if use_cache:
|
|
cached = equation_cache.load(eq_hash)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Escape the equation for use in a JavaScript string.
|
|
escaped_equation = json.dumps(equation)
|
|
|
|
# Get local paths for KaTeX files.
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
katex_css_path = os.path.join(script_dir, "katex.min.css")
|
|
katex_js_path = os.path.join(script_dir, "katex.min.js")
|
|
|
|
if not os.path.exists(katex_css_path) or not os.path.exists(katex_js_path):
|
|
raise FileNotFoundError(f"KaTeX files not found. Please ensure katex.min.css and katex.min.js are in {script_dir}")
|
|
|
|
# Get the browser instance for the current thread.
|
|
browser = get_browser()
|
|
|
|
# Create a new page.
|
|
page = browser.new_page(viewport={"width": 800, "height": 400})
|
|
|
|
# Basic HTML structure for rendering.
|
|
page_html = f"""
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<style>
|
|
body {{
|
|
display: flex;
|
|
justify-content: center;
|
|
align-items: center;
|
|
height: 100vh;
|
|
margin: 0;
|
|
background-color: {bg_color};
|
|
color: {text_color};
|
|
}}
|
|
#equation-container {{
|
|
padding: 0;
|
|
font-size: {font_size}px;
|
|
}}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div id="equation-container"></div>
|
|
</body>
|
|
</html>
|
|
"""
|
|
page.set_content(page_html)
|
|
page.add_style_tag(path=katex_css_path)
|
|
page.add_script_tag(path=katex_js_path)
|
|
page.wait_for_load_state("networkidle")
|
|
|
|
katex_loaded = page.evaluate("typeof katex !== 'undefined'")
|
|
if not katex_loaded:
|
|
page.close()
|
|
raise RuntimeError("KaTeX library failed to load. Check your katex.min.js file.")
|
|
|
|
try:
|
|
error_message = page.evaluate(
|
|
f"""
|
|
() => {{
|
|
try {{
|
|
katex.render({escaped_equation}, document.getElementById("equation-container"), {{
|
|
displayMode: true,
|
|
throwOnError: true
|
|
}});
|
|
return null;
|
|
}} catch (error) {{
|
|
console.error("KaTeX error:", error.message);
|
|
return error.message;
|
|
}}
|
|
}}
|
|
"""
|
|
)
|
|
except PlaywrightError as ex:
|
|
print(escaped_equation)
|
|
error_message = str(ex)
|
|
page.close()
|
|
raise
|
|
|
|
if error_message:
|
|
print(f"Error rendering equation: '{equation}'")
|
|
print(error_message)
|
|
# Cache the error result so we don't retry it next time.
|
|
rendered_eq = RenderedEquation(mathml=error_message, spans=[], error=error_message)
|
|
if use_cache:
|
|
equation_cache.save(eq_hash, rendered_eq)
|
|
page.close()
|
|
return rendered_eq
|
|
|
|
page.wait_for_selector(".katex", state="attached")
|
|
|
|
if debug_dom:
|
|
katex_dom_html = page.evaluate(
|
|
"""
|
|
() => {
|
|
return document.getElementById("equation-container").innerHTML;
|
|
}
|
|
"""
|
|
)
|
|
print("\n===== KaTeX DOM HTML =====")
|
|
print(katex_dom_html)
|
|
|
|
# Extract inner-most spans with non-whitespace text.
|
|
spans_info = page.evaluate(
|
|
"""
|
|
() => {
|
|
const spans = Array.from(document.querySelectorAll('span'));
|
|
const list = [];
|
|
spans.forEach(span => {
|
|
if (span.children.length === 0 && /\\S/.test(span.textContent)) {
|
|
const rect = span.getBoundingClientRect();
|
|
list.push({
|
|
text: span.textContent.trim(),
|
|
boundingBox: {
|
|
x: rect.x,
|
|
y: rect.y,
|
|
width: rect.width,
|
|
height: rect.height
|
|
}
|
|
});
|
|
}
|
|
});
|
|
return list;
|
|
}
|
|
"""
|
|
)
|
|
|
|
if debug_dom:
|
|
print("\n===== Extracted Span Information =====")
|
|
print(spans_info)
|
|
|
|
# Extract MathML output (if available) from the KaTeX output.
|
|
mathml = page.evaluate(
|
|
"""
|
|
() => {
|
|
const mathElem = document.querySelector('.katex-mathml math');
|
|
return mathElem ? mathElem.outerHTML : "";
|
|
}
|
|
"""
|
|
)
|
|
|
|
page.close()
|
|
|
|
rendered_eq = RenderedEquation(
|
|
mathml=mathml,
|
|
spans=[
|
|
SpanInfo(
|
|
text=s["text"],
|
|
bounding_box=BoundingBox(
|
|
x=s["boundingBox"]["x"],
|
|
y=s["boundingBox"]["y"],
|
|
width=s["boundingBox"]["width"],
|
|
height=s["boundingBox"]["height"],
|
|
),
|
|
)
|
|
for s in spans_info
|
|
],
|
|
)
|
|
|
|
# Save the successfully rendered equation to the SQLite cache.
|
|
if use_cache:
|
|
equation_cache.save(eq_hash, rendered_eq)
|
|
return rendered_eq
|
|
|
|
|
|
def compare_rendered_equations(reference: RenderedEquation, hypothesis: RenderedEquation) -> bool:
|
|
"""
|
|
Compare two RenderedEquation objects.
|
|
First, check if the normalized MathML of the hypothesis is contained within that of the reference.
|
|
If not, perform a neighbor-based matching on the spans.
|
|
"""
|
|
from bs4 import BeautifulSoup
|
|
|
|
def extract_inner(mathml: str) -> str:
|
|
try:
|
|
soup = BeautifulSoup(mathml, "xml")
|
|
semantics = soup.find("semantics")
|
|
if semantics:
|
|
inner_parts = [str(child) for child in semantics.contents if getattr(child, "name", None) != "annotation"]
|
|
return "".join(inner_parts)
|
|
else:
|
|
return str(soup)
|
|
except Exception as e:
|
|
print("Error parsing MathML with BeautifulSoup:", e)
|
|
print(mathml)
|
|
return mathml
|
|
|
|
def normalize(s: str) -> str:
|
|
return re.sub(r"\s+", "", s)
|
|
|
|
reference_inner = normalize(extract_inner(reference.mathml))
|
|
hypothesis_inner = normalize(extract_inner(hypothesis.mathml))
|
|
if reference_inner in hypothesis_inner:
|
|
return True
|
|
|
|
H, R = reference.spans, hypothesis.spans
|
|
H = [span for span in H if span.text != "\u200b"]
|
|
R = [span for span in R if span.text != "\u200b"]
|
|
|
|
candidate_map = {}
|
|
for i, hspan in enumerate(H):
|
|
candidate_map[i] = [j for j, rsp in enumerate(R) if rsp.text == hspan.text]
|
|
if not candidate_map[i]:
|
|
return False
|
|
|
|
def compute_neighbors(spans, tol=5):
|
|
neighbors = {}
|
|
for i, span in enumerate(spans):
|
|
cx = span.bounding_box.x + span.bounding_box.width / 2
|
|
cy = span.bounding_box.y + span.bounding_box.height / 2
|
|
up = down = left = right = None
|
|
up_dist = down_dist = left_dist = right_dist = None
|
|
for j, other in enumerate(spans):
|
|
if i == j:
|
|
continue
|
|
ocx = other.bounding_box.x + other.bounding_box.width / 2
|
|
ocy = other.bounding_box.y + other.bounding_box.height / 2
|
|
if ocy < cy and abs(ocx - cx) <= tol:
|
|
dist = cy - ocy
|
|
if up is None or dist < up_dist:
|
|
up = j
|
|
up_dist = dist
|
|
if ocy > cy and abs(ocx - cx) <= tol:
|
|
dist = ocy - cy
|
|
if down is None or dist < down_dist:
|
|
down = j
|
|
down_dist = dist
|
|
if ocx < cx and abs(ocy - cy) <= tol:
|
|
dist = cx - ocx
|
|
if left is None or dist < left_dist:
|
|
left = j
|
|
left_dist = dist
|
|
if ocx > cx and abs(ocy - cy) <= tol:
|
|
dist = ocx - cx
|
|
if right is None or dist < right_dist:
|
|
right = j
|
|
right_dist = dist
|
|
neighbors[i] = {"up": up, "down": down, "left": left, "right": right}
|
|
return neighbors
|
|
|
|
hyp_neighbors = compute_neighbors(H)
|
|
ref_neighbors = compute_neighbors(R)
|
|
|
|
n = len(H)
|
|
used = [False] * len(R)
|
|
assignment = {}
|
|
|
|
def backtrack(i):
|
|
if i == n:
|
|
return True
|
|
for cand in candidate_map[i]:
|
|
if used[cand]:
|
|
continue
|
|
assignment[i] = cand
|
|
used[cand] = True
|
|
valid = True
|
|
for direction in ["up", "down", "left", "right"]:
|
|
hyp_nb = hyp_neighbors[i].get(direction)
|
|
ref_nb = ref_neighbors[cand].get(direction)
|
|
if hyp_nb is not None:
|
|
expected_text = H[hyp_nb].text
|
|
if ref_nb is None:
|
|
valid = False
|
|
break
|
|
if hyp_nb in assignment:
|
|
if assignment[hyp_nb] != ref_nb:
|
|
valid = False
|
|
break
|
|
else:
|
|
if R[ref_nb].text != expected_text:
|
|
valid = False
|
|
break
|
|
if valid:
|
|
if backtrack(i + 1):
|
|
return True
|
|
used[cand] = False
|
|
del assignment[i]
|
|
return False
|
|
|
|
return backtrack(0)
|
|
|
|
|
|
class TestRenderedEquationComparison(unittest.TestCase):
|
|
def test_exact_match(self):
|
|
eq1 = render_equation("a+b", use_cache=False)
|
|
eq2 = render_equation("a+b", use_cache=False)
|
|
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
|
|
|
def test_whitespace_difference(self):
|
|
eq1 = render_equation("a+b", use_cache=False)
|
|
eq2 = render_equation("a + b", use_cache=False)
|
|
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
|
|
|
def test_not_found(self):
|
|
eq1 = render_equation("c-d", use_cache=False)
|
|
eq2 = render_equation("a+b", use_cache=False)
|
|
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
|
|
|
def test_align_block_contains_needle(self):
|
|
eq_plain = render_equation("a+b", use_cache=False)
|
|
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
|
|
self.assertTrue(compare_rendered_equations(eq_plain, eq_align))
|
|
|
|
def test_align_block_needle_not_in(self):
|
|
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
|
|
eq_diff = render_equation("c-d", use_cache=False)
|
|
self.assertFalse(compare_rendered_equations(eq_diff, eq_align))
|
|
|
|
def test_big(self):
|
|
ref_rendered = render_equation("\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}", use_cache=False, debug_dom=False)
|
|
align_rendered = render_equation(
|
|
"""\\begin{align*}\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}\\end{align*}""", use_cache=False, debug_dom=False
|
|
)
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
def test_dot_end1(self):
|
|
ref_rendered = render_equation(
|
|
"\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]"
|
|
)
|
|
align_rendered = render_equation(
|
|
"\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]."
|
|
)
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
def test_dot_end2(self):
|
|
ref_rendered = render_equation(
|
|
"\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\psi(g s)}\\right]"
|
|
)
|
|
align_rendered = render_equation(
|
|
"\\lambda_g = \\sum_{s \\in S} \\zeta_n^{\\psi(gs)} = \\sum_{i=1}^{k} \\left[ \\sum_{s, Rs = \\mathcal{I}_i} \\zeta_n^{\\psi(gs)} \\right]"
|
|
)
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
def test_lambda(self):
|
|
ref_rendered = render_equation("\\lambda_g = \\lambda_{g'}")
|
|
align_rendered = render_equation("\\lambda_{g}=\\lambda_{g^{\\prime}}")
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
def test_gemini(self):
|
|
ref_rendered = render_equation("u \\in (R/\\operatorname{Ann}_R(x_i))^{\\times}")
|
|
align_rendered = render_equation("u \\in\\left(R / \\operatorname{Ann}_{R}\\left(x_{i}\\right)\\right)^{\\times}")
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|