Pipeline scales temperature automatically, increases performance ~2%

This commit is contained in:
Jake Poznanski 2025-03-14 22:27:51 -07:00
parent 4768ac4be5
commit 1f8cc59b22
2 changed files with 57 additions and 70 deletions

View File

@ -18,21 +18,19 @@ import base64
import json import json
import os import os
import random import random
import time from typing import List, Optional
from pathlib import Path
from typing import Dict, List, Optional
import boto3 import boto3
import pypdf import pypdf
from google import genai from google import genai
from google.genai import types from google.genai import types
from tqdm import tqdm from tqdm import tqdm
from olmocr.bench.tests import TextPresenceTest, save_tests from olmocr.bench.tests import TextPresenceTest, save_tests
from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter import PdfFilter from olmocr.filter import PdfFilter
def download_pdf_from_s3(s3_path: str, local_path: str) -> bool: def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
""" """
Download a PDF file from S3. Download a PDF file from S3.
@ -121,11 +119,7 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
# Render the PDF page as an image # Render the PDF page as an image
try: try:
image_base64 = render_pdf_to_base64png( image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num + 1, target_longest_image_dim=2048) # render_pdf_to_base64png is 1-indexed
pdf_path,
page_num=page_num + 1, # render_pdf_to_base64png is 1-indexed
target_longest_image_dim=2048
)
except Exception as e: except Exception as e:
print(f"Error rendering PDF page: {str(e)}") print(f"Error rendering PDF page: {str(e)}")
return None return None
@ -137,7 +131,9 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
role="user", role="user",
parts=[ parts=[
image_part, image_part,
types.Part.from_text(text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""), types.Part.from_text(
text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""
),
], ],
), ),
] ]
@ -149,44 +145,35 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
max_output_tokens=8192, max_output_tokens=8192,
response_mime_type="application/json", response_mime_type="application/json",
response_schema=genai.types.Schema( response_schema=genai.types.Schema(
type = genai.types.Type.OBJECT, type=genai.types.Type.OBJECT,
properties = { properties={
"headers": genai.types.Schema( "headers": genai.types.Schema(
type = genai.types.Type.ARRAY, type=genai.types.Type.ARRAY,
items = genai.types.Schema( items=genai.types.Schema(
type = genai.types.Type.STRING, type=genai.types.Type.STRING,
), ),
), ),
"footers": genai.types.Schema( "footers": genai.types.Schema(
type = genai.types.Type.ARRAY, type=genai.types.Type.ARRAY,
items = genai.types.Schema( items=genai.types.Schema(
type = genai.types.Type.STRING, type=genai.types.Type.STRING,
), ),
), ),
}, },
), ),
) )
response = client.models.generate_content(model=model, response = client.models.generate_content(model=model, contents=contents, config=generate_content_config)
contents=contents,
config=generate_content_config)
assert len(response.candidates) > 0, "No candidates found" assert len(response.candidates) > 0, "No candidates found"
assert ( assert response.candidates[0].finish_reason == types.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure"
response.candidates[0].finish_reason == types.FinishReason.STOP
), "Finish reason was not STOP, likely a processing error or repetition failure"
data = json.loads(response.candidates[0].content.parts[0].text) data = json.loads(response.candidates[0].content.parts[0].text)
return data.get("headers", []) + data.get("footers", []) return data.get("headers", []) + data.get("footers", [])
def process_pdf(
s3_path: str, def process_pdf(s3_path: str, temp_dir: str, output_dir: str, api_key: str, tests: List[TextPresenceTest]) -> None:
temp_dir: str,
output_dir: str,
api_key: str,
tests: List[TextPresenceTest]
) -> None:
""" """
Process a single PDF from S3. Process a single PDF from S3.
@ -242,7 +229,6 @@ def process_pdf(
# TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text # TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text
# Create tests for each header/footer text # Create tests for each header/footer text
for i, text in enumerate(header_footer_text): for i, text in enumerate(header_footer_text):
test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}" test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}"
@ -300,11 +286,9 @@ def main():
if tests: if tests:
save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl")) save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl"))
if len(tests) > 100: if len(tests) > 100:
break break
print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}") print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}")

View File

@ -141,7 +141,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
} }
], ],
"max_tokens": MAX_TOKENS, "max_tokens": MAX_TOKENS,
"temperature": 0.8, "temperature": 0.0,
} }
@ -213,7 +213,7 @@ async def apost(url, json_data):
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult: async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions" COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries MAX_RETRIES = args.max_page_retries
TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
exponential_backoffs = 0 exponential_backoffs = 0
local_anchor_text_len = args.target_anchor_text_len local_anchor_text_len = args.target_anchor_text_len
local_image_rotation = 0 local_image_rotation = 0
@ -222,6 +222,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
while attempt < MAX_RETRIES: while attempt < MAX_RETRIES:
query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation) query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation)
query["temperature"] = TEMPERATURE_BY_ATTEMPT[
min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
logger.info(f"Built page query for {pdf_orig_path}-{page_num}") logger.info(f"Built page query for {pdf_orig_path}-{page_num}")