mirror of
https://github.com/allenai/olmocr.git
synced 2025-07-31 04:46:33 +00:00
Organizing things for data entry
This commit is contained in:
parent
af02c63531
commit
9f12917e10
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import argparse
|
||||
from difflib import SequenceMatcher
|
||||
from collections import Counter
|
||||
@ -8,7 +9,6 @@ import syntok.segmenter as segmenter
|
||||
import syntok.tokenizer as tokenizer
|
||||
|
||||
import base64
|
||||
import os
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
@ -18,6 +18,8 @@ from olmocr.bench.tests import TextPresenceTest, save_tests
|
||||
LABEL_WIDTH = 8 # fixed width for printing labels
|
||||
|
||||
# Uses a gemini prompt to get the most likely clean sentence from a pdf page
|
||||
last_gemini_call = time.perf_counter()
|
||||
|
||||
def clean_base_sentence(pdf_path: str, page_num: int, base_sentence: str) -> str:
|
||||
client = genai.Client(
|
||||
api_key=os.environ.get("GEMINI_API_KEY"),
|
||||
@ -58,8 +60,19 @@ Consider the sentence labeled "Base" above in the document image attached. What
|
||||
contents=contents,
|
||||
config=generate_content_config,
|
||||
)
|
||||
result = response.candidates[0].content.parts[0].text
|
||||
return result
|
||||
|
||||
# Basic rate limitting
|
||||
global last_gemini_call
|
||||
if time.perf_counter() - last_gemini_call < 6:
|
||||
time.sleep(6 - (time.perf_counter() - last_gemini_call))
|
||||
|
||||
last_gemini_call = time.perf_counter()
|
||||
|
||||
# Return response
|
||||
if response is not None and response.candidates is not None and len(response.candidates) > 0:
|
||||
return response.candidates[0].content.parts[0].text
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def parse_sentences(text: str) -> list[str]:
|
||||
@ -111,11 +124,9 @@ def compare_votes_for_file(base_pdf_file: str, base_pdf_page: int, base_text: st
|
||||
best_ratio = ratio
|
||||
best_candidate = c_sentence # Keep original capitalization for output
|
||||
|
||||
best_candidate = best_candidate.strip()
|
||||
|
||||
# Append the candidate if it passes the similarity threshold (e.g., 0.7)
|
||||
if best_ratio > 0.7 and best_candidate is not None:
|
||||
votes.append(best_candidate)
|
||||
votes.append(best_candidate.strip())
|
||||
|
||||
# Only consider variants that differ when compared case-insensitively
|
||||
variant_votes = [vote for vote in votes if vote.lower() != b_sentence.lower()]
|
||||
@ -175,7 +186,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--max-diffs",
|
||||
type=int,
|
||||
default=3,
|
||||
default=4,
|
||||
help="Maximum number of diffs to display per file."
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -215,10 +226,9 @@ def main():
|
||||
all_tests.extend(tests)
|
||||
print("")
|
||||
|
||||
# Output test candidates for review after each file, in case there are errors
|
||||
save_tests(all_tests, args.output)
|
||||
break
|
||||
|
||||
# Output test candidates for review
|
||||
save_tests(all_tests, args.output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
import json
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from fuzzysearch import find_near_matches
|
||||
from rapidfuzz import fuzz
|
||||
@ -12,189 +12,189 @@ class TestType(str, Enum):
|
||||
ABSENT = "absent"
|
||||
ORDER = "order"
|
||||
|
||||
class TestChecked(str, Enum):
|
||||
VERIFIED = "verified"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Exception raised for validation errors"""
|
||||
"""Exception raised for validation errors."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class BasePDFTest:
|
||||
"""Base class for all PDF test types"""
|
||||
"""
|
||||
Base class for all PDF test types.
|
||||
|
||||
Attributes:
|
||||
pdf: The PDF filename.
|
||||
page: The page number for the test.
|
||||
id: Unique identifier for the test.
|
||||
type: The type of test.
|
||||
threshold: A float between 0 and 1 representing the threshold for fuzzy matching.
|
||||
"""
|
||||
pdf: str
|
||||
page: int
|
||||
id: str
|
||||
type: str
|
||||
threshold: float
|
||||
|
||||
threshold: float = 1.0
|
||||
checked: Optional[TestChecked] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate common fields
|
||||
if not self.pdf:
|
||||
raise ValidationError("PDF filename cannot be empty")
|
||||
|
||||
if not self.id:
|
||||
raise ValidationError("Test ID cannot be empty")
|
||||
|
||||
if not isinstance(self.threshold, float) or not (0 <= self.threshold <= 1):
|
||||
raise ValidationError(f"Threshold must be a float between 0 and 1, got {self.threshold}")
|
||||
|
||||
# Check that type is valid
|
||||
if self.type not in [t.value for t in TestType]:
|
||||
if self.type not in {t.value for t in TestType}:
|
||||
raise ValidationError(f"Invalid test type: {self.type}")
|
||||
|
||||
|
||||
def run(self, md_content: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run the test on the content of the provided .md file.
|
||||
Returns a tuple (passed, explanation) where 'passed' is True if the test passes,
|
||||
and 'explanation' is a short message explaining the failure when the test does not pass.
|
||||
Run the test on the provided markdown content.
|
||||
|
||||
Args:
|
||||
md_content: The content of the .md file.
|
||||
|
||||
Returns:
|
||||
A tuple (passed, explanation) where 'passed' is True if the test passes,
|
||||
and 'explanation' provides details when the test fails.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement run method")
|
||||
raise NotImplementedError("Subclasses must implement the run method")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextPresenceTest(BasePDFTest):
|
||||
"""Test for text presence or absence in a PDF"""
|
||||
text: str
|
||||
"""
|
||||
Test to verify the presence or absence of specific text in a PDF.
|
||||
|
||||
Attributes:
|
||||
text: The text string to search for.
|
||||
"""
|
||||
text: str
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Additional validation for this specific test type
|
||||
if self.type not in [TestType.PRESENT.value, TestType.ABSENT.value]:
|
||||
if self.type not in {TestType.PRESENT.value, TestType.ABSENT.value}:
|
||||
raise ValidationError(f"Invalid type for TextPresenceTest: {self.type}")
|
||||
|
||||
if not self.text.strip():
|
||||
raise ValidationError("Text field cannot be empty")
|
||||
|
||||
|
||||
def run(self, md_content: str) -> Tuple[bool, str]:
|
||||
reference_query = self.text
|
||||
threshold = self.threshold
|
||||
best_ratio = fuzz.partial_ratio(reference_query, md_content) / 100.0
|
||||
|
||||
|
||||
if self.type == TestType.PRESENT.value:
|
||||
if best_ratio >= threshold:
|
||||
return (True, "")
|
||||
return True, ""
|
||||
else:
|
||||
return (False, f"Expected '{reference_query[:40]}...' with threshold {threshold} but best match ratio was {best_ratio:.3f}")
|
||||
else: # absent
|
||||
msg = (
|
||||
f"Expected '{reference_query[:40]}...' with threshold {threshold} "
|
||||
f"but best match ratio was {best_ratio:.3f}"
|
||||
)
|
||||
return False, msg
|
||||
else: # ABSENT
|
||||
if best_ratio < threshold:
|
||||
return (True, "")
|
||||
return True, ""
|
||||
else:
|
||||
return (False, f"Expected absence of '{reference_query[:40]}...' with threshold {threshold} but best match ratio was {best_ratio:.3f}")
|
||||
msg = (
|
||||
f"Expected absence of '{reference_query[:40]}...' with threshold {threshold} "
|
||||
f"but best match ratio was {best_ratio:.3f}"
|
||||
)
|
||||
return False, msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextOrderTest(BasePDFTest):
|
||||
"""Test for text order in a PDF"""
|
||||
"""
|
||||
Test to verify that one text appears before another in a PDF.
|
||||
|
||||
Attributes:
|
||||
before: The text expected to appear first.
|
||||
after: The text expected to appear after the 'before' text.
|
||||
"""
|
||||
before: str
|
||||
after: str
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Additional validation for this specific test type
|
||||
if self.type != TestType.ORDER.value:
|
||||
raise ValidationError(f"Invalid type for TextOrderTest: {self.type}")
|
||||
|
||||
if not self.before.strip():
|
||||
raise ValidationError("Before field cannot be empty")
|
||||
|
||||
if not self.after.strip():
|
||||
raise ValidationError("After field cannot be empty")
|
||||
|
||||
|
||||
def run(self, md_content: str) -> Tuple[bool, str]:
|
||||
before = self.before
|
||||
after = self.after
|
||||
threshold = self.threshold
|
||||
max_l_dist = round((1.0 - threshold) * len(before))
|
||||
|
||||
before_matches = find_near_matches(before, md_content, max_l_dist=max_l_dist)
|
||||
after_matches = find_near_matches(after, md_content, max_l_dist=max_l_dist)
|
||||
|
||||
max_l_dist = round((1.0 - threshold) * len(self.before))
|
||||
before_matches = find_near_matches(self.before, md_content, max_l_dist=max_l_dist)
|
||||
after_matches = find_near_matches(self.after, md_content, max_l_dist=max_l_dist)
|
||||
|
||||
if not before_matches:
|
||||
return (False, f"'before' search text '{before[:40]}...' not found with max_l_dist {max_l_dist}")
|
||||
return False, f"'before' text '{self.before[:40]}...' not found with max_l_dist {max_l_dist}"
|
||||
if not after_matches:
|
||||
return (False, f"'after' search text '{after[:40]}...' not found with max_l_dist {max_l_dist}")
|
||||
|
||||
return False, f"'after' text '{self.after[:40]}...' not found with max_l_dist {max_l_dist}"
|
||||
|
||||
for before_match in before_matches:
|
||||
for after_match in after_matches:
|
||||
if before_match.start < after_match.start:
|
||||
return (True, "")
|
||||
|
||||
return (False, f"Could not find a location where '{before[:40]}...' appears before '{after[:40]}...'.")
|
||||
return True, ""
|
||||
return False, (
|
||||
f"Could not find a location where '{self.before[:40]}...' appears before "
|
||||
f"'{self.after[:40]}...'."
|
||||
)
|
||||
|
||||
|
||||
def load_tests(jsonl_file: str) -> list[BasePDFTest]:
|
||||
"""Load tests from a JSONL file"""
|
||||
tests = []
|
||||
def load_tests(jsonl_file: str) -> List[BasePDFTest]:
|
||||
"""
|
||||
Load tests from a JSONL file.
|
||||
|
||||
with open(jsonl_file, 'r') as file:
|
||||
for line_number, line in enumerate(file, 1):
|
||||
Args:
|
||||
jsonl_file: Path to the JSONL file containing test definitions.
|
||||
|
||||
Returns:
|
||||
A list of test objects.
|
||||
"""
|
||||
tests: List[BasePDFTest] = []
|
||||
with open(jsonl_file, "r") as file:
|
||||
for line_number, line in enumerate(file, start=1):
|
||||
line = line.strip()
|
||||
if not line: # Skip empty lines
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# Parse the JSON object
|
||||
data = json.loads(line)
|
||||
|
||||
# Based on the type field, create the appropriate test object
|
||||
if data["type"] in [TestType.PRESENT.value, TestType.ABSENT.value]:
|
||||
test = TextPresenceTest(
|
||||
pdf=data["pdf"],
|
||||
page=data["page"],
|
||||
id=data["id"],
|
||||
type=data["type"],
|
||||
threshold=data["threshold"],
|
||||
text=data["text"]
|
||||
)
|
||||
elif data["type"] == TestType.ORDER.value:
|
||||
test = TextOrderTest(
|
||||
pdf=data["pdf"],
|
||||
page=data["page"],
|
||||
id=data["id"],
|
||||
type=data["type"],
|
||||
threshold=data["threshold"],
|
||||
before=data["before"],
|
||||
after=data["after"]
|
||||
)
|
||||
test_type = data.get("type")
|
||||
if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}:
|
||||
test = TextPresenceTest(**data)
|
||||
elif test_type == TestType.ORDER.value:
|
||||
test = TextOrderTest(**data)
|
||||
else:
|
||||
raise ValidationError(f"Unknown test type: {data['type']}")
|
||||
|
||||
raise ValidationError(f"Unknown test type: {test_type}")
|
||||
|
||||
tests.append(test)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON on line {line_number}: {e}")
|
||||
except ValidationError as e:
|
||||
print(f"Validation error on line {line_number}: {e}")
|
||||
except KeyError as e:
|
||||
print(f"Missing required field on line {line_number}: {e}")
|
||||
except (ValidationError, KeyError) as e:
|
||||
print(f"Error on line {line_number}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Unexpected error on line {line_number}: {e}")
|
||||
|
||||
|
||||
return tests
|
||||
|
||||
|
||||
def save_tests(tests: list[BasePDFTest], jsonl_file: str) -> None:
|
||||
"""Save tests to a JSONL file"""
|
||||
with open(jsonl_file, 'w') as file:
|
||||
def save_tests(tests: List[BasePDFTest], jsonl_file: str) -> None:
|
||||
"""
|
||||
Save tests to a JSONL file using asdict for conversion.
|
||||
|
||||
Args:
|
||||
tests: A list of test objects.
|
||||
jsonl_file: Path to the output JSONL file.
|
||||
"""
|
||||
with open(jsonl_file, "w") as file:
|
||||
for test in tests:
|
||||
# Convert dataclass to dict
|
||||
if isinstance(test, TextPresenceTest):
|
||||
data = {
|
||||
"pdf": test.pdf,
|
||||
"id": test.id,
|
||||
"type": test.type,
|
||||
"threshold": test.threshold,
|
||||
"text": test.text
|
||||
}
|
||||
elif isinstance(test, TextOrderTest):
|
||||
data = {
|
||||
"pdf": test.pdf,
|
||||
"id": test.id,
|
||||
"type": test.type,
|
||||
"threshold": test.threshold,
|
||||
"before": test.before,
|
||||
"after": test.after
|
||||
}
|
||||
file.write(json.dumps(data) + '\n')
|
||||
file.write(json.dumps(asdict(test)) + "\n")
|
||||
|
@ -2,10 +2,14 @@
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
import requests
|
||||
|
||||
from collections import defaultdict
|
||||
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
|
||||
|
||||
def parse_rules_file(file_path):
|
||||
@ -31,6 +35,7 @@ def parse_rules_file(file_path):
|
||||
|
||||
return pdf_rules
|
||||
|
||||
|
||||
def get_rule_html(rule, rule_index):
|
||||
"""Generate HTML representation for a rule with interactive elements."""
|
||||
rule_type = rule.get('type', 'unknown')
|
||||
@ -38,7 +43,6 @@ def get_rule_html(rule, rule_index):
|
||||
|
||||
# Determine status button class based on 'checked' value
|
||||
checked_status = rule.get('checked')
|
||||
# We won't set active class here; it'll be updated by JS upon interaction.
|
||||
thumbs_up_class = "active" if checked_status == "verified" else ""
|
||||
thumbs_down_class = "active" if checked_status == "rejected" else ""
|
||||
|
||||
@ -121,6 +125,7 @@ def get_rule_html(rule, rule_index):
|
||||
</tr>
|
||||
"""
|
||||
|
||||
|
||||
def generate_html(pdf_rules, rules_file_path):
|
||||
"""Generate the HTML page with PDF renderings and interactive rules."""
|
||||
# Limit to 10 unique PDFs
|
||||
@ -380,28 +385,24 @@ def generate_html(pdf_rules, rules_file_path):
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Add JavaScript to manage interactivity
|
||||
# Add JavaScript to manage interactivity and datastore integration
|
||||
html += f"""
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Store all rules data
|
||||
// Store all rules data (initially injected from the JSON file)
|
||||
let rulesData = {rules_json};
|
||||
|
||||
// Function to toggle status button
|
||||
function toggleStatus(button) {{
|
||||
// Find the closest rule row which holds the rule index
|
||||
const ruleRow = button.closest('.rule-row');
|
||||
const ruleIndex = parseInt(ruleRow.dataset.ruleIndex);
|
||||
// Determine which action was clicked (either 'verified' or 'rejected')
|
||||
const action = button.dataset.action;
|
||||
|
||||
// Toggle the rule's checked state: if already in that state, set to null; otherwise, set to the clicked action.
|
||||
const currentState = rulesData[ruleIndex].checked;
|
||||
const newState = (currentState === action) ? null : action;
|
||||
rulesData[ruleIndex].checked = newState;
|
||||
|
||||
// Update the UI: adjust active classes on buttons in this row
|
||||
// Update UI for status buttons
|
||||
const buttons = ruleRow.querySelectorAll('.status-button');
|
||||
buttons.forEach(btn => {{
|
||||
if (btn.dataset.action === newState) {{
|
||||
@ -411,6 +412,8 @@ def generate_html(pdf_rules, rules_file_path):
|
||||
}}
|
||||
}});
|
||||
|
||||
// Upload updated data to datastore
|
||||
uploadRulesData();
|
||||
outputJSON();
|
||||
}}
|
||||
|
||||
@ -421,10 +424,11 @@ def generate_html(pdf_rules, rules_file_path):
|
||||
const field = element.dataset.field;
|
||||
const newText = element.innerText.trim();
|
||||
|
||||
// Update rules data
|
||||
// Update the rules data
|
||||
rulesData[ruleIndex][field] = newText;
|
||||
|
||||
// Output updated JSONL to console
|
||||
// Upload updated data to datastore
|
||||
uploadRulesData();
|
||||
outputJSON();
|
||||
}}
|
||||
|
||||
@ -437,8 +441,53 @@ def generate_html(pdf_rules, rules_file_path):
|
||||
}});
|
||||
}}
|
||||
|
||||
// Output initial JSONL when page loads
|
||||
document.addEventListener('DOMContentLoaded', outputJSON);
|
||||
// Function to upload rulesData to datastore using putDatastore
|
||||
async function uploadRulesData() {{
|
||||
try {{
|
||||
await putDatastore(rulesData);
|
||||
console.log("Datastore updated successfully");
|
||||
}} catch (error) {{
|
||||
console.error("Failed to update datastore", error);
|
||||
}}
|
||||
}}
|
||||
|
||||
// Function to update UI from rulesData (used after fetching datastore state)
|
||||
function updateUIFromRulesData() {{
|
||||
document.querySelectorAll('.rule-row').forEach(ruleRow => {{
|
||||
const ruleIndex = parseInt(ruleRow.dataset.ruleIndex);
|
||||
const rule = rulesData[ruleIndex];
|
||||
// Update status buttons
|
||||
const buttons = ruleRow.querySelectorAll('.status-button');
|
||||
buttons.forEach(btn => {{
|
||||
if (btn.dataset.action === rule.checked) {{
|
||||
btn.classList.add('active');
|
||||
}} else {{
|
||||
btn.classList.remove('active');
|
||||
}}
|
||||
}});
|
||||
// Update editable text fields
|
||||
ruleRow.querySelectorAll('.editable-text').forEach(div => {{
|
||||
const field = div.dataset.field;
|
||||
if (rule[field] !== undefined) {{
|
||||
div.innerText = rule[field];
|
||||
}}
|
||||
}});
|
||||
}});
|
||||
}}
|
||||
|
||||
// On page load, fetch data from the datastore and update UI accordingly
|
||||
document.addEventListener('DOMContentLoaded', async function() {{
|
||||
try {{
|
||||
const datastoreState = await fetchDatastore();
|
||||
if (datastoreState.length) {{
|
||||
rulesData = datastoreState;
|
||||
updateUIFromRulesData();
|
||||
outputJSON();
|
||||
}}
|
||||
}} catch (error) {{
|
||||
console.error("Error fetching datastore", error);
|
||||
}}
|
||||
}});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@ -446,6 +495,30 @@ def generate_html(pdf_rules, rules_file_path):
|
||||
|
||||
return html
|
||||
|
||||
def get_page_datastore(html: str):
|
||||
"""
|
||||
Fetch the JSON datastore from the presigned URL.
|
||||
Returns a dict. If any error or no content, returns {}.
|
||||
"""
|
||||
match = re.search(r"const presignedGetUrl = \"(.*?)\";", html)
|
||||
if not match:
|
||||
return None
|
||||
presigned_url = match.group(1)
|
||||
|
||||
try:
|
||||
# Clean up the presigned URL (sometimes the signature may need re-encoding)
|
||||
url_parts = urlsplit(presigned_url)
|
||||
query_params = parse_qs(url_parts.query)
|
||||
encoded_query = urlencode(query_params, doseq=True)
|
||||
cleaned_url = urlunsplit((url_parts.scheme, url_parts.netloc, url_parts.path, encoded_query, url_parts.fragment))
|
||||
|
||||
resp = requests.get(cleaned_url)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except Exception as e:
|
||||
print(f"Error fetching datastore from {presigned_url}: {e}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate an interactive HTML visualization of PDF rules.')
|
||||
parser.add_argument('rules_file', help='Path to the rules file (JSON lines format)')
|
||||
@ -459,8 +532,21 @@ def main():
|
||||
|
||||
if os.path.exists(args.output):
|
||||
print(f"Output file {args.output} already exists, attempting to reload it's datastore")
|
||||
with open(args.output, "r") as df:
|
||||
datastore = get_page_datastore(df.read())
|
||||
|
||||
if datastore is None:
|
||||
print(f"Datastore for {args.output} is empty, please run tinyhost and verify your rules and then rerun the script")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loaded {len(datastore)} entries from datastore, updating {args.rules_file}")
|
||||
|
||||
with open(args.rules_file, 'w') as of:
|
||||
for rule in datastore:
|
||||
of.write(json.dumps(rule) + "\n")
|
||||
|
||||
return
|
||||
|
||||
|
||||
pdf_rules = parse_rules_file(args.rules_file)
|
||||
html = generate_html(pdf_rules, args.rules_file)
|
||||
|
||||
@ -469,5 +555,6 @@ def main():
|
||||
|
||||
print(f"Interactive HTML visualization created: {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user