Pipeline scales temperature automatically, increases performance ~2%

This commit is contained in:
Jake Poznanski 2025-03-14 22:27:51 -07:00
parent 4768ac4be5
commit 1f8cc59b22
2 changed files with 57 additions and 70 deletions

View File

@ -18,29 +18,27 @@ import base64
import json
import os
import random
import time
from pathlib import Path
from typing import Dict, List, Optional
from typing import List, Optional
import boto3
import pypdf
from google import genai
from google.genai import types
from tqdm import tqdm
from olmocr.bench.tests import TextPresenceTest, save_tests
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter import PdfFilter
def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
"""
Download a PDF file from S3.
Args:
s3_path: The S3 path (s3://bucket/path/to/file.pdf)
local_path: The local path to save the file
Returns:
bool: True if download was successful, False otherwise
"""
@ -49,13 +47,13 @@ def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
parts = s3_path.replace("s3://", "").split("/", 1)
bucket = parts[0]
key = parts[1]
# Create S3 client
s3 = boto3.client("s3")
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(local_path), exist_ok=True)
# Download file
s3.download_file(bucket, key, local_path)
return True
@ -67,35 +65,35 @@ def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> bool:
"""
Extract a specific page from a PDF and save it as a new PDF.
Args:
input_path: Path to the input PDF
output_path: Path to save the extracted page
page_num: The page number to extract (0-indexed)
Returns:
bool: True if extraction was successful, False otherwise
"""
try:
# Ensure output directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Read the input PDF
reader = pypdf.PdfReader(input_path)
# Check if page number is valid
if page_num >= len(reader.pages):
print(f"Page number {page_num} out of range for {input_path} with {len(reader.pages)} pages")
return False
# Create a new PDF with just the selected page
writer = pypdf.PdfWriter()
writer.add_page(reader.pages[page_num])
# Write the output PDF
with open(output_path, "wb") as output_file:
writer.write(output_file)
return True
except Exception as e:
print(f"Error extracting page {page_num} from {input_path}: {str(e)}")
@ -105,12 +103,12 @@ def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> b
def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Optional[List[str]]:
"""
Use Gemini to detect headers and footers in a rendered PDF page.
Args:
pdf_path: Path to the PDF file
page_num: The page number to analyze (0-indexed)
api_key: Gemini API key
Returns:
Optional[List[str]]: List of detected header/footer texts, or None if detection failed
"""
@ -121,23 +119,21 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
# Render the PDF page as an image
try:
image_base64 = render_pdf_to_base64png(
pdf_path,
page_num=page_num + 1, # render_pdf_to_base64png is 1-indexed
target_longest_image_dim=2048
)
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num + 1, target_longest_image_dim=2048) # render_pdf_to_base64png is 1-indexed
except Exception as e:
print(f"Error rendering PDF page: {str(e)}")
return None
image_part = types.Part(inline_data=types.Blob(mime_type="image/png", data=base64.b64decode(image_base64)))
contents = [
types.Content(
role="user",
parts=[
image_part,
types.Part.from_text(text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""),
types.Part.from_text(
text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""
),
],
),
]
@ -149,47 +145,38 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
max_output_tokens=8192,
response_mime_type="application/json",
response_schema=genai.types.Schema(
type = genai.types.Type.OBJECT,
properties = {
type=genai.types.Type.OBJECT,
properties={
"headers": genai.types.Schema(
type = genai.types.Type.ARRAY,
items = genai.types.Schema(
type = genai.types.Type.STRING,
type=genai.types.Type.ARRAY,
items=genai.types.Schema(
type=genai.types.Type.STRING,
),
),
"footers": genai.types.Schema(
type = genai.types.Type.ARRAY,
items = genai.types.Schema(
type = genai.types.Type.STRING,
type=genai.types.Type.ARRAY,
items=genai.types.Schema(
type=genai.types.Type.STRING,
),
),
},
),
)
response = client.models.generate_content(model=model,
contents=contents,
config=generate_content_config)
response = client.models.generate_content(model=model, contents=contents, config=generate_content_config)
assert len(response.candidates) > 0, "No candidates found"
assert (
response.candidates[0].finish_reason == types.FinishReason.STOP
), "Finish reason was not STOP, likely a processing error or repetition failure"
assert response.candidates[0].finish_reason == types.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure"
data = json.loads(response.candidates[0].content.parts[0].text)
return data.get("headers", []) + data.get("footers", [])
def process_pdf(
s3_path: str,
temp_dir: str,
output_dir: str,
api_key: str,
tests: List[TextPresenceTest]
) -> None:
def process_pdf(s3_path: str, temp_dir: str, output_dir: str, api_key: str, tests: List[TextPresenceTest]) -> None:
"""
Process a single PDF from S3.
Args:
s3_path: S3 path to the PDF
temp_dir: Directory for temporary files
@ -200,7 +187,7 @@ def process_pdf(
# Extract filename from S3 path
pdf_filename = os.path.basename(s3_path)
local_pdf_path = os.path.join(temp_dir, pdf_filename)
# Download PDF from S3
if not download_pdf_from_s3(s3_path, local_pdf_path):
return
@ -210,16 +197,16 @@ def process_pdf(
if pdf_filter.filter_out_pdf(local_pdf_path):
print("Filtering out", pdf_filename)
return
try:
# Read the PDF to get the number of pages
reader = pypdf.PdfReader(local_pdf_path)
num_pages = len(reader.pages)
if num_pages == 0:
print(f"PDF {pdf_filename} has no pages")
return
all_pages = list(range(len(reader.pages)))
random.shuffle(all_pages)
@ -229,20 +216,19 @@ def process_pdf(
# Only stick with headers and footers that have some actual data in them
header_footer_text = [x for x in header_footer_text if len(x.strip()) > 3]
if not header_footer_text:
print(f"No headers/footers detected in {pdf_filename} page {page_num}")
continue
# Extract the page and save to output dir
pdf_basename = os.path.splitext(pdf_filename)[0]
output_pdf_path = os.path.join(output_dir, "pdfs", f"{pdf_basename}_pg{page_num+1}.pdf")
extract_page_from_pdf(local_pdf_path, output_pdf_path, page_num)
# TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text
# Create tests for each header/footer text
for i, text in enumerate(header_footer_text):
test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}"
@ -255,10 +241,10 @@ def process_pdf(
max_diffs=0,
)
tests.append(test)
print(f"Processed {pdf_filename} page {page_num+1}, found {len(header_footer_text)} headers/footers")
return
except Exception as e:
print(f"Error processing {pdf_filename}: {str(e)}")
finally:
@ -274,39 +260,37 @@ def main():
parser.add_argument("--api_key", help="Gemini API key (if not provided, will use GEMINI_API_KEY environment variable)")
parser.add_argument("--temp_dir", default="/tmp/mine_headers_footers", help="Directory for temporary files")
args = parser.parse_args()
# Get API key
api_key = args.api_key or os.environ.get("GEMINI_API_KEY")
if not api_key:
print("Error: Gemini API key not provided. Use --api_key or set GEMINI_API_KEY environment variable.")
return
# Create directories
os.makedirs(args.temp_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "pdfs"), exist_ok=True)
# Read input list
with open(args.input_list, "r") as f:
s3_paths = [line.strip() for line in f if line.strip()]
print(f"Found {len(s3_paths)} PDF paths in input list")
# Process each PDF
tests = []
for s3_path in tqdm(s3_paths, desc="Processing PDFs"):
process_pdf(s3_path, args.temp_dir, args.output_dir, api_key, tests)
# Save tests after each PDF to avoid losing data in case of crashes
if tests:
save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl"))
if len(tests) > 100:
break
print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}")
if __name__ == "__main__":
main()
main()

View File

@ -141,7 +141,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
}
],
"max_tokens": MAX_TOKENS,
"temperature": 0.8,
"temperature": 0.0,
}
@ -213,7 +213,7 @@ async def apost(url, json_data):
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries
TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
exponential_backoffs = 0
local_anchor_text_len = args.target_anchor_text_len
local_image_rotation = 0
@ -222,6 +222,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
while attempt < MAX_RETRIES:
query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation)
query["temperature"] = TEMPERATURE_BY_ATTEMPT[
min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")