mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-28 01:41:27 +00:00
390 lines
14 KiB
Python
390 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Extract inner-most spans and their bounding boxes, and the mathML output,
|
|
from rendered LaTeX equations using Playwright and KaTeX.
|
|
Caching is maintained via a SHA1-based hash stored as a JSON file.
|
|
|
|
Requirements:
|
|
pip install playwright
|
|
python -m playwright install chromium
|
|
|
|
Place katex.min.css and katex.min.js in the same directory as this script
|
|
"""
|
|
|
|
import os
|
|
import hashlib
|
|
import pathlib
|
|
import json
|
|
import re
|
|
import shutil
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
import unittest
|
|
import xml.etree.ElementTree as ET
|
|
|
|
from playwright.sync_api import sync_playwright, Error as PlaywrightError
|
|
|
|
@dataclass
|
|
class BoundingBox:
|
|
x: float
|
|
y: float
|
|
width: float
|
|
height: float
|
|
|
|
@dataclass
|
|
class SpanInfo:
|
|
text: str
|
|
bounding_box: BoundingBox
|
|
|
|
@dataclass
|
|
class RenderedEquation:
|
|
mathml: str
|
|
spans: List[SpanInfo]
|
|
|
|
def get_equation_hash(equation, bg_color="white", text_color="black", font_size=24):
|
|
"""
|
|
Calculate SHA1 hash of the equation string and rendering parameters.
|
|
"""
|
|
params_str = f"{equation}|{bg_color}|{text_color}|{font_size}"
|
|
return hashlib.sha1(params_str.encode('utf-8')).hexdigest()
|
|
|
|
def get_cache_dir():
|
|
"""
|
|
Get the cache directory for equations, creating it if it doesn't exist.
|
|
"""
|
|
cache_dir = pathlib.Path.home() / '.cache' / 'olmocr' / 'bench' / 'equations'
|
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
return cache_dir
|
|
|
|
|
|
def clear_cache_dir():
|
|
"""
|
|
Clear all files and subdirectories in the cache directory.
|
|
"""
|
|
cache_dir = get_cache_dir()
|
|
if cache_dir.exists() and cache_dir.is_dir():
|
|
shutil.rmtree(cache_dir)
|
|
cache_dir.mkdir(parents=True, exist_ok=True) # Recreate the empty directory
|
|
|
|
|
|
def render_equation(
|
|
equation,
|
|
bg_color="white",
|
|
text_color="black",
|
|
font_size=24,
|
|
use_cache=True,
|
|
debug_dom=False,
|
|
):
|
|
"""
|
|
Render a LaTeX equation using Playwright and KaTeX, extract the inner-most span elements
|
|
(those without child elements that contain non-whitespace text) along with their bounding boxes,
|
|
and also extract the MathML output generated by KaTeX.
|
|
|
|
Returns:
|
|
RenderedEquation: A dataclass containing the mathml string and a list of SpanInfo dataclasses.
|
|
"""
|
|
# Calculate hash for caching
|
|
eq_hash = get_equation_hash(equation, bg_color, text_color, font_size)
|
|
cache_dir = get_cache_dir()
|
|
cache_file = cache_dir / f"{eq_hash}.json"
|
|
cache_error_file = cache_dir / f"{eq_hash}_error"
|
|
|
|
if use_cache:
|
|
if cache_error_file.exists():
|
|
return None
|
|
if cache_file.exists():
|
|
with open(cache_file, 'r') as f:
|
|
data = json.load(f)
|
|
spans = [
|
|
SpanInfo(
|
|
text=s["text"],
|
|
bounding_box=BoundingBox(
|
|
x=s["boundingBox"]["x"],
|
|
y=s["boundingBox"]["y"],
|
|
width=s["boundingBox"]["width"],
|
|
height=s["boundingBox"]["height"],
|
|
)
|
|
)
|
|
for s in data["spans"]
|
|
]
|
|
return RenderedEquation(mathml=data["mathml"], spans=spans)
|
|
|
|
# Escape backslashes for JavaScript string
|
|
escaped_equation = json.dumps(equation)
|
|
|
|
# Get local paths for KaTeX files
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
katex_css_path = os.path.join(script_dir, "katex.min.css")
|
|
katex_js_path = os.path.join(script_dir, "katex.min.js")
|
|
|
|
if not os.path.exists(katex_css_path) or not os.path.exists(katex_js_path):
|
|
raise FileNotFoundError(f"KaTeX files not found. Please ensure katex.min.css and katex.min.js are in {script_dir}")
|
|
|
|
with sync_playwright() as p:
|
|
browser = p.chromium.launch()
|
|
page = browser.new_page(viewport={"width": 800, "height": 400})
|
|
|
|
# Basic HTML structure
|
|
html = f"""
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<style>
|
|
body {{
|
|
display: flex;
|
|
justify-content: center;
|
|
align-items: center;
|
|
height: 100vh;
|
|
margin: 0;
|
|
background-color: {bg_color};
|
|
color: {text_color};
|
|
}}
|
|
#equation-container {{
|
|
padding: 0;
|
|
font-size: {font_size}px;
|
|
}}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div id="equation-container"></div>
|
|
</body>
|
|
</html>
|
|
"""
|
|
page.set_content(html)
|
|
page.add_style_tag(path=katex_css_path)
|
|
page.add_script_tag(path=katex_js_path)
|
|
page.wait_for_load_state("networkidle")
|
|
|
|
katex_loaded = page.evaluate("typeof katex !== 'undefined'")
|
|
if not katex_loaded:
|
|
raise RuntimeError("KaTeX library failed to load. Check your katex.min.js file.")
|
|
|
|
try:
|
|
error_message = page.evaluate(f"""
|
|
() => {{
|
|
try {{
|
|
katex.render({escaped_equation}, document.getElementById("equation-container"), {{
|
|
displayMode: true,
|
|
throwOnError: true
|
|
}});
|
|
return null;
|
|
}} catch (error) {{
|
|
console.error("KaTeX error:", error.message);
|
|
return error.message;
|
|
}}
|
|
}}
|
|
""")
|
|
except PlaywrightError as ex:
|
|
print(escaped_equation)
|
|
error_message = str(ex)
|
|
raise
|
|
|
|
if error_message:
|
|
print(f"Error rendering equation: '{equation}'")
|
|
print(error_message)
|
|
cache_error_file.touch()
|
|
browser.close()
|
|
return None
|
|
|
|
page.wait_for_selector(".katex", state="attached")
|
|
|
|
if debug_dom:
|
|
katex_dom_html = page.evaluate("""
|
|
() => {
|
|
return document.getElementById("equation-container").innerHTML;
|
|
}
|
|
""")
|
|
print("\n===== KaTeX DOM HTML =====")
|
|
print(katex_dom_html)
|
|
|
|
# Extract inner-most spans with non-whitespace text
|
|
spans_info = page.evaluate("""
|
|
() => {
|
|
const spans = Array.from(document.querySelectorAll('span'));
|
|
const list = [];
|
|
spans.forEach(span => {
|
|
// Check if this span has no child elements and contains non-whitespace text
|
|
if (span.children.length === 0 && /\S/.test(span.textContent)) {
|
|
const rect = span.getBoundingClientRect();
|
|
list.push({
|
|
text: span.textContent.trim(),
|
|
boundingBox: {
|
|
x: rect.x,
|
|
y: rect.y,
|
|
width: rect.width,
|
|
height: rect.height
|
|
}
|
|
});
|
|
}
|
|
});
|
|
return list;
|
|
}
|
|
""")
|
|
|
|
if debug_dom:
|
|
print("\n===== Extracted Span Information =====")
|
|
print(spans_info)
|
|
|
|
# Extract mathML output (if available) from the KaTeX output.
|
|
# We try to get the <math> element within an element with class "katex-mathml".
|
|
mathml = page.evaluate("""
|
|
() => {
|
|
const mathElem = document.querySelector('.katex-mathml math');
|
|
return mathElem ? mathElem.outerHTML : "";
|
|
}
|
|
""")
|
|
|
|
browser.close()
|
|
|
|
# Build the result as a RenderedEquation dataclass
|
|
rendered_eq = RenderedEquation(
|
|
mathml=mathml,
|
|
spans=[
|
|
SpanInfo(
|
|
text=s["text"],
|
|
bounding_box=BoundingBox(
|
|
x=s["boundingBox"]["x"],
|
|
y=s["boundingBox"]["y"],
|
|
width=s["boundingBox"]["width"],
|
|
height=s["boundingBox"]["height"]
|
|
)
|
|
)
|
|
for s in spans_info
|
|
]
|
|
)
|
|
|
|
# Save to cache (convert dataclasses to a JSON-serializable dict)
|
|
cache_data = {
|
|
"mathml": rendered_eq.mathml,
|
|
"spans": [
|
|
{
|
|
"text": span.text,
|
|
"boundingBox": {
|
|
"x": span.bounding_box.x,
|
|
"y": span.bounding_box.y,
|
|
"width": span.bounding_box.width,
|
|
"height": span.bounding_box.height
|
|
}
|
|
}
|
|
for span in rendered_eq.spans
|
|
]
|
|
}
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(cache_data, f)
|
|
return rendered_eq
|
|
|
|
|
|
def compare_rendered_equations(haystack: RenderedEquation, needle: RenderedEquation) -> bool:
|
|
"""
|
|
Compare two rendered equations by cleaning the MathML (removing namespaces),
|
|
extracting the inner content of any <semantics> element (ignoring <annotation>),
|
|
normalizing whitespace, and checking if the needle's inner MathML is a substring
|
|
of the haystack's inner MathML.
|
|
"""
|
|
|
|
def strip_namespaces(elem: ET.Element) -> ET.Element:
|
|
"""
|
|
Recursively remove namespace prefixes from an ElementTree element.
|
|
"""
|
|
for sub in elem.iter():
|
|
if '}' in sub.tag:
|
|
sub.tag = sub.tag.split('}', 1)[1]
|
|
return elem
|
|
|
|
def extract_inner(mathml: str) -> str:
|
|
"""
|
|
Parse the MathML, remove namespaces, and if a <semantics> element exists,
|
|
concatenate the string representations of its children (except <annotation>).
|
|
Otherwise, return the whole cleaned MathML.
|
|
"""
|
|
try:
|
|
root = ET.fromstring(mathml)
|
|
root = strip_namespaces(root)
|
|
semantics = root.find('semantics')
|
|
if semantics is not None:
|
|
inner_parts = []
|
|
for child in semantics:
|
|
if child.tag != 'annotation':
|
|
inner_parts.append(ET.tostring(child, encoding='unicode'))
|
|
return ''.join(inner_parts)
|
|
else:
|
|
return ET.tostring(root, encoding='unicode')
|
|
except Exception as e:
|
|
# For debugging purposes, print the error
|
|
print("Error parsing MathML:", e)
|
|
return mathml
|
|
|
|
def normalize(s: str) -> str:
|
|
"""
|
|
Remove all whitespace from the string.
|
|
"""
|
|
return re.sub(r'\s+', '', s)
|
|
|
|
# Clean and extract the inner MathML for both haystack and needle.
|
|
haystack_inner = normalize(extract_inner(haystack.mathml))
|
|
needle_inner = normalize(extract_inner(needle.mathml))
|
|
|
|
# For debugging: print the cleaned MathML strings.
|
|
print("Cleaned haystack MathML:", haystack_inner)
|
|
print("Cleaned needle MathML:", needle_inner)
|
|
|
|
# If needle is longer than haystack, swap them.
|
|
if len(needle_inner) > len(haystack_inner):
|
|
needle_inner, haystack_inner = haystack_inner, needle_inner
|
|
|
|
return needle_inner in haystack_inner
|
|
|
|
class TestRenderedEquationComparison(unittest.TestCase):
|
|
def test_exact_match(self):
|
|
# Both calls with identical LaTeX should produce matching MathML output.
|
|
eq1 = render_equation("a+b", use_cache=False)
|
|
eq2 = render_equation("a+b", use_cache=False)
|
|
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
|
|
|
def test_whitespace_difference(self):
|
|
# Differences in whitespace in the LaTeX input should not affect the MathML output.
|
|
eq1 = render_equation("a+b", use_cache=False)
|
|
eq2 = render_equation("a + b", use_cache=False)
|
|
self.assertTrue(compare_rendered_equations(eq1, eq2))
|
|
|
|
def test_not_found(self):
|
|
# Completely different equations should not match.
|
|
eq1 = render_equation("c-d", use_cache=False)
|
|
eq2 = render_equation("a+b", use_cache=False)
|
|
self.assertFalse(compare_rendered_equations(eq1, eq2))
|
|
|
|
def test_align_block_contains_needle(self):
|
|
# The MathML output of the plain equation should be found within the align block output.
|
|
eq_plain = render_equation("a+b", use_cache=False)
|
|
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
|
|
self.assertTrue(compare_rendered_equations(eq_align, eq_plain))
|
|
|
|
def test_align_block_needle_not_in(self):
|
|
# An align block rendering a different equation should not contain the MathML of an unrelated equation.
|
|
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
|
|
eq_diff = render_equation("c-d", use_cache=False)
|
|
self.assertFalse(compare_rendered_equations(eq_align, eq_diff))
|
|
|
|
def test_big(self):
|
|
ref_rendered = render_equation("\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}", use_cache=False, debug_dom=False)
|
|
align_rendered = render_equation("""\\begin{align*}\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}\\end{align*}""", use_cache=False, debug_dom=False)
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
def test_dot_end1(self):
|
|
ref_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]")
|
|
align_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right].")
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
def test_dot_end2(self):
|
|
ref_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\psi(g s)}\\right]")
|
|
align_rendered = render_equation("\\lambda_g = \\sum_{s \\in S} \\zeta_n^{\\psi(gs)} = \\sum_{i=1}^{k} \\left[ \\sum_{s, Rs = I_i} \\zeta_n^{\\psi(gs)} \\right]")
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
def test_lambda(self):
|
|
ref_rendered = render_equation("\\lambda_g = \\lambda_{g'}")
|
|
align_rendered = render_equation("\\lambda_{g}=\\lambda_{g^{\\prime}}")
|
|
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|