mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-10 15:52:31 +00:00
Pipeline scales temperature automatically, increases performance ~2%
This commit is contained in:
parent
4768ac4be5
commit
1f8cc59b22
@ -18,21 +18,19 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import boto3
|
||||
import pypdf
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from olmocr.bench.tests import TextPresenceTest, save_tests
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.filter import PdfFilter
|
||||
|
||||
|
||||
def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
|
||||
"""
|
||||
Download a PDF file from S3.
|
||||
@ -121,11 +119,7 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
|
||||
|
||||
# Render the PDF page as an image
|
||||
try:
|
||||
image_base64 = render_pdf_to_base64png(
|
||||
pdf_path,
|
||||
page_num=page_num + 1, # render_pdf_to_base64png is 1-indexed
|
||||
target_longest_image_dim=2048
|
||||
)
|
||||
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num + 1, target_longest_image_dim=2048) # render_pdf_to_base64png is 1-indexed
|
||||
except Exception as e:
|
||||
print(f"Error rendering PDF page: {str(e)}")
|
||||
return None
|
||||
@ -137,7 +131,9 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
|
||||
role="user",
|
||||
parts=[
|
||||
image_part,
|
||||
types.Part.from_text(text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""),
|
||||
types.Part.from_text(
|
||||
text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
@ -149,44 +145,35 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
|
||||
max_output_tokens=8192,
|
||||
response_mime_type="application/json",
|
||||
response_schema=genai.types.Schema(
|
||||
type = genai.types.Type.OBJECT,
|
||||
properties = {
|
||||
type=genai.types.Type.OBJECT,
|
||||
properties={
|
||||
"headers": genai.types.Schema(
|
||||
type = genai.types.Type.ARRAY,
|
||||
items = genai.types.Schema(
|
||||
type = genai.types.Type.STRING,
|
||||
type=genai.types.Type.ARRAY,
|
||||
items=genai.types.Schema(
|
||||
type=genai.types.Type.STRING,
|
||||
),
|
||||
),
|
||||
"footers": genai.types.Schema(
|
||||
type = genai.types.Type.ARRAY,
|
||||
items = genai.types.Schema(
|
||||
type = genai.types.Type.STRING,
|
||||
type=genai.types.Type.ARRAY,
|
||||
items=genai.types.Schema(
|
||||
type=genai.types.Type.STRING,
|
||||
),
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
response = client.models.generate_content(model=model,
|
||||
contents=contents,
|
||||
config=generate_content_config)
|
||||
response = client.models.generate_content(model=model, contents=contents, config=generate_content_config)
|
||||
|
||||
assert len(response.candidates) > 0, "No candidates found"
|
||||
assert (
|
||||
response.candidates[0].finish_reason == types.FinishReason.STOP
|
||||
), "Finish reason was not STOP, likely a processing error or repetition failure"
|
||||
assert response.candidates[0].finish_reason == types.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure"
|
||||
|
||||
data = json.loads(response.candidates[0].content.parts[0].text)
|
||||
|
||||
return data.get("headers", []) + data.get("footers", [])
|
||||
|
||||
def process_pdf(
|
||||
s3_path: str,
|
||||
temp_dir: str,
|
||||
output_dir: str,
|
||||
api_key: str,
|
||||
tests: List[TextPresenceTest]
|
||||
) -> None:
|
||||
|
||||
def process_pdf(s3_path: str, temp_dir: str, output_dir: str, api_key: str, tests: List[TextPresenceTest]) -> None:
|
||||
"""
|
||||
Process a single PDF from S3.
|
||||
|
||||
@ -242,7 +229,6 @@ def process_pdf(
|
||||
|
||||
# TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text
|
||||
|
||||
|
||||
# Create tests for each header/footer text
|
||||
for i, text in enumerate(header_footer_text):
|
||||
test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}"
|
||||
@ -300,11 +286,9 @@ def main():
|
||||
if tests:
|
||||
save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl"))
|
||||
|
||||
|
||||
if len(tests) > 100:
|
||||
break
|
||||
|
||||
|
||||
print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}")
|
||||
|
||||
|
||||
|
@ -141,7 +141,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
||||
}
|
||||
],
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"temperature": 0.8,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
|
||||
@ -213,7 +213,7 @@ async def apost(url, json_data):
|
||||
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||
MAX_RETRIES = args.max_page_retries
|
||||
|
||||
TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
||||
exponential_backoffs = 0
|
||||
local_anchor_text_len = args.target_anchor_text_len
|
||||
local_image_rotation = 0
|
||||
@ -222,6 +222,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
||||
|
||||
while attempt < MAX_RETRIES:
|
||||
query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation)
|
||||
query["temperature"] = TEMPERATURE_BY_ATTEMPT[
|
||||
min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
|
||||
] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
|
||||
|
||||
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user