mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-01 10:33:57 +00:00
Adding some more options to prompt chatgpt
This commit is contained in:
parent
eabbe279fb
commit
7a638c74c9
@ -199,7 +199,7 @@ def main():
|
||||
parser.add_argument("--test_report", type=str, default=None, help="Generate an HTML report of test results. Provide a filename (e.g., results.html).")
|
||||
args = parser.parse_args()
|
||||
|
||||
input_folder = args.dir
|
||||
input_folder = args.dir if os.path.isdir(args.dir) else os.path.dirname(args.dir)
|
||||
n_bootstrap = args.bootstrap_samples
|
||||
ci_level = args.confidence_level
|
||||
pdf_folder = os.path.join(input_folder, "pdfs")
|
||||
@ -216,7 +216,11 @@ def main():
|
||||
|
||||
pdf_basenames = [os.path.relpath(p, pdf_folder) for p in all_pdf_files]
|
||||
|
||||
jsonl_files = glob.glob(os.path.join(input_folder, "*.jsonl"))
|
||||
if os.path.isfile(args.dir):
|
||||
jsonl_files = [args.dir]
|
||||
else:
|
||||
jsonl_files = glob.glob(os.path.join(input_folder, "*.jsonl"))
|
||||
|
||||
if not jsonl_files:
|
||||
print(f"Error: No .jsonl files found in {input_folder}.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
@ -1,18 +1,29 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from olmocr.bench.prompts import build_basic_prompt
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.prompts.prompts import (
|
||||
PageResponse,
|
||||
build_finetuning_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
openai_response_format_schema,
|
||||
)
|
||||
|
||||
|
||||
def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-06", temperature: float = 0.1) -> str:
|
||||
def run_chatgpt(
|
||||
pdf_path: str,
|
||||
page_num: int = 1,
|
||||
model: str = "gpt-4o-2024-08-06",
|
||||
temperature: float = 0.1,
|
||||
target_longest_image_dim: int = 2048,
|
||||
prompt_template: Literal["full", "basic", "finetune"] = "finetune",
|
||||
response_template: Literal["plain", "json"] = "json",
|
||||
) -> str:
|
||||
"""
|
||||
Convert page of a PDF file to markdown using the commercial openAI APIs.
|
||||
|
||||
@ -25,7 +36,7 @@ def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-0
|
||||
str: The OCR result in markdown format.
|
||||
"""
|
||||
# Convert the first page of the PDF to a base64-encoded PNG image.
|
||||
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048)
|
||||
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=target_longest_image_dim)
|
||||
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport")
|
||||
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
@ -33,20 +44,29 @@ def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-0
|
||||
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
if prompt_template == "full":
|
||||
prompt = build_openai_silver_data_prompt(anchor_text)
|
||||
elif prompt_template == "finetune":
|
||||
prompt = build_finetuning_prompt(anchor_text)
|
||||
elif prompt_template == "basic":
|
||||
prompt = build_basic_prompt()
|
||||
else:
|
||||
raise ValueError("Unknown prompt template")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=3000,
|
||||
response_format=openai_response_format_schema(),
|
||||
response_format=openai_response_format_schema() if response_template == "json" else None,
|
||||
)
|
||||
|
||||
raw_response = response.choices[0].message.content
|
||||
@ -55,7 +75,10 @@ def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-0
|
||||
assert response.choices[0].message.refusal is None
|
||||
assert response.choices[0].finish_reason == "stop"
|
||||
|
||||
data = json.loads(raw_response)
|
||||
data = PageResponse(**data)
|
||||
if response_template == "json":
|
||||
data = json.loads(raw_response)
|
||||
data = PageResponse(**data)
|
||||
|
||||
return data.natural_text
|
||||
return data.natural_text
|
||||
else:
|
||||
return raw_response
|
||||
|
||||
54
olmocr/bench/templates/all_done_latex.html
Normal file
54
olmocr/bench/templates/all_done_latex.html
Normal file
@ -0,0 +1,54 @@
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>All Done!</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background-color: #f5f5f5;
|
||||
}
|
||||
.container {
|
||||
text-align: center;
|
||||
padding: 30px;
|
||||
background-color: white;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
h1 {
|
||||
color: #28a745;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
p {
|
||||
font-size: 18px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
button {
|
||||
padding: 10px 20px;
|
||||
background-color: #007bff;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>All Done! 🎉</h1>
|
||||
<p>You have reviewed all equations in the dataset.</p>
|
||||
<form method="post" action="/next_pdf">
|
||||
<button type="submit">Start Over</button>
|
||||
</form>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
|
||||
<!-- You can adjust the viewport settings as needed -->
|
||||
<meta name="viewport" content="width=1200, initial-scale=1.0">
|
||||
<title>Equation Verification</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.4.120/pdf.min.js"></script>
|
||||
@ -30,15 +31,15 @@
|
||||
overflow: hidden;
|
||||
}
|
||||
.pdf-viewer {
|
||||
flex: 2;
|
||||
flex: 2; /* Increased from 1 to 2 to make PDF larger */
|
||||
border-right: 1px solid #ddd;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
/* Updated PDF container size */
|
||||
#pdf-container {
|
||||
width: 200%;
|
||||
height: 200%;
|
||||
width: 200%; /* New fixed width */
|
||||
height: 200%; /* New fixed height */
|
||||
overflow: auto;
|
||||
}
|
||||
#zoom-controls {
|
||||
@ -74,7 +75,7 @@
|
||||
.test-item.rejected {
|
||||
background-color: #f8d7da;
|
||||
}
|
||||
|
||||
/* The equation-display now stores the raw LaTeX in a data attribute */
|
||||
.equation-display {
|
||||
padding: 10px;
|
||||
margin: 5px 0;
|
||||
@ -82,7 +83,7 @@
|
||||
border-radius: 4px;
|
||||
background-color: #f9f9f9;
|
||||
overflow-x: auto;
|
||||
font-size: 1.2em;
|
||||
font-size: 1.2em; /* Larger font for equations */
|
||||
}
|
||||
.button-group {
|
||||
display: flex;
|
||||
@ -126,7 +127,7 @@
|
||||
background-color: #007bff;
|
||||
width: 0%;
|
||||
}
|
||||
|
||||
/* Make MathJax equations more visible */
|
||||
.MathJax {
|
||||
font-size: 120% !important;
|
||||
}
|
||||
@ -168,7 +169,7 @@
|
||||
<div class="tests-panel">
|
||||
<h3>Equations ({{ tests|length }})</h3>
|
||||
{% for test in tests %}
|
||||
|
||||
<!-- Added data-latex attribute to store raw LaTeX -->
|
||||
<div class="test-item {% if test.checked == 'verified' %}verified{% elif test.checked == 'rejected' %}rejected{% endif %}" id="test-{{ test.id }}">
|
||||
<div class="equation-display" data-latex="{{ test.text|e }}">
|
||||
{{ test.text|safe }}
|
||||
@ -176,6 +177,7 @@
|
||||
<div class="button-group">
|
||||
<button class="verify-button" onclick="updateTest('{{ test.id }}', '{{ test.pdf }}', 'checked', 'verified')">Verify</button>
|
||||
<button class="reject-button" onclick="updateTest('{{ test.id }}', '{{ test.pdf }}', 'checked', 'rejected')">Reject</button>
|
||||
<!-- New Edit button -->
|
||||
<button class="edit-button" onclick="enableEdit('{{ test.id }}', '{{ test.pdf }}')">Edit</button>
|
||||
</div>
|
||||
</div>
|
||||
@ -184,19 +186,23 @@
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Set up PDF.js
|
||||
pdfjsLib.GlobalWorkerOptions.workerSrc = 'https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.4.120/pdf.worker.min.js';
|
||||
|
||||
// Track current zoom level
|
||||
let currentScale = 2.0; // Initial larger scale
|
||||
let pdfDoc = null;
|
||||
let pageNum = 1;
|
||||
let canvas = null;
|
||||
|
||||
// Load the PDF
|
||||
const loadingTask = pdfjsLib.getDocument('{{ pdf_path }}');
|
||||
loadingTask.promise.then(function(pdf) {
|
||||
pdfDoc = pdf;
|
||||
renderPage(pageNum);
|
||||
});
|
||||
|
||||
// Function to render a page with the current scale
|
||||
function renderPage(num) {
|
||||
pdfDoc.getPage(num).then(function(page) {
|
||||
const viewport = page.getViewport({ scale: currentScale });
|
||||
@ -351,4 +357,5 @@
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
</html>
|
||||
|
||||
86
olmocr/loadertest.py
Normal file
86
olmocr/loadertest.py
Normal file
@ -0,0 +1,86 @@
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
import boto3
|
||||
from tqdm import tqdm
|
||||
|
||||
# Configuration
|
||||
BUCKET = "ai2-llm"
|
||||
PREFIX = "pretraining-data/sources/soldni-open-access-books/v0/pipeline/results"
|
||||
OUTPUT_FILENAME = "all_completed_files.txt"
|
||||
|
||||
|
||||
def process_file(key: str):
|
||||
"""
|
||||
Process a single S3 file given by its key.
|
||||
Reads a jsonl file from S3, decodes each line,
|
||||
extracts the 'Source-File' from the 'metadata' field,
|
||||
and returns a list of these source file strings.
|
||||
"""
|
||||
# Create a new S3 client in the worker thread (thread-safe)
|
||||
s3 = boto3.client("s3")
|
||||
extracted_lines = []
|
||||
try:
|
||||
response = s3.get_object(Bucket=BUCKET, Key=key)
|
||||
for raw_line in response["Body"].iter_lines():
|
||||
try:
|
||||
# Decode the line from bytes to text
|
||||
line_str = raw_line.decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"Skipping a line in {key} due to decode error: {e}")
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line_str)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Skipping a malformed json line in {key}: {e}")
|
||||
continue
|
||||
# Extract 'Source-File' from metadata if present
|
||||
metadata = data.get("metadata", {})
|
||||
source_file = metadata.get("Source-File")
|
||||
if source_file:
|
||||
extracted_lines.append(source_file)
|
||||
except Exception as e:
|
||||
print(f"Error processing file {key}: {e}")
|
||||
return extracted_lines
|
||||
|
||||
|
||||
def main():
|
||||
s3 = boto3.client("s3")
|
||||
paginator = s3.get_paginator("list_objects_v2")
|
||||
page_iterator = paginator.paginate(Bucket=BUCKET, Prefix=PREFIX)
|
||||
|
||||
# Gather all S3 object keys under the specified prefix
|
||||
keys = []
|
||||
for page in page_iterator:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
for obj in page["Contents"]:
|
||||
keys.append(obj["Key"])
|
||||
|
||||
print(f"Found {len(keys)} files to process.")
|
||||
|
||||
# Open the output file for writing
|
||||
with open(OUTPUT_FILENAME, "w", encoding="utf-8") as output_file:
|
||||
# Create a thread pool to process files concurrently.
|
||||
# Adjust max_workers based on your environment and workload.
|
||||
with ProcessPoolExecutor() as executor:
|
||||
# Submit all processing jobs and map each future to its key
|
||||
future_to_key = {executor.submit(process_file, key): key for key in keys}
|
||||
# Use tqdm to wrap the as_completed iterator for progress display
|
||||
for future in tqdm(as_completed(future_to_key), total=len(future_to_key), desc="Processing files"):
|
||||
try:
|
||||
source_files = future.result()
|
||||
# Write each extracted line to the output file as soon as the future completes
|
||||
for source in source_files:
|
||||
output_file.write(source + "\n")
|
||||
# Optionally flush after each completed task
|
||||
output_file.flush()
|
||||
except Exception as e:
|
||||
key = future_to_key[future]
|
||||
print(f"Exception occurred for file {key}: {e}")
|
||||
|
||||
print(f"Finished writing the source file names to {OUTPUT_FILENAME}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user