mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-20 05:58:29 +00:00
Adding in eval scripts from oe-data-internal now all in one place
This commit is contained in:
parent
e64d4f7103
commit
a50ffe27c9
0
pdelfin/eval/__init__.py
Normal file
0
pdelfin/eval/__init__.py
Normal file
94
pdelfin/eval/evalhtml.py
Normal file
94
pdelfin/eval/evalhtml.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
from jinja2 import Template
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import boto3
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
session = boto3.Session(profile_name='s2')
|
||||||
|
s3_client = session.client('s3')
|
||||||
|
|
||||||
|
|
||||||
|
def render_pdf_to_base64png(s3_path, page):
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_pdf:
|
||||||
|
pdf_path = tmp_pdf.name
|
||||||
|
bucket, key = s3_path.replace("s3://", "").split('/', 1)
|
||||||
|
s3_client.download_file(bucket, key, pdf_path)
|
||||||
|
|
||||||
|
# Render the PDF to an image, and display it in the first position
|
||||||
|
pdftoppm_result = subprocess.run(
|
||||||
|
["pdftoppm",
|
||||||
|
"-png",
|
||||||
|
"-f", str(page),
|
||||||
|
"-l", str(page),
|
||||||
|
pdf_path],
|
||||||
|
timeout=120,
|
||||||
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
assert pdftoppm_result.returncode == 0, pdftoppm_result.stderr
|
||||||
|
|
||||||
|
png_image = Image.open(io.BytesIO(pdftoppm_result.stdout))
|
||||||
|
webp_output = io.BytesIO()
|
||||||
|
png_image.save(webp_output, format="WEBP")
|
||||||
|
|
||||||
|
image_base64 = base64.b64encode(webp_output.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
return image_base64
|
||||||
|
|
||||||
|
|
||||||
|
def create_review_html(data, filename="review_page.html"):
|
||||||
|
# Load the Jinja2 template from the file
|
||||||
|
with open(os.path.join(os.path.dirname(__file__), "evalhtml_template.html"), "r") as f:
|
||||||
|
template = Template(f.read())
|
||||||
|
|
||||||
|
entries = []
|
||||||
|
for i, entry in tqdm(enumerate(data)):
|
||||||
|
# Randomly decide whether to display gold on the left or right
|
||||||
|
if random.choice([True, False]):
|
||||||
|
left_text, right_text = entry["gold_text"], entry["eval_text"]
|
||||||
|
left_alignment, right_alignment = entry["alignment"], entry["alignment"]
|
||||||
|
left_class, right_class = "gold", "eval"
|
||||||
|
else:
|
||||||
|
left_text, right_text = entry["eval_text"], entry["gold_text"]
|
||||||
|
left_alignment, right_alignment = entry["alignment"], entry["alignment"]
|
||||||
|
left_class, right_class = "eval", "gold"
|
||||||
|
|
||||||
|
# Convert newlines to <p> tags for proper formatting
|
||||||
|
left_text = "<p>" + left_text.replace("\n", "</p><p>") + "</p>"
|
||||||
|
right_text = "<p>" + right_text.replace("\n", "</p><p>") + "</p>"
|
||||||
|
|
||||||
|
parsed_url = urlparse(entry["s3_path"])
|
||||||
|
bucket = parsed_url.netloc
|
||||||
|
s3_key = parsed_url.path.lstrip('/')
|
||||||
|
signed_pdf_link = s3_client.generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": s3_key}, ExpiresIn=604800)
|
||||||
|
|
||||||
|
# Create a dictionary for each entry
|
||||||
|
entries.append({
|
||||||
|
"entry_id": i,
|
||||||
|
"page_image": render_pdf_to_base64png(entry["s3_path"], entry["page"]),
|
||||||
|
"s3_path": entry["s3_path"],
|
||||||
|
"page": entry["page"],
|
||||||
|
"signed_pdf_link": signed_pdf_link,
|
||||||
|
"left_text": left_text,
|
||||||
|
"right_text": right_text,
|
||||||
|
"left_alignment": left_alignment,
|
||||||
|
"right_alignment": right_alignment,
|
||||||
|
"left_class": left_class,
|
||||||
|
"right_class": right_class,
|
||||||
|
"gold_class": "gold" if left_class == "gold" else "eval",
|
||||||
|
"eval_class": "eval" if right_class == "eval" else "gold"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Render the template with the entries
|
||||||
|
final_html = template.render(entries=entries)
|
||||||
|
|
||||||
|
# Write the HTML content to the specified file
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(final_html)
|
||||||
|
|
||||||
|
print(f"HTML file '{filename}' created successfully!")
|
397
pdelfin/eval/evalhtml_template.html
Normal file
397
pdelfin/eval/evalhtml_template.html
Normal file
@ -0,0 +1,397 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Text Evaluation Review</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
background-color: #f9f9f9;
|
||||||
|
margin: 0;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
text-align: center;
|
||||||
|
font-size: 2em;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 1200px;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
.entry {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 1fr 1fr 1fr;
|
||||||
|
grid-gap: 20px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
padding: 20px;
|
||||||
|
background-color: #fff;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 0 10px rgba(0,0,0,0.1);
|
||||||
|
transition: background-color 0.3s ease;
|
||||||
|
}
|
||||||
|
.text-block {
|
||||||
|
padding: 10px;
|
||||||
|
background-color: #f1f1f1;
|
||||||
|
border-radius: 6px;
|
||||||
|
min-height: 100px;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
justify-content: space-between;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
.text-block:hover {
|
||||||
|
background-color: #e0e0e0;
|
||||||
|
}
|
||||||
|
.text-block.selected {
|
||||||
|
background-color: lightgreen;
|
||||||
|
border: 2px solid black;
|
||||||
|
}
|
||||||
|
.alignment {
|
||||||
|
font-size: 0.9em;
|
||||||
|
color: #777;
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
.reveal-box {
|
||||||
|
position: fixed;
|
||||||
|
top: 20px;
|
||||||
|
right: 20px;
|
||||||
|
padding: 15px;
|
||||||
|
background-color: white;
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||||
|
z-index: 1000;
|
||||||
|
width: 200px;
|
||||||
|
}
|
||||||
|
.reveal-box input {
|
||||||
|
margin-right: 10px;
|
||||||
|
}
|
||||||
|
.reveal-info {
|
||||||
|
margin-top: 10px;
|
||||||
|
font-size: 0.9em;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
.revealed .gold {
|
||||||
|
background-color: #fff9e6;
|
||||||
|
}
|
||||||
|
.revealed .eval {
|
||||||
|
background-color: #e6f3ff;
|
||||||
|
}
|
||||||
|
.revealed .text-block.selected {
|
||||||
|
border: 2px solid black;
|
||||||
|
}
|
||||||
|
.entry > div:first-child img {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
object-fit: cover;
|
||||||
|
border-radius: 6px;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: transform 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
}
|
||||||
|
/* Full-screen preview mode */
|
||||||
|
.full-screen img {
|
||||||
|
position: fixed;
|
||||||
|
top: 50%;
|
||||||
|
left: 50%;
|
||||||
|
transform: translate(-50%, -50%);
|
||||||
|
width: unset !important;
|
||||||
|
max-width: 90vw;
|
||||||
|
height: auto;
|
||||||
|
max-height: 90vh;
|
||||||
|
z-index: 1001;
|
||||||
|
box-shadow: 0 0 20px rgba(0, 0, 0, 0.5);
|
||||||
|
}
|
||||||
|
.overlay {
|
||||||
|
position: fixed;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
width: 100vw;
|
||||||
|
height: 100vh;
|
||||||
|
background-color: rgba(0, 0, 0, 0.7);
|
||||||
|
z-index: 1000;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
.overlay.active {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
/* Voting Buttons Styles */
|
||||||
|
.voting-buttons {
|
||||||
|
margin-top: 10px;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 5px;
|
||||||
|
}
|
||||||
|
.voting-buttons button {
|
||||||
|
padding: 8px 12px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
background-color: #007BFF;
|
||||||
|
color: white;
|
||||||
|
transition: background-color 0.3s ease, border 0.3s ease;
|
||||||
|
}
|
||||||
|
.voting-buttons button:hover {
|
||||||
|
background-color: #0056b3;
|
||||||
|
}
|
||||||
|
.voting-buttons button.invalid {
|
||||||
|
background-color: #dc3545;
|
||||||
|
}
|
||||||
|
.voting-buttons button.invalid:hover {
|
||||||
|
background-color: #a71d2a;
|
||||||
|
}
|
||||||
|
.voting-buttons button.both-good {
|
||||||
|
background-color: #28a745;
|
||||||
|
}
|
||||||
|
.voting-buttons button.both-good:hover {
|
||||||
|
background-color: #1e7e34;
|
||||||
|
}
|
||||||
|
.voting-buttons button.both-bad {
|
||||||
|
background-color: #ffc107;
|
||||||
|
color: #212529;
|
||||||
|
}
|
||||||
|
.voting-buttons button.both-bad:hover {
|
||||||
|
background-color: #e0a800;
|
||||||
|
}
|
||||||
|
/* Selected State for Voting Buttons */
|
||||||
|
.voting-buttons button.selected {
|
||||||
|
border: 3px solid #000;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Text Evaluation Review</h1>
|
||||||
|
|
||||||
|
<!-- Floating Reveal Box -->
|
||||||
|
<div class="reveal-box">
|
||||||
|
<input type="checkbox" id="reveal-toggle" />
|
||||||
|
<label for="reveal-toggle">Reveal Gold/Eval</label>
|
||||||
|
<div class="reveal-info" id="vote-info">Votes</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="container">
|
||||||
|
{% for entry in entries %}
|
||||||
|
<div class="entry {{ entry.gold_class }} {{ entry.eval_class }}" data-entry-id="{{ entry.s3_path | replace('/', '_') }}">
|
||||||
|
<div class="image-container">
|
||||||
|
<img src="data:image/png;base64,{{ entry.page_image }}" alt="Render">
|
||||||
|
|
||||||
|
<div class="alignment">Alignment: {{ entry.left_alignment }}</div>
|
||||||
|
<a href="{{entry.signed_pdf_link}}#page={{ entry.page }}" target="_blank">{{ entry.s3_path }} (Page {{ entry.page }})</a>
|
||||||
|
|
||||||
|
<!-- Voting Buttons -->
|
||||||
|
<div class="voting-buttons">
|
||||||
|
<button class="both-good" data-vote="both_good">Both Good</button>
|
||||||
|
<button class="both-bad" data-vote="both_bad">Both Bad</button>
|
||||||
|
<button class="invalid" data-vote="invalid_pdf">Invalid PDF</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="text-block {{ entry.left_class }}" data-choice="left">
|
||||||
|
<div>{{ entry.left_text|safe }}</div>
|
||||||
|
</div>
|
||||||
|
<div class="text-block {{ entry.right_class }}" data-choice="right">
|
||||||
|
<div>{{ entry.right_text|safe }}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Overlay for full-screen preview -->
|
||||||
|
<div class="overlay"></div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
document.addEventListener('DOMContentLoaded', () => {
|
||||||
|
fetchDataAndUpdatePage();
|
||||||
|
|
||||||
|
// Toggle the full-screen image preview
|
||||||
|
const overlay = document.querySelector('.overlay');
|
||||||
|
document.querySelectorAll('.image-container img').forEach(img => {
|
||||||
|
img.addEventListener('click', () => {
|
||||||
|
const entry = img.closest('.entry');
|
||||||
|
entry.classList.toggle('full-screen');
|
||||||
|
overlay.classList.toggle('active');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
overlay.addEventListener('click', () => {
|
||||||
|
document.querySelectorAll('.full-screen').forEach(entry => {
|
||||||
|
entry.classList.remove('full-screen');
|
||||||
|
});
|
||||||
|
overlay.classList.remove('active');
|
||||||
|
});
|
||||||
|
|
||||||
|
document.getElementById('reveal-toggle').addEventListener('change', (e) => {
|
||||||
|
document.body.classList.toggle('revealed', e.target.checked);
|
||||||
|
updateReveal();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Handle text-block selections
|
||||||
|
document.querySelectorAll('.text-block').forEach(block => {
|
||||||
|
block.addEventListener('click', () => selectChoice(block));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Handle voting buttons
|
||||||
|
document.querySelectorAll('.voting-buttons button').forEach(button => {
|
||||||
|
button.addEventListener('click', () => handleVote(button));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Utility function to sanitize s3_path for use as a key
|
||||||
|
function sanitizeKey(key) {
|
||||||
|
return key.replace(/[^a-zA-Z0-9-_]/g, '_');
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchDataAndUpdatePage() {
|
||||||
|
let datastore = await fetchDatastore();
|
||||||
|
|
||||||
|
document.querySelectorAll('.entry').forEach(entry => {
|
||||||
|
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
|
||||||
|
const leftBlock = entry.querySelector('.text-block[data-choice="left"]');
|
||||||
|
const rightBlock = entry.querySelector('.text-block[data-choice="right"]');
|
||||||
|
const voteButtons = entry.querySelectorAll('.voting-buttons button');
|
||||||
|
|
||||||
|
if (datastore[entryKey]) {
|
||||||
|
const choice = datastore[entryKey];
|
||||||
|
if (choice === 'left' || choice === 'right') {
|
||||||
|
const selectedBlock = choice === 'left' ? leftBlock : rightBlock;
|
||||||
|
selectChoice(selectedBlock, false);
|
||||||
|
} else {
|
||||||
|
// Handle additional voting choices
|
||||||
|
handleAdditionalVote(entry, choice, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
updateVoteInfo(datastore);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function selectChoice(block, save = true) {
|
||||||
|
let datastore = await fetchDatastore();
|
||||||
|
|
||||||
|
const entry = block.closest('.entry');
|
||||||
|
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
|
||||||
|
const blocks = entry.querySelectorAll('.text-block');
|
||||||
|
|
||||||
|
blocks.forEach(b => b.classList.remove('selected'));
|
||||||
|
block.classList.add('selected');
|
||||||
|
|
||||||
|
datastore[entryKey] = block.getAttribute('data-choice');
|
||||||
|
|
||||||
|
const numVotes = Object.keys(datastore).length;
|
||||||
|
document.getElementById("vote-info").innerText = `Total Votes: ${numVotes}`;
|
||||||
|
|
||||||
|
if (save) {
|
||||||
|
putDatastore(datastore); // Save entire datastore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleVote(button, save = true) {
|
||||||
|
let datastore = await fetchDatastore();
|
||||||
|
|
||||||
|
const entry = button.closest('.entry');
|
||||||
|
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
|
||||||
|
const choice = button.getAttribute('data-vote');
|
||||||
|
|
||||||
|
// Deselect any selected voting buttons within this entry
|
||||||
|
const voteButtons = entry.querySelectorAll('.voting-buttons button');
|
||||||
|
voteButtons.forEach(b => b.classList.remove('selected'));
|
||||||
|
|
||||||
|
// Select the clicked button
|
||||||
|
button.classList.add('selected');
|
||||||
|
|
||||||
|
// Deselect any selected text-blocks
|
||||||
|
const textBlocks = entry.querySelectorAll('.text-block');
|
||||||
|
textBlocks.forEach(b => b.classList.remove('selected'));
|
||||||
|
|
||||||
|
datastore[entryKey] = choice;
|
||||||
|
|
||||||
|
const numVotes = Object.keys(datastore).length;
|
||||||
|
document.getElementById("vote-info").innerText = `Total Votes: ${numVotes}`;
|
||||||
|
|
||||||
|
if (save) {
|
||||||
|
putDatastore(datastore); // Save entire datastore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleAdditionalVote(entry, choice, save = true) {
|
||||||
|
let datastore = await fetchDatastore();
|
||||||
|
|
||||||
|
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
|
||||||
|
|
||||||
|
// Select the appropriate voting button based on the choice
|
||||||
|
const voteButton = entry.querySelector(`.voting-buttons button[data-vote="${choice}"]`);
|
||||||
|
if (voteButton) {
|
||||||
|
// Deselect other voting buttons
|
||||||
|
const voteButtons = entry.querySelectorAll('.voting-buttons button');
|
||||||
|
voteButtons.forEach(b => b.classList.remove('selected'));
|
||||||
|
|
||||||
|
// Select the current button
|
||||||
|
voteButton.classList.add('selected');
|
||||||
|
}
|
||||||
|
|
||||||
|
datastore[entryKey] = choice;
|
||||||
|
|
||||||
|
if (save) {
|
||||||
|
putDatastore(datastore);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function updateReveal() {
|
||||||
|
let datastore = await fetchDatastore();
|
||||||
|
let goldVotes = 0;
|
||||||
|
let evalVotes = 0;
|
||||||
|
let bothGoodVotes = 0;
|
||||||
|
let bothBadVotes = 0;
|
||||||
|
let invalidPdfVotes = 0;
|
||||||
|
|
||||||
|
document.querySelectorAll('.entry').forEach(entry => {
|
||||||
|
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
|
||||||
|
const leftBlock = entry.querySelector('.text-block[data-choice="left"]');
|
||||||
|
const rightBlock = entry.querySelector('.text-block[data-choice="right"]');
|
||||||
|
|
||||||
|
const vote = datastore[entryKey];
|
||||||
|
if (vote === 'left') {
|
||||||
|
if (leftBlock.classList.contains('gold')) {
|
||||||
|
goldVotes++;
|
||||||
|
} else {
|
||||||
|
evalVotes++;
|
||||||
|
}
|
||||||
|
} else if (vote === 'right') {
|
||||||
|
if (rightBlock.classList.contains('gold')) {
|
||||||
|
goldVotes++;
|
||||||
|
} else {
|
||||||
|
evalVotes++;
|
||||||
|
}
|
||||||
|
} else if (vote === 'both_good') {
|
||||||
|
bothGoodVotes++;
|
||||||
|
} else if (vote === 'both_bad') {
|
||||||
|
bothBadVotes++;
|
||||||
|
} else if (vote === 'invalid_pdf') {
|
||||||
|
invalidPdfVotes++;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const totalVotes = goldVotes + evalVotes + bothGoodVotes + bothBadVotes + invalidPdfVotes;
|
||||||
|
const goldPercentage = totalVotes > 0 ? Math.round((goldVotes / totalVotes) * 100) : 0;
|
||||||
|
const evalPercentage = totalVotes > 0 ? Math.round((evalVotes / totalVotes) * 100) : 0;
|
||||||
|
const bothGoodPercentage = totalVotes > 0 ? Math.round((bothGoodVotes / totalVotes) * 100) : 0;
|
||||||
|
const bothBadPercentage = totalVotes > 0 ? Math.round((bothBadVotes / totalVotes) * 100) : 0;
|
||||||
|
const invalidPdfPercentage = totalVotes > 0 ? Math.round((invalidPdfVotes / totalVotes) * 100) : 0;
|
||||||
|
|
||||||
|
document.getElementById("vote-info").innerText = `Gold: ${goldPercentage}% | Eval: ${evalPercentage}% | Both Good: ${bothGoodPercentage}% | Both Bad: ${bothBadPercentage}% | Invalid PDF: ${invalidPdfPercentage}%`;
|
||||||
|
|
||||||
|
document.querySelectorAll('.entry').forEach(entry => {
|
||||||
|
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
|
||||||
|
const vote = datastore[entryKey];
|
||||||
|
if (vote === 'left' || vote === 'right') {
|
||||||
|
const selectedBlock = vote === 'left' ? entry.querySelector('.text-block[data-choice="left"]') : entry.querySelector('.text-block[data-choice="right"]');
|
||||||
|
selectedBlock.classList.add('selected');
|
||||||
|
}
|
||||||
|
// Additional votes already handled in handleAdditionalVote
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
258
pdelfin/eval/runeval.py
Normal file
258
pdelfin/eval/runeval.py
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
# This script will build a set of scores for the accuracy of a given pdf conversion tactic against a gold dataset
|
||||||
|
#
|
||||||
|
# You might need to pip install git+https://github.com/allenai/refine.git@soldni/eval-m
|
||||||
|
# in order to use some of the existing aligner scoring that was developed as part
|
||||||
|
# of the refiner pipeline
|
||||||
|
import boto3
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
|
import zstandard
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from smart_open import smart_open, register_compressor
|
||||||
|
from dolma_refine.evaluate.metrics import DocumentEditSimilarity
|
||||||
|
from dolma_refine.evaluate.segmenters import SpacySegmenter
|
||||||
|
from dolma_refine.evaluate.aligners import HirschbergAligner
|
||||||
|
|
||||||
|
from .evalhtml import create_review_html
|
||||||
|
|
||||||
|
|
||||||
|
CACHE_DIR = os.path.join(Path.home(), ".cache", "pdf_gold_data_cache")
|
||||||
|
|
||||||
|
s3_client = boto3.client('s3')
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_zst(file_obj, mode):
|
||||||
|
return zstandard.open(file_obj, mode)
|
||||||
|
|
||||||
|
register_compressor(".zstd", _handle_zst)
|
||||||
|
register_compressor(".zst", _handle_zst)
|
||||||
|
|
||||||
|
# Helper function to download files from S3
|
||||||
|
def download_from_s3(s3_path: str, local_path: str):
|
||||||
|
bucket_name, key = s3_path.replace("s3://", "").split("/", 1)
|
||||||
|
s3_client.download_file(bucket_name, key, local_path)
|
||||||
|
|
||||||
|
def is_debugging():
|
||||||
|
return sys.gettrace() is not None
|
||||||
|
|
||||||
|
# Create a hash to store file contents and check for changes
|
||||||
|
def compute_file_hash(file_path: str) -> str:
|
||||||
|
hash_md5 = hashlib.md5()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(4096), b""):
|
||||||
|
hash_md5.update(chunk)
|
||||||
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
# Load every .json file from GOLD_DATA_S3_PATH (and saves it to some temp folder for quick loading next time)
|
||||||
|
# returns map from "custom_id" ex. "s3://ai2-s2-pdfs/39ce/3db4516cd6e7d7f8e580a494c7a665a6a16a.pdf-4" (where the -4 means page 4)
|
||||||
|
# to the gold standard text
|
||||||
|
def load_gold_data(gold_data_path: str) -> dict:
|
||||||
|
if not os.path.exists(CACHE_DIR):
|
||||||
|
os.makedirs(CACHE_DIR)
|
||||||
|
|
||||||
|
gold_data = {}
|
||||||
|
|
||||||
|
# List the contents of the S3 bucket
|
||||||
|
bucket_name, prefix = gold_data_path.replace("s3://", "").split("/", 1)
|
||||||
|
paginator = s3_client.get_paginator('list_objects_v2')
|
||||||
|
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||||
|
|
||||||
|
for page in pages:
|
||||||
|
for obj in page.get('Contents', []):
|
||||||
|
s3_key = obj['Key']
|
||||||
|
if s3_key.endswith('.json'):
|
||||||
|
local_file_path = os.path.join(CACHE_DIR, os.path.basename(s3_key))
|
||||||
|
etag = obj['ETag'].strip('"') # ETag is the checksum
|
||||||
|
|
||||||
|
# Check if the file is already cached and verify its checksum
|
||||||
|
if os.path.exists(local_file_path):
|
||||||
|
cached_file_hash = compute_file_hash(local_file_path)
|
||||||
|
if cached_file_hash != etag:
|
||||||
|
raise ValueError(f"File {local_file_path} has changed on S3. Clear the cache in {CACHE_DIR} and reload.")
|
||||||
|
else:
|
||||||
|
# Download the file from S3 if not cached
|
||||||
|
download_from_s3(f"s3://{bucket_name}/{s3_key}", local_file_path)
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with smart_open(local_file_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
|
||||||
|
if "custom_id" in data:
|
||||||
|
# This is for loading gold data that came out of openai's batch API directly
|
||||||
|
custom_id = data["custom_id"]
|
||||||
|
text = data["response"]["body"]["choices"][0]["message"]["content"]
|
||||||
|
else:
|
||||||
|
# This is for loading gold data that went through the mise pdf refine pipeline
|
||||||
|
custom_id = data["s3_path"] + "-" + str(data["page"])
|
||||||
|
text = data["outputs"][0]["text"]
|
||||||
|
|
||||||
|
gold_data[custom_id] = text
|
||||||
|
|
||||||
|
print(f"Loaded {len(gold_data):,} gold data entries for comparison")
|
||||||
|
|
||||||
|
return gold_data
|
||||||
|
|
||||||
|
# Helper function to list all .jsonl files from a directory or an S3 bucket
|
||||||
|
def list_jsonl_files(path: str) -> list:
|
||||||
|
valid_endings = [".json", ".jsonl", ".json.zstd", ".jsonl.zstd"]
|
||||||
|
jsonl_files = []
|
||||||
|
|
||||||
|
if path.startswith("s3://"):
|
||||||
|
bucket_name, prefix = path.replace("s3://", "").split("/", 1)
|
||||||
|
paginator = s3_client.get_paginator('list_objects_v2')
|
||||||
|
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||||
|
|
||||||
|
for page in pages:
|
||||||
|
for obj in page.get('Contents', []):
|
||||||
|
if any(obj['Key'].endswith(ending) for ending in valid_endings):
|
||||||
|
jsonl_files.append(f"s3://{bucket_name}/{obj['Key']}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If it's a local directory, list all .jsonl files
|
||||||
|
for root, _, files in os.walk(path):
|
||||||
|
for file in files:
|
||||||
|
if any(file.endswith(ending) for ending in valid_endings):
|
||||||
|
jsonl_files.append(os.path.join(root, file))
|
||||||
|
|
||||||
|
return jsonl_files
|
||||||
|
|
||||||
|
# Takes in a path to a local directory or s3://[bucket]/[prefix path] where your jsonl files are stored
|
||||||
|
# This is most likely the output location of the refiner
|
||||||
|
# Expecting each jsonl line to include {s3_path: [path to original pdf], page: [pagenum], text: [proper page text]}
|
||||||
|
# Returns the average Levenshtein distance match between the data
|
||||||
|
def process_jsonl_file(jsonl_file, gold_data, comparer):
|
||||||
|
page_data = {}
|
||||||
|
total_alignment_score = 0
|
||||||
|
char_weighted_alignment_score = 0
|
||||||
|
total_pages = 0
|
||||||
|
total_chars = 0
|
||||||
|
|
||||||
|
with smart_open(jsonl_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
|
||||||
|
if "custom_id" in data:
|
||||||
|
goldkey = data["custom_id"]
|
||||||
|
data["s3_path"] = goldkey[:goldkey.rindex("-")]
|
||||||
|
data["page"] = int(goldkey[goldkey.rindex("-") + 1:])
|
||||||
|
else:
|
||||||
|
goldkey = data["s3_path"] + "-" + str(data["page"])
|
||||||
|
|
||||||
|
if goldkey not in gold_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
gold_text = gold_data[goldkey]
|
||||||
|
|
||||||
|
# You need to consider the case when no input is provided to the refiner, it will hallucinate
|
||||||
|
# So in that case we say there is no eval text
|
||||||
|
if len(data["text"].strip()) == 0:
|
||||||
|
eval_text = ""
|
||||||
|
else:
|
||||||
|
eval_text = data["outputs"][0]["text"][0]
|
||||||
|
|
||||||
|
# If the eval text or gold text is empty, we skip this page and don't use it for comparison
|
||||||
|
# It means that something was an OCR page, and the text-based pipeline just won't be able to handle that
|
||||||
|
if len(eval_text.strip()) < 10 or len(gold_text.strip()) < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
#eval_text = data["text"] # Uncomment to measure the raw input text to the refiner, without any refining happening
|
||||||
|
|
||||||
|
alignment = comparer.compute(gold_text, eval_text)
|
||||||
|
|
||||||
|
# print("GOLD_______________________________________")
|
||||||
|
# print(gold_text)
|
||||||
|
# print("EVAL________________________________________")
|
||||||
|
# print(eval_text)
|
||||||
|
# print("")
|
||||||
|
# print(f"Alignment: {alignment:.3f}")
|
||||||
|
# print("")
|
||||||
|
# input()
|
||||||
|
|
||||||
|
page_data[goldkey] = {
|
||||||
|
"s3_path": data["s3_path"],
|
||||||
|
"page": data["page"],
|
||||||
|
"gold_text": gold_text,
|
||||||
|
"eval_text": eval_text,
|
||||||
|
"alignment": alignment
|
||||||
|
}
|
||||||
|
|
||||||
|
total_alignment_score += alignment
|
||||||
|
char_weighted_alignment_score += alignment * len(gold_text)
|
||||||
|
total_chars += len(gold_text)
|
||||||
|
total_pages += 1
|
||||||
|
|
||||||
|
return total_alignment_score, char_weighted_alignment_score, total_chars, total_pages, page_data
|
||||||
|
|
||||||
|
def do_eval(gold_data_path: str, eval_data_path: str, ) -> tuple[float, list[dict]]:
|
||||||
|
gold_data = load_gold_data(gold_data_path)
|
||||||
|
|
||||||
|
total_alignment_score = 0
|
||||||
|
total_weight = 0
|
||||||
|
total_pages_compared = set()
|
||||||
|
|
||||||
|
page_eval_data = []
|
||||||
|
|
||||||
|
segmenter = SpacySegmenter("spacy")
|
||||||
|
aligner = HirschbergAligner(match_score=1,
|
||||||
|
mismatch_score=-1,
|
||||||
|
indel_score=-1)
|
||||||
|
comparer = DocumentEditSimilarity(segmenter=segmenter, aligner=aligner)
|
||||||
|
|
||||||
|
# List all .jsonl files in the directory or S3 bucket
|
||||||
|
jsonl_files = list_jsonl_files(eval_data_path)
|
||||||
|
|
||||||
|
if not jsonl_files:
|
||||||
|
raise ValueError("No .jsonl files found in the specified path.")
|
||||||
|
|
||||||
|
print(f"Found {len(jsonl_files):,} files to evaluate")
|
||||||
|
|
||||||
|
with ProcessPoolExecutor() if not is_debugging() else ThreadPoolExecutor() as executor:
|
||||||
|
# Prepare the future tasks
|
||||||
|
futures = [executor.submit(process_jsonl_file, jsonl_file, gold_data, comparer) for jsonl_file in jsonl_files]
|
||||||
|
|
||||||
|
# Process each future as it completes
|
||||||
|
for future in tqdm(as_completed(futures), total=len(jsonl_files)):
|
||||||
|
alignment_score, char_weighted_score, chars, pages, page_data = future.result() # Get the result of the completed task
|
||||||
|
|
||||||
|
# Aggregate statistics
|
||||||
|
total_alignment_score += char_weighted_score
|
||||||
|
total_weight += chars
|
||||||
|
total_pages_compared |= page_data.keys()
|
||||||
|
|
||||||
|
# Generate the eval data
|
||||||
|
for pd_key, pd in page_data.items():
|
||||||
|
if pd["alignment"] > 0.97:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(pd["gold_text"]) < 200 and len(pd["eval_text"]) < 200:
|
||||||
|
continue
|
||||||
|
|
||||||
|
page_eval_data.append(pd)
|
||||||
|
|
||||||
|
# Select random entries to return in the page_eval_data
|
||||||
|
page_eval_data = random.sample(page_eval_data, 20)
|
||||||
|
|
||||||
|
# Select the top 20 lowest alignments
|
||||||
|
# page_eval_data.sort(key=lambda x: x["alignment"])
|
||||||
|
# page_eval_data = page_eval_data[:20]
|
||||||
|
|
||||||
|
# Uncomment this to generate a nice review page to use with tinyhost
|
||||||
|
create_review_html(page_eval_data, filename="review_page.html")
|
||||||
|
|
||||||
|
print(f"Compared {len(total_pages_compared):,} pages")
|
||||||
|
print(f"Total corpus alignment: {total_alignment_score:.2f}")
|
||||||
|
print(f"Mean alignment: {total_alignment_score / total_weight:.3f}")
|
||||||
|
|
||||||
|
return total_alignment_score / total_weight, page_eval_data
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
result = do_eval(gold_data_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/",
|
||||||
|
eval_data_path="s3://ai2-oe-data/birr-dev/qwen2-vl/outputs/for-jake/2b/2024-09-24/")
|
Loading…
x
Reference in New Issue
Block a user