Pipeline scales temperature automatically, increases performance ~2%

This commit is contained in:
Jake Poznanski 2025-03-14 22:27:51 -07:00
parent 4768ac4be5
commit 1f8cc59b22
2 changed files with 57 additions and 70 deletions

View File

@ -18,21 +18,19 @@ import base64
import json
import os
import random
import time
from pathlib import Path
from typing import Dict, List, Optional
from typing import List, Optional
import boto3
import pypdf
from google import genai
from google.genai import types
from tqdm import tqdm
from olmocr.bench.tests import TextPresenceTest, save_tests
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter import PdfFilter
def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
"""
Download a PDF file from S3.
@ -121,11 +119,7 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
# Render the PDF page as an image
try:
image_base64 = render_pdf_to_base64png(
pdf_path,
page_num=page_num + 1, # render_pdf_to_base64png is 1-indexed
target_longest_image_dim=2048
)
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num + 1, target_longest_image_dim=2048) # render_pdf_to_base64png is 1-indexed
except Exception as e:
print(f"Error rendering PDF page: {str(e)}")
return None
@ -137,7 +131,9 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
role="user",
parts=[
image_part,
types.Part.from_text(text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""),
types.Part.from_text(
text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""
),
],
),
]
@ -149,44 +145,35 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
max_output_tokens=8192,
response_mime_type="application/json",
response_schema=genai.types.Schema(
type = genai.types.Type.OBJECT,
properties = {
type=genai.types.Type.OBJECT,
properties={
"headers": genai.types.Schema(
type = genai.types.Type.ARRAY,
items = genai.types.Schema(
type = genai.types.Type.STRING,
type=genai.types.Type.ARRAY,
items=genai.types.Schema(
type=genai.types.Type.STRING,
),
),
"footers": genai.types.Schema(
type = genai.types.Type.ARRAY,
items = genai.types.Schema(
type = genai.types.Type.STRING,
type=genai.types.Type.ARRAY,
items=genai.types.Schema(
type=genai.types.Type.STRING,
),
),
},
),
)
response = client.models.generate_content(model=model,
contents=contents,
config=generate_content_config)
response = client.models.generate_content(model=model, contents=contents, config=generate_content_config)
assert len(response.candidates) > 0, "No candidates found"
assert (
response.candidates[0].finish_reason == types.FinishReason.STOP
), "Finish reason was not STOP, likely a processing error or repetition failure"
assert response.candidates[0].finish_reason == types.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure"
data = json.loads(response.candidates[0].content.parts[0].text)
return data.get("headers", []) + data.get("footers", [])
def process_pdf(
s3_path: str,
temp_dir: str,
output_dir: str,
api_key: str,
tests: List[TextPresenceTest]
) -> None:
def process_pdf(s3_path: str, temp_dir: str, output_dir: str, api_key: str, tests: List[TextPresenceTest]) -> None:
"""
Process a single PDF from S3.
@ -242,7 +229,6 @@ def process_pdf(
# TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text
# Create tests for each header/footer text
for i, text in enumerate(header_footer_text):
test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}"
@ -300,11 +286,9 @@ def main():
if tests:
save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl"))
if len(tests) > 100:
break
print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}")

View File

@ -141,7 +141,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
}
],
"max_tokens": MAX_TOKENS,
"temperature": 0.8,
"temperature": 0.0,
}
@ -213,7 +213,7 @@ async def apost(url, json_data):
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries
TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
exponential_backoffs = 0
local_anchor_text_len = args.target_anchor_text_len
local_image_rotation = 0
@ -222,6 +222,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
while attempt < MAX_RETRIES:
query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation)
query["temperature"] = TEMPERATURE_BY_ATTEMPT[
min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")