mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-11 07:58:10 +00:00
Adding some more options to prompt chatgpt
This commit is contained in:
parent
eabbe279fb
commit
7a638c74c9
@ -199,7 +199,7 @@ def main():
|
|||||||
parser.add_argument("--test_report", type=str, default=None, help="Generate an HTML report of test results. Provide a filename (e.g., results.html).")
|
parser.add_argument("--test_report", type=str, default=None, help="Generate an HTML report of test results. Provide a filename (e.g., results.html).")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
input_folder = args.dir
|
input_folder = args.dir if os.path.isdir(args.dir) else os.path.dirname(args.dir)
|
||||||
n_bootstrap = args.bootstrap_samples
|
n_bootstrap = args.bootstrap_samples
|
||||||
ci_level = args.confidence_level
|
ci_level = args.confidence_level
|
||||||
pdf_folder = os.path.join(input_folder, "pdfs")
|
pdf_folder = os.path.join(input_folder, "pdfs")
|
||||||
@ -216,7 +216,11 @@ def main():
|
|||||||
|
|
||||||
pdf_basenames = [os.path.relpath(p, pdf_folder) for p in all_pdf_files]
|
pdf_basenames = [os.path.relpath(p, pdf_folder) for p in all_pdf_files]
|
||||||
|
|
||||||
|
if os.path.isfile(args.dir):
|
||||||
|
jsonl_files = [args.dir]
|
||||||
|
else:
|
||||||
jsonl_files = glob.glob(os.path.join(input_folder, "*.jsonl"))
|
jsonl_files = glob.glob(os.path.join(input_folder, "*.jsonl"))
|
||||||
|
|
||||||
if not jsonl_files:
|
if not jsonl_files:
|
||||||
print(f"Error: No .jsonl files found in {input_folder}.", file=sys.stderr)
|
print(f"Error: No .jsonl files found in {input_folder}.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@ -1,18 +1,29 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from olmocr.bench.prompts import build_basic_prompt
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
from olmocr.prompts.anchor import get_anchor_text
|
from olmocr.prompts.anchor import get_anchor_text
|
||||||
from olmocr.prompts.prompts import (
|
from olmocr.prompts.prompts import (
|
||||||
PageResponse,
|
PageResponse,
|
||||||
|
build_finetuning_prompt,
|
||||||
build_openai_silver_data_prompt,
|
build_openai_silver_data_prompt,
|
||||||
openai_response_format_schema,
|
openai_response_format_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-06", temperature: float = 0.1) -> str:
|
def run_chatgpt(
|
||||||
|
pdf_path: str,
|
||||||
|
page_num: int = 1,
|
||||||
|
model: str = "gpt-4o-2024-08-06",
|
||||||
|
temperature: float = 0.1,
|
||||||
|
target_longest_image_dim: int = 2048,
|
||||||
|
prompt_template: Literal["full", "basic", "finetune"] = "finetune",
|
||||||
|
response_template: Literal["plain", "json"] = "json",
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert page of a PDF file to markdown using the commercial openAI APIs.
|
Convert page of a PDF file to markdown using the commercial openAI APIs.
|
||||||
|
|
||||||
@ -25,7 +36,7 @@ def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-0
|
|||||||
str: The OCR result in markdown format.
|
str: The OCR result in markdown format.
|
||||||
"""
|
"""
|
||||||
# Convert the first page of the PDF to a base64-encoded PNG image.
|
# Convert the first page of the PDF to a base64-encoded PNG image.
|
||||||
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048)
|
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=target_longest_image_dim)
|
||||||
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport")
|
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport")
|
||||||
|
|
||||||
if not os.getenv("OPENAI_API_KEY"):
|
if not os.getenv("OPENAI_API_KEY"):
|
||||||
@ -33,20 +44,29 @@ def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-0
|
|||||||
|
|
||||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
|
||||||
|
if prompt_template == "full":
|
||||||
|
prompt = build_openai_silver_data_prompt(anchor_text)
|
||||||
|
elif prompt_template == "finetune":
|
||||||
|
prompt = build_finetuning_prompt(anchor_text)
|
||||||
|
elif prompt_template == "basic":
|
||||||
|
prompt = build_basic_prompt()
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown prompt template")
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
|
{"type": "text", "text": prompt},
|
||||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=3000,
|
max_tokens=3000,
|
||||||
response_format=openai_response_format_schema(),
|
response_format=openai_response_format_schema() if response_template == "json" else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = response.choices[0].message.content
|
raw_response = response.choices[0].message.content
|
||||||
@ -55,7 +75,10 @@ def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-0
|
|||||||
assert response.choices[0].message.refusal is None
|
assert response.choices[0].message.refusal is None
|
||||||
assert response.choices[0].finish_reason == "stop"
|
assert response.choices[0].finish_reason == "stop"
|
||||||
|
|
||||||
|
if response_template == "json":
|
||||||
data = json.loads(raw_response)
|
data = json.loads(raw_response)
|
||||||
data = PageResponse(**data)
|
data = PageResponse(**data)
|
||||||
|
|
||||||
return data.natural_text
|
return data.natural_text
|
||||||
|
else:
|
||||||
|
return raw_response
|
||||||
|
|||||||
54
olmocr/bench/templates/all_done_latex.html
Normal file
54
olmocr/bench/templates/all_done_latex.html
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>All Done!</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
height: 100vh;
|
||||||
|
margin: 0;
|
||||||
|
background-color: #f5f5f5;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
text-align: center;
|
||||||
|
padding: 30px;
|
||||||
|
background-color: white;
|
||||||
|
border-radius: 10px;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
color: #28a745;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
p {
|
||||||
|
font-size: 18px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
padding: 10px 20px;
|
||||||
|
background-color: #007bff;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
border-radius: 5px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 16px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>All Done! 🎉</h1>
|
||||||
|
<p>You have reviewed all equations in the dataset.</p>
|
||||||
|
<form method="post" action="/next_pdf">
|
||||||
|
<button type="submit">Start Over</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
|
||||||
@ -1,8 +1,9 @@
|
|||||||
|
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
|
<!-- You can adjust the viewport settings as needed -->
|
||||||
<meta name="viewport" content="width=1200, initial-scale=1.0">
|
<meta name="viewport" content="width=1200, initial-scale=1.0">
|
||||||
<title>Equation Verification</title>
|
<title>Equation Verification</title>
|
||||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.4.120/pdf.min.js"></script>
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.4.120/pdf.min.js"></script>
|
||||||
@ -30,15 +31,15 @@
|
|||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
.pdf-viewer {
|
.pdf-viewer {
|
||||||
flex: 2;
|
flex: 2; /* Increased from 1 to 2 to make PDF larger */
|
||||||
border-right: 1px solid #ddd;
|
border-right: 1px solid #ddd;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
position: relative;
|
position: relative;
|
||||||
}
|
}
|
||||||
|
/* Updated PDF container size */
|
||||||
#pdf-container {
|
#pdf-container {
|
||||||
width: 200%;
|
width: 200%; /* New fixed width */
|
||||||
height: 200%;
|
height: 200%; /* New fixed height */
|
||||||
overflow: auto;
|
overflow: auto;
|
||||||
}
|
}
|
||||||
#zoom-controls {
|
#zoom-controls {
|
||||||
@ -74,7 +75,7 @@
|
|||||||
.test-item.rejected {
|
.test-item.rejected {
|
||||||
background-color: #f8d7da;
|
background-color: #f8d7da;
|
||||||
}
|
}
|
||||||
|
/* The equation-display now stores the raw LaTeX in a data attribute */
|
||||||
.equation-display {
|
.equation-display {
|
||||||
padding: 10px;
|
padding: 10px;
|
||||||
margin: 5px 0;
|
margin: 5px 0;
|
||||||
@ -82,7 +83,7 @@
|
|||||||
border-radius: 4px;
|
border-radius: 4px;
|
||||||
background-color: #f9f9f9;
|
background-color: #f9f9f9;
|
||||||
overflow-x: auto;
|
overflow-x: auto;
|
||||||
font-size: 1.2em;
|
font-size: 1.2em; /* Larger font for equations */
|
||||||
}
|
}
|
||||||
.button-group {
|
.button-group {
|
||||||
display: flex;
|
display: flex;
|
||||||
@ -126,7 +127,7 @@
|
|||||||
background-color: #007bff;
|
background-color: #007bff;
|
||||||
width: 0%;
|
width: 0%;
|
||||||
}
|
}
|
||||||
|
/* Make MathJax equations more visible */
|
||||||
.MathJax {
|
.MathJax {
|
||||||
font-size: 120% !important;
|
font-size: 120% !important;
|
||||||
}
|
}
|
||||||
@ -168,7 +169,7 @@
|
|||||||
<div class="tests-panel">
|
<div class="tests-panel">
|
||||||
<h3>Equations ({{ tests|length }})</h3>
|
<h3>Equations ({{ tests|length }})</h3>
|
||||||
{% for test in tests %}
|
{% for test in tests %}
|
||||||
|
<!-- Added data-latex attribute to store raw LaTeX -->
|
||||||
<div class="test-item {% if test.checked == 'verified' %}verified{% elif test.checked == 'rejected' %}rejected{% endif %}" id="test-{{ test.id }}">
|
<div class="test-item {% if test.checked == 'verified' %}verified{% elif test.checked == 'rejected' %}rejected{% endif %}" id="test-{{ test.id }}">
|
||||||
<div class="equation-display" data-latex="{{ test.text|e }}">
|
<div class="equation-display" data-latex="{{ test.text|e }}">
|
||||||
{{ test.text|safe }}
|
{{ test.text|safe }}
|
||||||
@ -176,6 +177,7 @@
|
|||||||
<div class="button-group">
|
<div class="button-group">
|
||||||
<button class="verify-button" onclick="updateTest('{{ test.id }}', '{{ test.pdf }}', 'checked', 'verified')">Verify</button>
|
<button class="verify-button" onclick="updateTest('{{ test.id }}', '{{ test.pdf }}', 'checked', 'verified')">Verify</button>
|
||||||
<button class="reject-button" onclick="updateTest('{{ test.id }}', '{{ test.pdf }}', 'checked', 'rejected')">Reject</button>
|
<button class="reject-button" onclick="updateTest('{{ test.id }}', '{{ test.pdf }}', 'checked', 'rejected')">Reject</button>
|
||||||
|
<!-- New Edit button -->
|
||||||
<button class="edit-button" onclick="enableEdit('{{ test.id }}', '{{ test.pdf }}')">Edit</button>
|
<button class="edit-button" onclick="enableEdit('{{ test.id }}', '{{ test.pdf }}')">Edit</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@ -184,19 +186,23 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
|
// Set up PDF.js
|
||||||
pdfjsLib.GlobalWorkerOptions.workerSrc = 'https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.4.120/pdf.worker.min.js';
|
pdfjsLib.GlobalWorkerOptions.workerSrc = 'https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.4.120/pdf.worker.min.js';
|
||||||
|
|
||||||
|
// Track current zoom level
|
||||||
let currentScale = 2.0; // Initial larger scale
|
let currentScale = 2.0; // Initial larger scale
|
||||||
let pdfDoc = null;
|
let pdfDoc = null;
|
||||||
let pageNum = 1;
|
let pageNum = 1;
|
||||||
let canvas = null;
|
let canvas = null;
|
||||||
|
|
||||||
|
// Load the PDF
|
||||||
const loadingTask = pdfjsLib.getDocument('{{ pdf_path }}');
|
const loadingTask = pdfjsLib.getDocument('{{ pdf_path }}');
|
||||||
loadingTask.promise.then(function(pdf) {
|
loadingTask.promise.then(function(pdf) {
|
||||||
pdfDoc = pdf;
|
pdfDoc = pdf;
|
||||||
renderPage(pageNum);
|
renderPage(pageNum);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Function to render a page with the current scale
|
||||||
function renderPage(num) {
|
function renderPage(num) {
|
||||||
pdfDoc.getPage(num).then(function(page) {
|
pdfDoc.getPage(num).then(function(page) {
|
||||||
const viewport = page.getViewport({ scale: currentScale });
|
const viewport = page.getViewport({ scale: currentScale });
|
||||||
@ -352,3 +358,4 @@
|
|||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|
||||||
86
olmocr/loadertest.py
Normal file
86
olmocr/loadertest.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import json
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
BUCKET = "ai2-llm"
|
||||||
|
PREFIX = "pretraining-data/sources/soldni-open-access-books/v0/pipeline/results"
|
||||||
|
OUTPUT_FILENAME = "all_completed_files.txt"
|
||||||
|
|
||||||
|
|
||||||
|
def process_file(key: str):
|
||||||
|
"""
|
||||||
|
Process a single S3 file given by its key.
|
||||||
|
Reads a jsonl file from S3, decodes each line,
|
||||||
|
extracts the 'Source-File' from the 'metadata' field,
|
||||||
|
and returns a list of these source file strings.
|
||||||
|
"""
|
||||||
|
# Create a new S3 client in the worker thread (thread-safe)
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
extracted_lines = []
|
||||||
|
try:
|
||||||
|
response = s3.get_object(Bucket=BUCKET, Key=key)
|
||||||
|
for raw_line in response["Body"].iter_lines():
|
||||||
|
try:
|
||||||
|
# Decode the line from bytes to text
|
||||||
|
line_str = raw_line.decode("utf-8")
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
print(f"Skipping a line in {key} due to decode error: {e}")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line_str)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Skipping a malformed json line in {key}: {e}")
|
||||||
|
continue
|
||||||
|
# Extract 'Source-File' from metadata if present
|
||||||
|
metadata = data.get("metadata", {})
|
||||||
|
source_file = metadata.get("Source-File")
|
||||||
|
if source_file:
|
||||||
|
extracted_lines.append(source_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing file {key}: {e}")
|
||||||
|
return extracted_lines
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
paginator = s3.get_paginator("list_objects_v2")
|
||||||
|
page_iterator = paginator.paginate(Bucket=BUCKET, Prefix=PREFIX)
|
||||||
|
|
||||||
|
# Gather all S3 object keys under the specified prefix
|
||||||
|
keys = []
|
||||||
|
for page in page_iterator:
|
||||||
|
if "Contents" not in page:
|
||||||
|
continue
|
||||||
|
for obj in page["Contents"]:
|
||||||
|
keys.append(obj["Key"])
|
||||||
|
|
||||||
|
print(f"Found {len(keys)} files to process.")
|
||||||
|
|
||||||
|
# Open the output file for writing
|
||||||
|
with open(OUTPUT_FILENAME, "w", encoding="utf-8") as output_file:
|
||||||
|
# Create a thread pool to process files concurrently.
|
||||||
|
# Adjust max_workers based on your environment and workload.
|
||||||
|
with ProcessPoolExecutor() as executor:
|
||||||
|
# Submit all processing jobs and map each future to its key
|
||||||
|
future_to_key = {executor.submit(process_file, key): key for key in keys}
|
||||||
|
# Use tqdm to wrap the as_completed iterator for progress display
|
||||||
|
for future in tqdm(as_completed(future_to_key), total=len(future_to_key), desc="Processing files"):
|
||||||
|
try:
|
||||||
|
source_files = future.result()
|
||||||
|
# Write each extracted line to the output file as soon as the future completes
|
||||||
|
for source in source_files:
|
||||||
|
output_file.write(source + "\n")
|
||||||
|
# Optionally flush after each completed task
|
||||||
|
output_file.flush()
|
||||||
|
except Exception as e:
|
||||||
|
key = future_to_key[future]
|
||||||
|
print(f"Exception occurred for file {key}: {e}")
|
||||||
|
|
||||||
|
print(f"Finished writing the source file names to {OUTPUT_FILENAME}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user