olmocr/olmocr/bench/katex/render.py
2025-03-11 03:57:12 +00:00

390 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Extract inner-most spans and their bounding boxes, and the mathML output,
from rendered LaTeX equations using Playwright and KaTeX.
Caching is maintained via a SHA1-based hash stored as a JSON file.
Requirements:
pip install playwright
python -m playwright install chromium
Place katex.min.css and katex.min.js in the same directory as this script
"""
import os
import hashlib
import pathlib
import json
import re
import shutil
from dataclasses import dataclass
from typing import List
import unittest
import xml.etree.ElementTree as ET
from playwright.sync_api import sync_playwright, Error as PlaywrightError
@dataclass
class BoundingBox:
x: float
y: float
width: float
height: float
@dataclass
class SpanInfo:
text: str
bounding_box: BoundingBox
@dataclass
class RenderedEquation:
mathml: str
spans: List[SpanInfo]
def get_equation_hash(equation, bg_color="white", text_color="black", font_size=24):
"""
Calculate SHA1 hash of the equation string and rendering parameters.
"""
params_str = f"{equation}|{bg_color}|{text_color}|{font_size}"
return hashlib.sha1(params_str.encode('utf-8')).hexdigest()
def get_cache_dir():
"""
Get the cache directory for equations, creating it if it doesn't exist.
"""
cache_dir = pathlib.Path.home() / '.cache' / 'olmocr' / 'bench' / 'equations'
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def clear_cache_dir():
"""
Clear all files and subdirectories in the cache directory.
"""
cache_dir = get_cache_dir()
if cache_dir.exists() and cache_dir.is_dir():
shutil.rmtree(cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True) # Recreate the empty directory
def render_equation(
equation,
bg_color="white",
text_color="black",
font_size=24,
use_cache=True,
debug_dom=False,
):
"""
Render a LaTeX equation using Playwright and KaTeX, extract the inner-most span elements
(those without child elements that contain non-whitespace text) along with their bounding boxes,
and also extract the MathML output generated by KaTeX.
Returns:
RenderedEquation: A dataclass containing the mathml string and a list of SpanInfo dataclasses.
"""
# Calculate hash for caching
eq_hash = get_equation_hash(equation, bg_color, text_color, font_size)
cache_dir = get_cache_dir()
cache_file = cache_dir / f"{eq_hash}.json"
cache_error_file = cache_dir / f"{eq_hash}_error"
if use_cache:
if cache_error_file.exists():
return None
if cache_file.exists():
with open(cache_file, 'r') as f:
data = json.load(f)
spans = [
SpanInfo(
text=s["text"],
bounding_box=BoundingBox(
x=s["boundingBox"]["x"],
y=s["boundingBox"]["y"],
width=s["boundingBox"]["width"],
height=s["boundingBox"]["height"],
)
)
for s in data["spans"]
]
return RenderedEquation(mathml=data["mathml"], spans=spans)
# Escape backslashes for JavaScript string
escaped_equation = json.dumps(equation)
# Get local paths for KaTeX files
script_dir = os.path.dirname(os.path.abspath(__file__))
katex_css_path = os.path.join(script_dir, "katex.min.css")
katex_js_path = os.path.join(script_dir, "katex.min.js")
if not os.path.exists(katex_css_path) or not os.path.exists(katex_js_path):
raise FileNotFoundError(f"KaTeX files not found. Please ensure katex.min.css and katex.min.js are in {script_dir}")
with sync_playwright() as p:
browser = p.chromium.launch()
page = browser.new_page(viewport={"width": 800, "height": 400})
# Basic HTML structure
html = f"""
<!DOCTYPE html>
<html>
<head>
<style>
body {{
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background-color: {bg_color};
color: {text_color};
}}
#equation-container {{
padding: 0;
font-size: {font_size}px;
}}
</style>
</head>
<body>
<div id="equation-container"></div>
</body>
</html>
"""
page.set_content(html)
page.add_style_tag(path=katex_css_path)
page.add_script_tag(path=katex_js_path)
page.wait_for_load_state("networkidle")
katex_loaded = page.evaluate("typeof katex !== 'undefined'")
if not katex_loaded:
raise RuntimeError("KaTeX library failed to load. Check your katex.min.js file.")
try:
error_message = page.evaluate(f"""
() => {{
try {{
katex.render({escaped_equation}, document.getElementById("equation-container"), {{
displayMode: true,
throwOnError: true
}});
return null;
}} catch (error) {{
console.error("KaTeX error:", error.message);
return error.message;
}}
}}
""")
except PlaywrightError as ex:
print(escaped_equation)
error_message = str(ex)
raise
if error_message:
print(f"Error rendering equation: '{equation}'")
print(error_message)
cache_error_file.touch()
browser.close()
return None
page.wait_for_selector(".katex", state="attached")
if debug_dom:
katex_dom_html = page.evaluate("""
() => {
return document.getElementById("equation-container").innerHTML;
}
""")
print("\n===== KaTeX DOM HTML =====")
print(katex_dom_html)
# Extract inner-most spans with non-whitespace text
spans_info = page.evaluate("""
() => {
const spans = Array.from(document.querySelectorAll('span'));
const list = [];
spans.forEach(span => {
// Check if this span has no child elements and contains non-whitespace text
if (span.children.length === 0 && /\S/.test(span.textContent)) {
const rect = span.getBoundingClientRect();
list.push({
text: span.textContent.trim(),
boundingBox: {
x: rect.x,
y: rect.y,
width: rect.width,
height: rect.height
}
});
}
});
return list;
}
""")
if debug_dom:
print("\n===== Extracted Span Information =====")
print(spans_info)
# Extract mathML output (if available) from the KaTeX output.
# We try to get the <math> element within an element with class "katex-mathml".
mathml = page.evaluate("""
() => {
const mathElem = document.querySelector('.katex-mathml math');
return mathElem ? mathElem.outerHTML : "";
}
""")
browser.close()
# Build the result as a RenderedEquation dataclass
rendered_eq = RenderedEquation(
mathml=mathml,
spans=[
SpanInfo(
text=s["text"],
bounding_box=BoundingBox(
x=s["boundingBox"]["x"],
y=s["boundingBox"]["y"],
width=s["boundingBox"]["width"],
height=s["boundingBox"]["height"]
)
)
for s in spans_info
]
)
# Save to cache (convert dataclasses to a JSON-serializable dict)
cache_data = {
"mathml": rendered_eq.mathml,
"spans": [
{
"text": span.text,
"boundingBox": {
"x": span.bounding_box.x,
"y": span.bounding_box.y,
"width": span.bounding_box.width,
"height": span.bounding_box.height
}
}
for span in rendered_eq.spans
]
}
with open(cache_file, 'w') as f:
json.dump(cache_data, f)
return rendered_eq
def compare_rendered_equations(haystack: RenderedEquation, needle: RenderedEquation) -> bool:
"""
Compare two rendered equations by cleaning the MathML (removing namespaces),
extracting the inner content of any <semantics> element (ignoring <annotation>),
normalizing whitespace, and checking if the needle's inner MathML is a substring
of the haystack's inner MathML.
"""
def strip_namespaces(elem: ET.Element) -> ET.Element:
"""
Recursively remove namespace prefixes from an ElementTree element.
"""
for sub in elem.iter():
if '}' in sub.tag:
sub.tag = sub.tag.split('}', 1)[1]
return elem
def extract_inner(mathml: str) -> str:
"""
Parse the MathML, remove namespaces, and if a <semantics> element exists,
concatenate the string representations of its children (except <annotation>).
Otherwise, return the whole cleaned MathML.
"""
try:
root = ET.fromstring(mathml)
root = strip_namespaces(root)
semantics = root.find('semantics')
if semantics is not None:
inner_parts = []
for child in semantics:
if child.tag != 'annotation':
inner_parts.append(ET.tostring(child, encoding='unicode'))
return ''.join(inner_parts)
else:
return ET.tostring(root, encoding='unicode')
except Exception as e:
# For debugging purposes, print the error
print("Error parsing MathML:", e)
return mathml
def normalize(s: str) -> str:
"""
Remove all whitespace from the string.
"""
return re.sub(r'\s+', '', s)
# Clean and extract the inner MathML for both haystack and needle.
haystack_inner = normalize(extract_inner(haystack.mathml))
needle_inner = normalize(extract_inner(needle.mathml))
# For debugging: print the cleaned MathML strings.
print("Cleaned haystack MathML:", haystack_inner)
print("Cleaned needle MathML:", needle_inner)
# If needle is longer than haystack, swap them.
if len(needle_inner) > len(haystack_inner):
needle_inner, haystack_inner = haystack_inner, needle_inner
return needle_inner in haystack_inner
class TestRenderedEquationComparison(unittest.TestCase):
def test_exact_match(self):
# Both calls with identical LaTeX should produce matching MathML output.
eq1 = render_equation("a+b", use_cache=False)
eq2 = render_equation("a+b", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_whitespace_difference(self):
# Differences in whitespace in the LaTeX input should not affect the MathML output.
eq1 = render_equation("a+b", use_cache=False)
eq2 = render_equation("a + b", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_not_found(self):
# Completely different equations should not match.
eq1 = render_equation("c-d", use_cache=False)
eq2 = render_equation("a+b", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_align_block_contains_needle(self):
# The MathML output of the plain equation should be found within the align block output.
eq_plain = render_equation("a+b", use_cache=False)
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq_align, eq_plain))
def test_align_block_needle_not_in(self):
# An align block rendering a different equation should not contain the MathML of an unrelated equation.
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
eq_diff = render_equation("c-d", use_cache=False)
self.assertFalse(compare_rendered_equations(eq_align, eq_diff))
def test_big(self):
ref_rendered = render_equation("\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}", use_cache=False, debug_dom=False)
align_rendered = render_equation("""\\begin{align*}\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}\\end{align*}""", use_cache=False, debug_dom=False)
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_dot_end1(self):
ref_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]")
align_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right].")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_dot_end2(self):
ref_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\psi(g s)}\\right]")
align_rendered = render_equation("\\lambda_g = \\sum_{s \\in S} \\zeta_n^{\\psi(gs)} = \\sum_{i=1}^{k} \\left[ \\sum_{s, Rs = I_i} \\zeta_n^{\\psi(gs)} \\right]")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_lambda(self):
ref_rendered = render_equation("\\lambda_g = \\lambda_{g'}")
align_rendered = render_equation("\\lambda_{g}=\\lambda_{g^{\\prime}}")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
if __name__ == "__main__":
unittest.main()