2024-11-07 00:03:30 +00:00
import argparse
2024-11-08 08:14:20 -08:00
import asyncio
2025-01-29 15:25:10 -08:00
import atexit
import base64
2024-11-11 11:58:45 -08:00
import datetime
2025-01-29 15:25:10 -08:00
import hashlib
import json
import logging
2024-11-23 21:41:49 +00:00
import multiprocessing
2025-01-29 15:25:10 -08:00
import os
import random
import re
import shutil
import sys
import tempfile
import time
from concurrent . futures import ProcessPoolExecutor , ThreadPoolExecutor , as_completed
from concurrent . futures . process import BrokenProcessPool
from dataclasses import dataclass
2025-07-23 03:40:05 +00:00
from functools import cache
2025-01-29 15:25:10 -08:00
from io import BytesIO
2024-12-03 10:48:52 -08:00
from urllib . parse import urlparse
2025-01-29 15:25:10 -08:00
import boto3
import httpx
2024-12-04 17:56:45 +00:00
from botocore . exceptions import ClientError
2025-03-03 13:42:13 -08:00
from huggingface_hub import snapshot_download
2024-11-07 23:24:01 +00:00
from PIL import Image
2024-11-08 11:04:58 -08:00
from pypdf import PdfReader
2025-01-29 15:25:10 -08:00
from tqdm import tqdm
2024-11-07 18:21:23 +00:00
2025-01-29 15:25:10 -08:00
from olmocr . check import (
check_poppler_version ,
check_torch_gpu_available ,
)
2025-01-27 18:30:41 +00:00
from olmocr . data . renderpdf import render_pdf_to_base64png
2025-01-29 15:25:10 -08:00
from olmocr . filter . filter import Language , PdfFilter
2025-03-31 13:28:30 -07:00
from olmocr . image_utils import convert_image_to_pdf_bytes , is_jpeg , is_png
2025-01-27 18:30:41 +00:00
from olmocr . metrics import MetricsKeeper , WorkerTracker
2025-07-15 18:00:01 +00:00
from olmocr . prompts import PageResponse , build_no_anchoring_yaml_prompt
2025-01-29 15:25:10 -08:00
from olmocr . prompts . anchor import get_anchor_text
from olmocr . s3_utils import (
2025-04-17 09:59:28 -07:00
download_directory ,
2025-01-29 15:25:10 -08:00
download_zstd_csv ,
expand_s3_glob ,
get_s3_bytes ,
get_s3_bytes_with_backoff ,
parse_s3_path ,
)
2025-07-23 03:40:05 +00:00
from olmocr . train . dataloader import FrontMatterParser
2025-01-27 18:30:41 +00:00
from olmocr . version import VERSION
2025-01-29 15:25:10 -08:00
from olmocr . work_queue import LocalWorkQueue , S3WorkQueue , WorkQueue
2024-11-07 00:03:30 +00:00
2024-11-11 14:26:15 -08:00
# Initialize logger
2024-11-07 18:21:23 +00:00
logger = logging . getLogger ( __name__ )
2024-11-11 14:26:15 -08:00
logger . setLevel ( logging . DEBUG )
2024-11-12 09:33:53 -08:00
logger . propagate = False
2025-06-02 18:07:31 +00:00
server_logger = logging . getLogger ( " vllm " )
server_logger . propagate = False
2024-11-11 14:26:15 -08:00
2025-01-29 15:30:39 -08:00
file_handler = logging . FileHandler ( " olmocr-pipeline-debug.log " , mode = " a " )
2024-11-11 14:26:15 -08:00
file_handler . setLevel ( logging . DEBUG )
2025-01-29 15:30:39 -08:00
file_handler . setFormatter ( logging . Formatter ( " %(asctime)s - %(name)s - %(levelname)s - %(message)s " ) )
2024-11-11 14:26:15 -08:00
2024-11-12 09:33:53 -08:00
console_handler = logging . StreamHandler ( )
console_handler . setLevel ( logging . I NFO )
2025-01-29 15:30:39 -08:00
console_handler . setFormatter ( logging . Formatter ( " %(asctime)s - %(name)s - %(levelname)s - %(message)s " ) )
2024-11-12 09:33:53 -08:00
2024-11-11 14:26:15 -08:00
# Add handlers to the logger
logger . addHandler ( file_handler )
2024-11-12 09:33:53 -08:00
logger . addHandler ( console_handler )
2025-06-02 18:07:31 +00:00
server_logger . addHandler ( file_handler )
2024-11-07 00:03:30 +00:00
2024-11-07 18:21:23 +00:00
# Quiet logs from pypdf
logging . getLogger ( " pypdf " ) . setLevel ( logging . ERROR )
2024-11-11 13:09:09 -08:00
# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts
2025-01-29 15:30:39 -08:00
workspace_s3 = boto3 . client ( " s3 " )
pdf_s3 = boto3 . client ( " s3 " )
2024-11-07 00:03:30 +00:00
2024-11-11 13:09:09 -08:00
# Global variables for token statistics
2025-01-29 15:30:39 -08:00
metrics = MetricsKeeper ( window = 60 * 5 )
2024-11-12 13:28:39 -08:00
tracker = WorkerTracker ( )
2024-11-11 13:09:09 -08:00
2024-12-03 10:48:52 -08:00
# Process pool for offloading cpu bound work, like calculating anchor texts, max 32 workers, otherwise it can spawn way too many workers on a big machine
2025-01-29 15:30:39 -08:00
process_pool = ProcessPoolExecutor ( max_workers = min ( multiprocessing . cpu_count ( ) / / 2 + 1 , 32 ) , mp_context = multiprocessing . get_context ( " spawn " ) )
2024-11-08 11:04:58 -08:00
2024-11-21 23:23:11 +00:00
# Filter object, cached so it will only get loaded when/if you need it
get_pdf_filter = cache ( lambda : PdfFilter ( languages_to_keep = { Language . ENGLISH , None } , apply_download_spam_check = True , apply_form_check = True ) )
2024-11-21 10:20:58 -08:00
2025-03-31 13:28:30 -07:00
# Specify a default port, but it can be overridden by args
2025-06-02 18:07:31 +00:00
BASE_SERVER_PORT = 30024
2024-11-14 13:13:27 -08:00
2025-01-29 15:30:39 -08:00
2024-11-11 11:46:49 -08:00
@dataclass ( frozen = True )
class PageResult :
s3_path : str
page_num : int
response : PageResponse
2024-11-12 08:34:25 -08:00
input_tokens : int
output_tokens : int
2024-11-19 14:59:20 -08:00
is_fallback : bool
2024-11-11 13:09:09 -08:00
2024-11-07 18:21:23 +00:00
2025-07-15 18:00:01 +00:00
async def build_page_query ( local_pdf_path : str , page : int , target_longest_image_dim : int , image_rotation : int = 0 ) - > dict :
2025-05-29 22:33:16 +00:00
MAX_TOKENS = 4500
2024-11-07 23:24:01 +00:00
assert image_rotation in [ 0 , 90 , 180 , 270 ] , " Invalid image rotation provided in build_page_query "
2024-11-08 09:59:27 -08:00
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
2025-07-14 17:35:29 +00:00
image_base64 = await asyncio . to_thread ( render_pdf_to_base64png , local_pdf_path , page , target_longest_image_dim = target_longest_image_dim )
2024-11-12 08:18:22 -08:00
2024-11-07 23:24:01 +00:00
if image_rotation != 0 :
image_bytes = base64 . b64decode ( image_base64 )
with Image . open ( BytesIO ( image_bytes ) ) as img :
2025-08-04 17:53:48 +00:00
if image_rotation == 90 :
tranpose = Image . Transpose . ROTATE_90
elif image_rotation == 180 :
tranpose = Image . Transpose . ROTATE_180
else :
tranpose = Image . Transpose . ROTATE_270
rotated_img = img . transpose ( tranpose )
2024-11-07 23:24:01 +00:00
# Save the rotated image to a bytes buffer
buffered = BytesIO ( )
rotated_img . save ( buffered , format = " PNG " )
# Encode the rotated image back to base64
2025-01-29 15:30:39 -08:00
image_base64 = base64 . b64encode ( buffered . getvalue ( ) ) . decode ( " utf-8 " )
2024-11-07 23:24:01 +00:00
return {
2025-07-01 17:01:33 +00:00
" model " : " olmocr " ,
2024-11-08 11:04:58 -08:00
" messages " : [
2024-11-07 23:24:01 +00:00
{
" role " : " user " ,
" content " : [
2025-01-29 15:30:39 -08:00
{ " type " : " image_url " , " image_url " : { " url " : f " data:image/png;base64, { image_base64 } " } } ,
2025-07-15 18:00:01 +00:00
{ " type " : " text " , " text " : build_no_anchoring_yaml_prompt ( ) } ,
2024-11-07 23:24:01 +00:00
] ,
}
] ,
2024-11-08 11:04:58 -08:00
" max_tokens " : MAX_TOKENS ,
2025-03-14 22:27:51 -07:00
" temperature " : 0.0 ,
2024-11-07 23:24:01 +00:00
}
2024-12-03 10:48:52 -08:00
# Manual simple implementation of HTTP Post
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost ( url , json_data ) :
parsed_url = urlparse ( url )
host = parsed_url . hostname
port = parsed_url . port or 80
path = parsed_url . path or " / "
writer = None
try :
reader , writer = await asyncio . open_connection ( host , port )
json_payload = json . dumps ( json_data )
request = (
f " POST { path } HTTP/1.1 \r \n "
f " Host: { host } \r \n "
f " Content-Type: application/json \r \n "
f " Content-Length: { len ( json_payload ) } \r \n "
f " Connection: close \r \n \r \n "
f " { json_payload } "
)
writer . write ( request . encode ( ) )
await writer . drain ( )
# Read status line
status_line = await reader . readline ( )
if not status_line :
raise ConnectionError ( " No response from server " )
2025-01-29 15:30:39 -08:00
status_parts = status_line . decode ( ) . strip ( ) . split ( " " , 2 )
2024-12-03 10:48:52 -08:00
if len ( status_parts ) < 2 :
raise ValueError ( f " Malformed status line: { status_line . decode ( ) . strip ( ) } " )
status_code = int ( status_parts [ 1 ] )
# Read headers
headers = { }
while True :
line = await reader . readline ( )
2025-01-29 15:30:39 -08:00
if line in ( b " \r \n " , b " \n " , b " " ) :
2024-12-03 10:48:52 -08:00
break
2025-01-29 15:30:39 -08:00
key , _ , value = line . decode ( ) . partition ( " : " )
2024-12-03 10:48:52 -08:00
headers [ key . strip ( ) . lower ( ) ] = value . strip ( )
# Read response body
2025-01-29 15:30:39 -08:00
if " content-length " in headers :
body_length = int ( headers [ " content-length " ] )
2024-12-03 10:48:52 -08:00
response_body = await reader . readexactly ( body_length )
else :
raise ConnectionError ( " Anything other than fixed content length responses are not implemented yet " )
return status_code , response_body
except Exception as e :
# Pass through errors
raise e
finally :
# But just make sure to close the socket on your way out
if writer is not None :
try :
writer . close ( )
await writer . wait_closed ( )
except :
pass
2025-01-28 14:16:53 -08:00
async def process_page ( args , worker_id : int , pdf_orig_path : str , pdf_local_path : str , page_num : int ) - > PageResult :
2025-06-02 18:07:31 +00:00
COMPLETION_URL = f " http://localhost: { BASE_SERVER_PORT } /v1/chat/completions "
2024-11-18 14:25:32 -08:00
MAX_RETRIES = args . max_page_retries
2025-07-23 20:37:48 +00:00
MODEL_MAX_CONTEXT = 16384
2025-07-22 19:35:40 +00:00
TEMPERATURE_BY_ATTEMPT = [ 0.1 , 0.1 , 0.2 , 0.3 , 0.5 , 0.8 , 0.9 , 1.0 ]
2024-11-14 09:02:49 -08:00
exponential_backoffs = 0
2025-08-04 18:21:31 +00:00
cumulative_rotation = 0 # Track cumulative rotation instead of local
2024-11-12 10:49:13 -08:00
attempt = 0
2025-01-28 14:16:53 -08:00
await tracker . track_work ( worker_id , f " { pdf_orig_path } - { page_num } " , " started " )
2024-11-11 11:58:45 -08:00
2024-11-12 10:49:13 -08:00
while attempt < MAX_RETRIES :
2025-07-15 18:00:01 +00:00
lookup_attempt = min ( attempt , len ( TEMPERATURE_BY_ATTEMPT ) - 1 )
2025-05-29 23:23:02 +00:00
query = await build_page_query (
pdf_local_path ,
page_num ,
args . target_longest_image_dim ,
2025-08-04 18:21:31 +00:00
image_rotation = cumulative_rotation ,
2025-05-29 23:23:02 +00:00
)
2025-05-29 16:01:26 +00:00
# Change temperature as number of attempts increases to overcome repetition issues at expense of quality
query [ " temperature " ] = TEMPERATURE_BY_ATTEMPT [ lookup_attempt ]
2024-11-11 11:46:49 -08:00
2025-07-15 18:00:01 +00:00
# Enable guided decoding regex if needed
if args . guided_decoding :
2025-07-23 03:40:05 +00:00
query [ " guided_regex " ] = (
r " --- \ nprimary_language: (?:[a-z] {2} |null) \ nis_rotation_valid: (?:True|False|true|false) \ nrotation_correction: (?:0|90|180|270) \ nis_table: (?:True|False|true|false) \ nis_diagram: (?:True|False|true|false) \ n(?:---|--- \ n[ \ s \ S]+) "
)
2025-07-15 18:00:01 +00:00
2025-01-28 14:16:53 -08:00
logger . info ( f " Built page query for { pdf_orig_path } - { page_num } " )
2024-12-03 10:48:52 -08:00
2024-11-11 14:38:26 -08:00
try :
2024-12-03 10:48:52 -08:00
status_code , response_body = await apost ( COMPLETION_URL , json_data = query )
2024-11-18 09:03:24 -08:00
2024-12-03 10:48:52 -08:00
if status_code == 400 :
raise ValueError ( f " Got BadRequestError from server: { response_body } , skipping this response " )
elif status_code == 500 :
raise ValueError ( f " Got InternalServerError from server: { response_body } , skipping this response " )
elif status_code != 200 :
raise ValueError ( f " Error http status { status_code } " )
2024-11-15 10:03:26 -08:00
2024-12-03 10:48:52 -08:00
base_response_data = json . loads ( response_body )
2025-01-29 15:30:39 -08:00
2025-07-23 20:37:48 +00:00
if base_response_data [ " usage " ] [ " total_tokens " ] > MODEL_MAX_CONTEXT :
raise ValueError ( f " Response exceeded model_max_context of { MODEL_MAX_CONTEXT } , cannot use this response " )
2025-07-23 03:40:05 +00:00
2025-07-15 18:00:01 +00:00
if base_response_data [ " choices " ] [ 0 ] [ " finish_reason " ] != " stop " :
raise ValueError ( " Response did not finish with reason code ' stop ' , cannot use this response " )
2025-01-29 15:30:39 -08:00
metrics . add_metrics (
2025-06-02 18:07:31 +00:00
server_input_tokens = base_response_data [ " usage " ] . get ( " prompt_tokens " , 0 ) ,
server_output_tokens = base_response_data [ " usage " ] . get ( " completion_tokens " , 0 ) ,
2025-01-29 15:30:39 -08:00
)
2024-11-15 12:48:36 -08:00
2025-07-15 18:00:01 +00:00
model_response_markdown = base_response_data [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
parser = FrontMatterParser ( front_matter_class = PageResponse )
front_matter , text = parser . _extract_front_matter_and_text ( model_response_markdown )
page_response = parser . _parse_front_matter ( front_matter , text )
2024-11-15 12:48:36 -08:00
2024-11-18 08:29:32 -08:00
if not page_response . is_rotation_valid and attempt < MAX_RETRIES - 1 :
2025-01-29 15:30:39 -08:00
logger . info (
f " Got invalid_page rotation for { pdf_orig_path } - { page_num } attempt { attempt } , retrying with { page_response . rotation_correction } rotation "
)
2025-08-04 18:21:31 +00:00
# Add the rotation correction to the cumulative rotation
cumulative_rotation = ( cumulative_rotation + page_response . rotation_correction ) % 360
logger . info ( f " Cumulative rotation is now { cumulative_rotation } degrees " )
2025-01-28 14:16:53 -08:00
raise ValueError ( f " invalid_page rotation for { pdf_orig_path } - { page_num } " )
2024-11-18 08:29:32 -08:00
2025-07-15 21:41:10 +00:00
metrics . add_metrics ( * * { " completed_pages " : 1 , f " finished_on_attempt_ { attempt } " : 1 } )
2025-01-28 14:16:53 -08:00
await tracker . track_work ( worker_id , f " { pdf_orig_path } - { page_num } " , " finished " )
2024-11-15 12:48:36 -08:00
return PageResult (
2025-01-28 14:16:53 -08:00
pdf_orig_path ,
2024-11-15 12:48:36 -08:00
page_num ,
page_response ,
input_tokens = base_response_data [ " usage " ] . get ( " prompt_tokens " , 0 ) ,
2024-11-19 14:59:20 -08:00
output_tokens = base_response_data [ " usage " ] . get ( " completion_tokens " , 0 ) ,
is_fallback = False ,
2024-11-15 12:48:36 -08:00
)
2024-12-03 10:48:52 -08:00
except ( ConnectionError , OSError , asyncio . TimeoutError ) as e :
2025-01-28 14:16:53 -08:00
logger . warning ( f " Client error on attempt { attempt } for { pdf_orig_path } - { page_num } : { type ( e ) } { e } " )
2025-01-29 15:30:39 -08:00
2024-11-12 10:49:13 -08:00
# Now we want to do exponential backoff, and not count this as an actual page retry
2025-06-02 18:07:31 +00:00
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to vllm
2024-11-12 10:49:13 -08:00
# are supposed to work. Probably this means that the server is just restarting
2025-01-29 15:30:39 -08:00
sleep_delay = 10 * ( 2 * * exponential_backoffs )
2024-11-14 09:02:49 -08:00
exponential_backoffs + = 1
2025-01-28 14:16:53 -08:00
logger . info ( f " Sleeping for { sleep_delay } seconds on { pdf_orig_path } - { page_num } to allow server restart " )
2024-11-14 09:02:49 -08:00
await asyncio . sleep ( sleep_delay )
2024-11-14 12:06:13 -08:00
except asyncio . CancelledError :
2025-01-28 14:16:53 -08:00
logger . info ( f " Process page { pdf_orig_path } - { page_num } cancelled " )
await tracker . track_work ( worker_id , f " { pdf_orig_path } - { page_num } " , " cancelled " )
2024-11-14 12:06:13 -08:00
raise
2024-11-14 14:13:04 -08:00
except json . JSONDecodeError as e :
2025-01-28 14:16:53 -08:00
logger . warning ( f " JSON decode error on attempt { attempt } for { pdf_orig_path } - { page_num } : { e } " )
2024-11-14 14:13:04 -08:00
attempt + = 1
2024-11-15 10:03:26 -08:00
except ValueError as e :
2025-01-28 14:16:53 -08:00
logger . warning ( f " ValueError on attempt { attempt } for { pdf_orig_path } - { page_num } : { type ( e ) } - { e } " )
2024-11-15 10:03:26 -08:00
attempt + = 1
2024-11-11 14:38:26 -08:00
except Exception as e :
2025-01-28 14:16:53 -08:00
logger . exception ( f " Unexpected error on attempt { attempt } for { pdf_orig_path } - { page_num } : { type ( e ) } - { e } " )
2024-11-12 10:49:13 -08:00
attempt + = 1
2024-11-11 14:38:26 -08:00
2025-01-28 14:16:53 -08:00
logger . error ( f " Failed to process { pdf_orig_path } - { page_num } after { MAX_RETRIES } attempts. " )
2025-06-13 19:53:34 +00:00
metrics . add_metrics ( failed_pages = 1 )
2025-01-28 14:16:53 -08:00
await tracker . track_work ( worker_id , f " { pdf_orig_path } - { page_num } " , " errored " )
2025-01-29 15:30:39 -08:00
2024-11-19 14:59:20 -08:00
return PageResult (
2025-01-28 14:16:53 -08:00
pdf_orig_path ,
2024-11-19 14:59:20 -08:00
page_num ,
2025-01-29 15:30:39 -08:00
PageResponse (
natural_text = get_anchor_text ( pdf_local_path , page_num , pdf_engine = " pdftotext " ) ,
primary_language = None ,
is_rotation_valid = True ,
rotation_correction = 0 ,
is_table = False ,
is_diagram = False ,
) ,
2024-11-19 14:59:20 -08:00
input_tokens = 0 ,
output_tokens = 0 ,
2025-01-29 15:30:39 -08:00
is_fallback = True ,
2024-11-19 14:59:20 -08:00
)
2024-11-08 09:59:27 -08:00
2025-07-23 03:40:05 +00:00
2025-01-28 14:16:53 -08:00
async def process_pdf ( args , worker_id : int , pdf_orig_path : str ) :
2025-06-12 17:21:21 +00:00
with tempfile . NamedTemporaryFile ( " wb+ " , suffix = " .pdf " , delete = False ) as tf :
2024-12-04 17:56:45 +00:00
try :
2025-01-28 14:16:53 -08:00
data = await asyncio . to_thread ( lambda : get_s3_bytes_with_backoff ( pdf_s3 , pdf_orig_path ) )
2024-12-04 17:56:45 +00:00
tf . write ( data )
tf . flush ( )
except ClientError as ex :
2025-01-29 15:30:39 -08:00
if ex . response [ " Error " ] [ " Code " ] == " NoSuchKey " :
2025-01-28 14:16:53 -08:00
logger . info ( f " S3 File Not found, skipping it completely { pdf_orig_path } " )
2024-12-04 17:56:45 +00:00
return None
else :
raise
2024-11-08 09:59:27 -08:00
2025-03-31 10:59:38 -07:00
if is_png ( tf . name ) or is_jpeg ( tf . name ) :
logger . info ( f " Converting { pdf_orig_path } from image to PDF format... " )
tf . seek ( 0 )
tf . write ( convert_image_to_pdf_bytes ( tf . name ) )
tf . flush ( )
2025-06-12 17:21:21 +00:00
try :
2024-11-14 08:55:20 -08:00
try :
reader = PdfReader ( tf . name )
num_pages = reader . get_num_pages ( )
except :
2025-01-28 14:16:53 -08:00
logger . exception ( f " Could not count number of pages for { pdf_orig_path } , aborting document " )
2024-11-14 08:55:20 -08:00
return None
2024-11-08 10:19:00 -08:00
2025-01-28 14:16:53 -08:00
logger . info ( f " Got { num_pages } pages to do for { pdf_orig_path } in worker { worker_id } " )
2024-11-12 13:44:20 -08:00
2024-11-21 23:23:11 +00:00
if args . apply_filter and get_pdf_filter ( ) . filter_out_pdf ( tf . name ) :
2025-01-28 14:16:53 -08:00
logger . info ( f " Filtering out pdf { pdf_orig_path } " )
2024-11-21 10:20:58 -08:00
return None
2024-11-08 10:19:00 -08:00
# List to hold the tasks for processing each page
page_tasks = [ ]
2024-11-14 12:06:13 -08:00
page_results = [ ]
2024-11-08 10:19:00 -08:00
2024-11-12 13:54:45 -08:00
try :
2024-11-14 12:06:13 -08:00
async with asyncio . TaskGroup ( ) as tg :
for page_num in range ( 1 , num_pages + 1 ) :
2025-01-28 14:16:53 -08:00
task = tg . create_task ( process_page ( args , worker_id , pdf_orig_path , tf . name , page_num ) )
2024-11-14 12:06:13 -08:00
page_tasks . append ( task )
2024-11-12 08:18:22 -08:00
2024-11-14 12:06:13 -08:00
# Collect the results from the entire task group, assuming no exceptions
page_results = [ task . result ( ) for task in page_tasks ]
2024-11-19 14:59:20 -08:00
num_fallback_pages = sum ( page_result . is_fallback for page_result in page_results )
if num_fallback_pages / num_pages > args . max_page_error_rate :
2025-01-29 15:30:39 -08:00
logger . error (
f " Document { pdf_orig_path } has { num_fallback_pages } fallback pages out of { num_pages } exceeding max_page_error_rate of { args . max_page_error_rate } , discarding document. "
)
2024-11-19 14:59:20 -08:00
return None
elif num_fallback_pages > 0 :
2025-01-29 15:30:39 -08:00
logger . warning (
f " Document { pdf_orig_path } processed with { num_fallback_pages } fallback pages out of { num_pages } , proceeding to build Dolma document. "
)
2024-11-19 14:59:20 -08:00
2025-01-28 14:16:53 -08:00
return build_dolma_document ( pdf_orig_path , page_results )
2024-11-14 12:06:13 -08:00
except Exception as e :
2024-11-22 22:07:43 +00:00
# Check for ExceptionGroup with BrokenProcessPool
if isinstance ( e , ExceptionGroup ) :
broken_pool , other = e . split ( BrokenProcessPool )
if broken_pool is not None : # Found at least one BrokenProcessPool
logger . critical ( " Encountered BrokenProcessPool, exiting process. " )
sys . exit ( 1 )
2025-01-28 14:16:53 -08:00
logger . exception ( f " Exception in process_pdf for { pdf_orig_path } : { e } " )
2024-11-14 13:44:54 -08:00
# You can't build a dolma doc with even 1 failed page, so just get out of here
# However, you don't want to propagate an exception higher up and cancel the entire work_group
return None
2025-06-12 17:21:21 +00:00
finally :
if os . path . exists ( tf . name ) :
os . unlink ( tf . name )
2024-11-14 12:06:13 -08:00
2025-01-28 14:16:53 -08:00
def build_dolma_document ( pdf_orig_path , page_results ) :
2024-11-14 12:06:13 -08:00
# Build the document text and page spans
document_text = " "
pdf_page_spans = [ ]
current_char_pos = 0
for index , page_result in enumerate ( page_results ) :
if page_result . response . natural_text is not None :
content = page_result . response . natural_text + ( " \n " if index < len ( page_results ) - 1 else " " )
else :
content = " "
start_pos = current_char_pos
document_text + = content
current_char_pos = len ( document_text )
pdf_page_spans . append ( [ start_pos , current_char_pos , page_result . page_num ] )
if not document_text :
2025-01-28 14:16:53 -08:00
logger . info ( f " No document text for { pdf_orig_path } " )
2024-11-14 12:06:13 -08:00
return None # Return None if the document text is empty
# Build the Dolma document
metadata = {
2025-01-28 14:16:53 -08:00
" Source-File " : pdf_orig_path ,
2025-01-27 18:30:41 +00:00
" olmocr-version " : VERSION ,
2024-11-14 12:06:13 -08:00
" pdf-total-pages " : len ( page_results ) ,
" total-input-tokens " : sum ( page . input_tokens for page in page_results ) ,
2024-11-19 15:11:02 -08:00
" total-output-tokens " : sum ( page . output_tokens for page in page_results ) ,
" total-fallback-pages " : sum ( page . is_fallback for page in page_results ) ,
2024-11-14 12:06:13 -08:00
}
2024-11-08 10:19:00 -08:00
2024-11-14 12:06:13 -08:00
id_ = hashlib . sha1 ( document_text . encode ( ) ) . hexdigest ( )
dolma_doc = {
" id " : id_ ,
" text " : document_text ,
2025-01-27 18:30:41 +00:00
" source " : " olmocr " ,
2024-11-14 12:06:13 -08:00
" added " : datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) ,
" created " : datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) ,
" metadata " : metadata ,
2025-01-29 15:30:39 -08:00
" attributes " : { " pdf_page_numbers " : pdf_page_spans } ,
2024-11-14 12:06:13 -08:00
}
return dolma_doc
2024-11-08 09:59:27 -08:00
2024-11-18 11:04:51 -08:00
2025-01-28 14:28:19 -08:00
async def worker ( args , work_queue : WorkQueue , semaphore , worker_id ) :
2024-11-08 08:14:20 -08:00
while True :
2024-11-18 11:04:51 -08:00
# Wait until allowed to proceed
await semaphore . acquire ( )
2024-11-11 13:09:09 -08:00
2024-11-18 11:04:51 -08:00
work_item = await work_queue . get_work ( )
2024-11-13 13:05:57 -08:00
2024-11-18 11:04:51 -08:00
if work_item is None :
logger . info ( f " Worker { worker_id } exiting due to empty queue " )
semaphore . release ( )
break
2024-11-12 08:18:22 -08:00
2024-11-18 11:04:51 -08:00
logger . info ( f " Worker { worker_id } processing work item { work_item . hash } " )
await tracker . clear_work ( worker_id )
2024-11-13 13:05:57 -08:00
2025-01-29 15:30:39 -08:00
try :
async with asyncio . TaskGroup ( ) as tg :
2025-01-27 20:45:28 +00:00
dolma_tasks = [ tg . create_task ( process_pdf ( args , worker_id , pdf ) ) for pdf in work_item . work_paths ]
2024-12-03 10:48:52 -08:00
logger . info ( f " Created all tasks for { work_item . hash } " )
2024-11-15 11:23:38 -08:00
2024-12-03 10:48:52 -08:00
logger . info ( f " Finished TaskGroup for worker on { work_item . hash } " )
2024-11-14 12:06:13 -08:00
2024-11-14 13:44:54 -08:00
dolma_docs = [ ]
for task in dolma_tasks :
try :
result = task . result ( )
except :
# some dolma doc creations may have failed
pass
if result is not None :
dolma_docs . append ( result )
2025-01-29 15:30:39 -08:00
2024-11-18 11:04:51 -08:00
logger . info ( f " Got { len ( dolma_docs ) } docs for { work_item . hash } " )
2024-11-11 13:09:09 -08:00
# Write the Dolma documents to a local temporary file in JSONL format
2025-01-29 15:30:39 -08:00
with tempfile . NamedTemporaryFile ( mode = " w+ " , delete = False ) as tf :
2024-11-11 13:09:09 -08:00
for doc in dolma_docs :
tf . write ( json . dumps ( doc ) )
2025-01-29 15:30:39 -08:00
tf . write ( " \n " )
2024-11-11 13:09:09 -08:00
tf . flush ( )
2025-05-19 19:42:48 +00:00
temp_path = tf . name
2024-11-11 13:09:09 -08:00
2025-05-19 19:42:48 +00:00
try :
2024-11-11 13:09:09 -08:00
# Define the output S3 path using the work_hash
2025-01-29 15:30:39 -08:00
output_final_path = os . path . join ( args . workspace , " results " , f " output_ { work_item . hash } .jsonl " )
2024-11-11 13:09:09 -08:00
2025-01-28 14:28:19 -08:00
if output_final_path . startswith ( " s3:// " ) :
bucket , key = parse_s3_path ( output_final_path )
2025-05-19 19:42:48 +00:00
workspace_s3 . upload_file ( temp_path , bucket , key )
2025-01-28 14:28:19 -08:00
else :
2025-05-19 19:42:48 +00:00
shutil . copyfile ( temp_path , output_final_path )
finally :
# Clean up the temporary file
if os . path . exists ( temp_path ) :
os . unlink ( temp_path )
2025-05-20 16:42:21 +00:00
2025-05-19 19:42:48 +00:00
# If --markdown flag is set, also write the natural text to markdown files
if args . markdown :
logger . info ( f " Writing { len ( dolma_docs ) } markdown files for { work_item . hash } " )
for doc in dolma_docs :
source_file = doc [ " metadata " ] [ " Source-File " ]
natural_text = doc [ " text " ]
2025-05-20 16:42:21 +00:00
2025-05-19 19:42:48 +00:00
# Create the output markdown path that preserves the folder structure
if source_file . startswith ( " s3:// " ) :
# Extract the path after the bucket name for S3 sources
parsed = urlparse ( source_file )
2025-05-20 16:42:21 +00:00
relative_path = parsed . path . lstrip ( " / " )
2025-05-19 19:42:48 +00:00
else :
# For local files, use the full path
relative_path = source_file
2025-05-20 16:42:21 +00:00
2025-05-19 19:42:48 +00:00
# Change the extension to .md
md_filename = os . path . splitext ( os . path . basename ( relative_path ) ) [ 0 ] + " .md "
# Get the directory path without the filename
dir_path = os . path . dirname ( relative_path )
2025-05-20 16:42:21 +00:00
2025-05-19 19:42:48 +00:00
# Create the output markdown path
markdown_dir = os . path . join ( args . workspace , " markdown " , dir_path )
markdown_path = os . path . join ( markdown_dir , md_filename )
2025-05-20 16:42:21 +00:00
2025-05-19 19:42:48 +00:00
# Create the directory structure if it doesn't exist
if markdown_path . startswith ( " s3:// " ) :
# For S3 paths, we'll create a temporary file and upload it
with tempfile . NamedTemporaryFile ( mode = " w+ " , delete = False ) as md_tf :
md_tf . write ( natural_text )
md_tf . flush ( )
md_temp_path = md_tf . name
2025-05-20 16:42:21 +00:00
2025-05-19 19:42:48 +00:00
try :
md_bucket , md_key = parse_s3_path ( markdown_path )
workspace_s3 . upload_file ( md_temp_path , md_bucket , md_key )
finally :
# Make sure to clean up the temporary file even if upload fails
if os . path . exists ( md_temp_path ) :
os . unlink ( md_temp_path )
else :
# For local paths, create the directory structure and write the file
os . makedirs ( markdown_dir , exist_ok = True )
with open ( markdown_path , " w " ) as md_f :
md_f . write ( natural_text )
2024-11-11 13:09:09 -08:00
2024-11-12 09:33:53 -08:00
# Update finished token counts from successful documents
2025-01-29 15:30:39 -08:00
metrics . add_metrics (
finished_input_tokens = sum ( doc [ " metadata " ] [ " total-input-tokens " ] for doc in dolma_docs ) ,
finished_output_tokens = sum ( doc [ " metadata " ] [ " total-output-tokens " ] for doc in dolma_docs ) ,
)
2024-12-04 19:08:21 +00:00
await work_queue . mark_done ( work_item )
2024-11-11 13:09:09 -08:00
except Exception as e :
2024-11-18 11:04:51 -08:00
logger . exception ( f " Exception occurred while processing work_hash { work_item . hash } : { e } " )
2024-11-11 13:09:09 -08:00
finally :
2024-11-18 11:04:51 -08:00
semaphore . release ( )
2024-11-08 08:14:20 -08:00
2025-08-03 23:00:06 -04:00
async def vllm_server_task ( model_name_or_path , args , semaphore , unknown_args = None ) :
2024-11-18 08:25:36 -08:00
cmd = [
2025-06-02 18:07:31 +00:00
" vllm " ,
" serve " ,
2025-01-29 15:30:39 -08:00
model_name_or_path ,
2025-06-13 19:53:34 +00:00
" --port " ,
str ( BASE_SERVER_PORT ) ,
2025-06-02 18:07:31 +00:00
" --disable-log-requests " ,
2025-06-13 19:53:34 +00:00
" --uvicorn-log-level " ,
" warning " ,
" --served-model-name " ,
2025-07-01 17:01:33 +00:00
" olmocr " ,
2025-07-09 17:46:54 +00:00
" --tensor-parallel-size " ,
str ( args . tensor_parallel_size ) ,
" --data-parallel-size " ,
str ( args . data_parallel_size ) ,
2024-11-18 08:25:36 -08:00
]
2025-07-23 03:32:49 +00:00
if args . gpu_memory_utilization is not None :
cmd . extend ( [ " --gpu-memory-utilization " , str ( args . gpu_memory_utilization ) ] )
2025-07-23 03:40:05 +00:00
2025-07-23 03:32:49 +00:00
if args . max_model_len is not None :
2025-07-23 03:40:05 +00:00
cmd . extend ( [ " --max-model-len " , str ( args . max_model_len ) ] )
2025-07-23 03:32:49 +00:00
2025-08-03 23:00:06 -04:00
if unknown_args :
cmd . extend ( unknown_args )
2024-11-18 08:25:36 -08:00
proc = await asyncio . create_subprocess_exec (
* cmd ,
2024-11-12 08:18:22 -08:00
stdout = asyncio . subprocess . PIPE ,
stderr = asyncio . subprocess . PIPE ,
2024-11-15 13:18:13 -08:00
)
2024-11-08 09:14:00 -08:00
2024-11-15 13:18:13 -08:00
# Ensure the subprocess is terminated on exit
2024-11-11 10:24:47 -08:00
def _kill_proc ( ) :
2025-07-23 21:38:29 +00:00
try :
proc . terminate ( )
except :
logger . info ( " VLLM Process already terminated " )
2024-11-11 10:24:47 -08:00
atexit . register ( _kill_proc )
2024-11-08 10:36:09 -08:00
2024-11-15 13:18:13 -08:00
# Shared variables between tasks
last_running_req , last_queue_req = 0 , 0
2024-11-18 09:55:45 -08:00
server_printed_ready_message = False
2024-11-12 10:49:13 -08:00
last_semaphore_release = time . time ( )
2024-11-15 13:18:13 -08:00
2024-11-12 08:18:22 -08:00
async def process_line ( line ) :
2024-11-18 09:55:45 -08:00
nonlocal last_running_req , last_queue_req , last_semaphore_release , server_printed_ready_message
2025-06-02 18:07:31 +00:00
server_logger . info ( line )
2024-11-12 10:49:13 -08:00
2025-01-29 14:12:39 -08:00
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
# can see any warnings/errors more easily
if not server_printed_ready_message :
logger . info ( line )
2024-11-16 08:16:11 -08:00
if " Detected errors during sampling " in line :
logger . error ( " Cannot continue, sampling errors detected, model is probably corrupt " )
sys . exit ( 1 )
2025-06-02 21:22:25 +00:00
if not server_printed_ready_message and ( " The server is fired up and ready to roll! " in line or " Starting vLLM API server " in line ) :
2024-11-18 09:55:45 -08:00
server_printed_ready_message = True
2024-11-18 11:04:51 -08:00
last_semaphore_release = time . time ( )
2024-11-18 09:55:45 -08:00
2025-06-13 19:53:34 +00:00
match = re . search ( r " Running: ( \ d+) " , line )
2024-11-12 10:49:13 -08:00
if match :
last_running_req = int ( match . group ( 1 ) )
2024-11-15 13:18:13 -08:00
2025-07-09 17:46:54 +00:00
match = re . search ( r " (?:Waiting|Pending): \ s*( \ d+) " , line )
2024-11-12 08:18:22 -08:00
if match :
2024-11-18 08:25:36 -08:00
last_queue_req = int ( match . group ( 1 ) )
2025-06-02 18:07:31 +00:00
logger . info ( f " vllm running req: { last_running_req } queue req: { last_queue_req } " )
2024-11-12 08:18:22 -08:00
async def read_stream ( stream ) :
while True :
line = await stream . readline ( )
if not line :
break
2024-12-02 13:59:42 -08:00
try :
2025-01-29 15:30:39 -08:00
line = line . decode ( " utf-8 " ) . rstrip ( )
2024-12-02 13:59:42 -08:00
await process_line ( line )
except Exception as ex :
logger . warning ( f " Got { ex } when reading log line from inference server, skipping " )
2024-11-12 08:18:22 -08:00
2024-11-15 13:18:13 -08:00
async def timeout_task ( ) :
2024-11-18 08:25:36 -08:00
nonlocal last_running_req , last_queue_req , last_semaphore_release
2024-11-15 13:18:13 -08:00
try :
while True :
2024-11-15 13:19:23 -08:00
await asyncio . sleep ( 1 )
2024-11-18 09:55:45 -08:00
if server_printed_ready_message and last_queue_req == 0 and time . time ( ) - last_semaphore_release > 30 and semaphore . locked ( ) :
2024-11-15 13:18:13 -08:00
semaphore . release ( )
last_semaphore_release = time . time ( )
2024-11-18 08:25:36 -08:00
logger . info ( " Semaphore released, allowing a worker to proceed. " )
2024-11-15 13:18:13 -08:00
except asyncio . CancelledError :
pass # Clean up if the task is cancelled
# Start tasks to read stdout, stderr, and handle timeout logic
2024-11-12 08:18:22 -08:00
stdout_task = asyncio . create_task ( read_stream ( proc . stdout ) )
stderr_task = asyncio . create_task ( read_stream ( proc . stderr ) )
2024-11-15 13:19:23 -08:00
timeout_task = asyncio . create_task ( timeout_task ( ) )
2024-11-12 08:18:22 -08:00
2024-11-25 09:48:05 -08:00
try :
await proc . wait ( )
except asyncio . CancelledError :
2025-06-02 18:07:31 +00:00
logger . info ( " Got cancellation request for VLLM server " )
2024-11-25 09:48:05 -08:00
proc . terminate ( )
2025-07-23 16:48:56 +00:00
try :
await asyncio . wait_for ( proc . wait ( ) , timeout = 10.0 )
except asyncio . TimeoutError :
logger . warning ( " VLLM server did not terminate within 10 seconds " )
2024-11-25 09:48:05 -08:00
raise
2024-11-15 13:19:23 -08:00
timeout_task . cancel ( )
await asyncio . gather ( stdout_task , stderr_task , timeout_task , return_exceptions = True )
2024-11-08 09:14:00 -08:00
2024-11-08 10:36:09 -08:00
2025-08-03 23:00:06 -04:00
async def vllm_server_host ( model_name_or_path , args , semaphore , unknown_args = None ) :
2024-12-02 23:46:46 +00:00
MAX_RETRIES = 5
retry = 0
while retry < MAX_RETRIES :
2025-08-03 23:00:06 -04:00
await vllm_server_task ( model_name_or_path , args , semaphore , unknown_args )
2025-06-02 18:07:31 +00:00
logger . warning ( " VLLM server task ended " )
2024-12-02 23:46:46 +00:00
retry + = 1
2024-11-12 10:49:13 -08:00
2024-12-02 23:56:45 +00:00
if retry > = MAX_RETRIES :
2025-06-02 18:07:31 +00:00
logger . error ( f " Ended up starting the vllm server more than { retry } times, cancelling pipeline " )
2025-01-29 15:47:57 -08:00
logger . error ( " " )
2025-06-13 19:53:34 +00:00
logger . error (
" Please make sure vllm is installed according to the latest instructions here: https://docs.vllm.ai/en/stable/getting_started/installation/gpu.html "
)
2024-12-02 23:56:45 +00:00
sys . exit ( 1 )
2024-11-12 10:49:13 -08:00
2025-06-02 18:07:31 +00:00
async def vllm_server_ready ( ) :
2024-11-08 11:38:56 -08:00
max_attempts = 300
2024-11-08 10:19:00 -08:00
delay_sec = 1
2025-06-02 18:07:31 +00:00
url = f " http://localhost: { BASE_SERVER_PORT } /v1/models "
2024-11-08 10:19:00 -08:00
for attempt in range ( 1 , max_attempts + 1 ) :
try :
2024-11-19 10:41:58 -08:00
async with httpx . AsyncClient ( ) as session :
response = await session . get ( url )
if response . status_code == 200 :
2025-06-02 18:07:31 +00:00
logger . info ( " vllm server is ready. " )
2024-11-19 10:41:58 -08:00
return
else :
logger . info ( f " Attempt { attempt } : Unexpected status code { response . status_code } " )
2025-03-13 22:29:27 +00:00
except Exception :
2025-06-02 18:07:31 +00:00
logger . warning ( f " Attempt { attempt } : Please wait for vllm server to become ready... " )
2024-11-08 10:36:09 -08:00
2024-11-08 10:19:00 -08:00
await asyncio . sleep ( delay_sec )
2025-06-02 18:07:31 +00:00
raise Exception ( " vllm server did not become ready after waiting. " )
2024-11-08 10:19:00 -08:00
2024-11-12 13:28:39 -08:00
2025-06-12 21:14:00 +00:00
async def download_model ( model_name_or_path : str , max_retries : int = 5 ) :
2025-06-06 18:52:01 +00:00
for retry in range ( max_retries ) :
try :
if model_name_or_path . startswith ( " s3:// " ) or model_name_or_path . startswith ( " gs:// " ) or model_name_or_path . startswith ( " weka:// " ) :
logger . info ( f " Downloading model directory from ' { model_name_or_path } ' " )
model_cache_dir = os . path . join ( os . path . expanduser ( " ~ " ) , " .cache " , " olmocr " , " model " )
# Delete existing model cache directory if it exists
if os . path . exists ( model_cache_dir ) :
shutil . rmtree ( model_cache_dir )
download_directory ( [ model_name_or_path ] , model_cache_dir )
return model_cache_dir
elif os . path . isabs ( model_name_or_path ) and os . path . isdir ( model_name_or_path ) :
logger . info ( f " Using local model path at ' { model_name_or_path } ' " )
return model_name_or_path
else :
logger . info ( f " Downloading model with hugging face ' { model_name_or_path } ' " )
snapshot_download ( repo_id = model_name_or_path )
return model_name_or_path
except Exception :
if retry == max_retries - 1 :
2025-06-12 21:14:00 +00:00
raise # Raise on final attempt and fail the job
2025-06-06 18:52:01 +00:00
sleep_time = random . randrange ( 2 , 20 ) * 2 * * retry
logger . exception ( f " Could not download model, sleeping for { sleep_time } seconds to retry ( { retry + 1 } / { max_retries } ) " )
await asyncio . sleep ( random . randrange ( 10 , 30 ) * 2 * * retry )
2025-03-03 13:42:13 -08:00
2024-11-18 11:04:51 -08:00
async def metrics_reporter ( work_queue ) :
2024-11-12 12:56:35 -08:00
while True :
2024-11-12 13:28:39 -08:00
# Leading newlines preserve table formatting in logs
2024-11-18 11:04:51 -08:00
logger . info ( f " Queue remaining: { work_queue . size } " )
2024-11-12 13:28:39 -08:00
logger . info ( " \n " + str ( metrics ) )
logger . info ( " \n " + str ( await tracker . get_status_table ( ) ) )
2024-11-12 12:56:35 -08:00
await asyncio . sleep ( 10 )
2024-11-08 09:14:00 -08:00
2024-11-13 10:25:35 -08:00
2024-11-12 15:56:51 -08:00
def submit_beaker_job ( args ) :
2025-02-07 16:05:00 -08:00
from beaker import ( # type: ignore
2024-11-13 08:00:14 -08:00
Beaker ,
Constraints ,
EnvVar ,
ExperimentSpec ,
ImageSource ,
Priority ,
ResultSpec ,
SecretNotFound ,
TaskContext ,
TaskResources ,
TaskSpec ,
)
2025-01-29 15:30:39 -08:00
2024-11-13 08:00:14 -08:00
b = Beaker . from_env ( default_workspace = args . beaker_workspace )
account = b . account . whoami ( )
2024-11-13 11:26:46 -08:00
owner = account . name
2025-06-02 23:05:48 +00:00
beaker_image = f " jakep/olmocr-inference- { VERSION } "
2024-11-13 08:00:14 -08:00
2025-01-27 18:30:41 +00:00
task_name = f " olmocr- { os . path . basename ( args . workspace . rstrip ( ' / ' ) ) } "
2024-11-13 08:00:14 -08:00
2024-11-19 11:48:45 -08:00
# Take out --beaker flag so the workers will just run things
2024-11-13 10:25:35 -08:00
args_list = [ arg for arg in sys . argv [ 1 : ] if arg != " --beaker " ]
2024-11-13 08:00:14 -08:00
2024-11-19 11:48:45 -08:00
# Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
2025-01-29 15:30:39 -08:00
args_list = [ arg for i , arg in enumerate ( args_list ) if not ( arg . startswith ( " --pdfs " ) or ( i > 0 and args_list [ i - 1 ] == " --pdfs " ) ) ]
2024-11-19 11:48:45 -08:00
2024-11-13 11:26:46 -08:00
try :
b . secret . get ( f " { owner } -WEKA_ACCESS_KEY_ID " , args . beaker_workspace )
b . secret . get ( f " { owner } -WEKA_SECRET_ACCESS_KEY " , args . beaker_workspace )
b . secret . get ( f " { owner } -AWS_CREDENTIALS_FILE " , args . beaker_workspace )
except SecretNotFound :
2025-01-29 15:30:39 -08:00
print (
f " Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace { args . beaker_workspace } ? [y/n] "
)
2024-11-13 11:26:46 -08:00
if input ( ) . strip ( ) . lower ( ) != " y " :
print ( " Exiting... " )
sys . exit ( 1 )
b . secret . write ( f " { owner } -WEKA_ACCESS_KEY_ID " , os . environ . get ( " WEKA_ACCESS_KEY_ID " , " " ) , args . beaker_workspace )
b . secret . write ( f " { owner } -WEKA_SECRET_ACCESS_KEY " , os . environ . get ( " WEKA_SECRET_ACCESS_KEY " , " " ) , args . beaker_workspace )
2025-01-30 13:42:42 -08:00
b . secret . write (
f " { owner } -AWS_CREDENTIALS_FILE " ,
open ( os . path . join ( os . path . expanduser ( " ~ " ) , " .aws " , " credentials " ) ) . read ( ) ,
args . beaker_workspace ,
)
env_var_secrets = [
EnvVar ( name = " WEKA_ACCESS_KEY_ID " , secret = f " { owner } -WEKA_ACCESS_KEY_ID " ) ,
EnvVar ( name = " WEKA_SECRET_ACCESS_KEY " , secret = f " { owner } -WEKA_SECRET_ACCESS_KEY " ) ,
EnvVar ( name = " AWS_CREDENTIALS_FILE " , secret = f " { owner } -AWS_CREDENTIALS_FILE " ) ,
]
2024-11-13 11:26:46 -08:00
2025-01-30 22:30:39 +00:00
try :
b . secret . get ( " OLMOCR_PREVIEW_HF_TOKEN " , args . beaker_workspace )
env_var_secrets . append ( EnvVar ( name = " HF_TOKEN " , secret = " OLMOCR_PREVIEW_HF_TOKEN " ) )
except SecretNotFound :
pass
2024-11-18 13:07:27 -08:00
try :
2025-01-29 15:47:57 -08:00
b . secret . get ( " OE_DATA_GCS_SA_KEY " , args . beaker_workspace )
2025-01-30 13:42:42 -08:00
env_var_secrets . append ( EnvVar ( name = " GOOGLE_APPLICATION_CREDENTIALS_FILE " , secret = " OE_DATA_GCS_SA_KEY " ) )
2024-11-18 13:07:27 -08:00
except SecretNotFound :
print ( " Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline): " )
lines = [ ]
prev_empty = False
for line in iter ( input , None ) :
if not line and prev_empty :
break
prev_empty = not line
lines . append ( line )
gcs_sa_key = " \n " . join ( lines [ : - 1 ] ) . strip ( ) # Remove the last empty line
if gcs_sa_key :
2025-01-29 15:47:57 -08:00
b . secret . write ( " OE_DATA_GCS_SA_KEY " , gcs_sa_key , args . beaker_workspace )
2025-01-30 13:42:42 -08:00
env_var_secrets . append ( EnvVar ( name = " GOOGLE_APPLICATION_CREDENTIALS_FILE " , secret = " OE_DATA_GCS_SA_KEY " ) )
2024-11-13 11:26:46 -08:00
2024-11-13 08:00:14 -08:00
# Create the experiment spec
experiment_spec = ExperimentSpec (
budget = " ai2/oe-data " ,
description = task_name ,
tasks = [
TaskSpec (
name = task_name ,
propagate_failure = False ,
propagate_preemption = False ,
2024-11-14 09:02:49 -08:00
replicas = args . beaker_gpus ,
2024-11-13 08:00:14 -08:00
context = TaskContext (
2024-11-14 09:10:28 -08:00
priority = Priority ( args . beaker_priority ) ,
2024-11-13 08:00:14 -08:00
preemptible = True ,
) ,
image = ImageSource ( beaker = beaker_image ) ,
2025-01-30 22:14:57 +00:00
command = [ " python " , " -m " , " olmocr.pipeline " ] + args_list ,
2025-01-30 13:42:42 -08:00
env_vars = [ EnvVar ( name = " BEAKER_JOB_NAME " , value = task_name ) , EnvVar ( name = " OWNER " , value = owner ) ] + env_var_secrets ,
2024-11-14 09:02:49 -08:00
resources = TaskResources ( gpu_count = 1 ) ,
2024-11-15 13:30:27 -08:00
constraints = Constraints ( cluster = args . beaker_cluster if isinstance ( args . beaker_cluster , list ) else [ args . beaker_cluster ] ) ,
2024-11-13 08:00:14 -08:00
result = ResultSpec ( path = " /noop-results " ) ,
)
] ,
)
2025-01-29 15:30:39 -08:00
2024-11-13 08:00:14 -08:00
experiment_data = b . experiment . create ( spec = experiment_spec , workspace = args . beaker_workspace )
2025-01-29 15:30:39 -08:00
2024-11-13 08:00:14 -08:00
print ( f " Experiment URL: https://beaker.org/ex/ { experiment_data . id } " )
2024-11-12 15:56:51 -08:00
2024-11-18 11:04:51 -08:00
2025-05-06 21:21:06 +00:00
def print_stats ( args , root_work_queue ) :
2024-12-10 17:18:10 +00:00
LONG_CONTEXT_THRESHOLD = 32768
2025-01-10 19:38:42 +00:00
2025-01-28 14:29:46 -08:00
assert args . workspace . startswith ( " s3:// " ) , " Printing stats functionality only works with s3 workspaces for now. "
2024-11-18 07:57:39 -08:00
# Get total work items and completed items
2024-11-18 11:04:51 -08:00
index_file_s3_path = os . path . join ( args . workspace , " work_index_list.csv.zstd " )
output_glob = os . path . join ( args . workspace , " results " , " *.jsonl " )
2025-01-29 15:30:39 -08:00
2024-11-18 07:57:39 -08:00
done_work_items = expand_s3_glob ( workspace_s3 , output_glob )
2025-05-06 21:21:06 +00:00
work_queue_lines = download_zstd_csv ( workspace_s3 , index_file_s3_path )
work_queue = { }
for line in work_queue_lines :
if line . strip ( ) :
parts = root_work_queue . _decode_csv_row ( line . strip ( ) )
if parts : # Ensure we have at least one part
work_queue [ parts [ 0 ] ] = parts [ 1 : ]
2025-01-29 15:30:39 -08:00
2024-11-19 13:41:32 -08:00
total_items = len ( work_queue )
2024-11-18 07:57:39 -08:00
completed_items = len ( done_work_items )
2025-01-29 15:30:39 -08:00
2024-11-18 07:57:39 -08:00
def process_output_file ( s3_path ) :
try :
data = get_s3_bytes ( workspace_s3 , s3_path )
doc_count = 0
total_input_tokens = 0
total_output_tokens = 0
2024-11-18 11:50:22 -08:00
total_pages = 0
2025-01-29 15:30:39 -08:00
total_fallback_pages = 0
2024-11-18 11:50:22 -08:00
processed_paths = set ( )
2025-01-29 15:30:39 -08:00
2024-12-10 17:18:10 +00:00
# Counters for long context docs within a single file
long_context_docs = 0
long_context_tokens = 0
2025-01-29 15:30:39 -08:00
for line in data . decode ( " utf-8 " ) . splitlines ( ) :
2024-11-18 07:57:39 -08:00
if line . strip ( ) :
doc = json . loads ( line )
doc_count + = 1
2024-12-10 17:18:10 +00:00
doc_input_tokens = doc [ " metadata " ] . get ( " total-input-tokens " , 0 )
doc_output_tokens = doc [ " metadata " ] . get ( " total-output-tokens " , 0 )
doc_pages = doc [ " metadata " ] . get ( " pdf-total-pages " , 0 )
doc_fallback_pages = doc [ " metadata " ] . get ( " total-fallback-pages " , 0 )
total_input_tokens + = doc_input_tokens
total_output_tokens + = doc_output_tokens
total_pages + = doc_pages
total_fallback_pages + = doc_fallback_pages
2024-11-18 11:50:22 -08:00
processed_paths . add ( doc [ " metadata " ] [ " Source-File " ] )
2024-12-10 17:18:10 +00:00
# Check if this doc exceeds the long context threshold
if doc_output_tokens > LONG_CONTEXT_THRESHOLD :
long_context_docs + = 1
long_context_tokens + = doc_output_tokens
2025-01-29 15:30:39 -08:00
return (
doc_count ,
total_input_tokens ,
total_output_tokens ,
total_pages ,
total_fallback_pages ,
processed_paths ,
long_context_docs ,
long_context_tokens ,
)
2024-11-18 07:57:39 -08:00
except Exception as e :
logger . warning ( f " Error processing { s3_path } : { e } " )
2024-12-10 17:18:10 +00:00
return 0 , 0 , 0 , 0 , 0 , set ( ) , 0 , 0
2025-01-29 15:30:39 -08:00
2025-06-05 15:58:19 +00:00
print ( f " \n Completed work items { completed_items : , } out of { total_items : , } : { completed_items / total_items * 100 : .2f } % " )
2024-11-18 07:57:39 -08:00
print ( " \n Processing output files... " )
docs_total = 0
input_tokens_total = 0
output_tokens_total = 0
2024-11-18 11:50:22 -08:00
pages_total = 0
2024-11-20 10:42:26 -08:00
fallback_pages_total = 0
2024-11-18 11:50:22 -08:00
all_processed_paths = set ( )
original_paths = set ( )
2025-01-29 15:30:39 -08:00
2024-12-10 17:18:10 +00:00
# Counters for long context documents across all files
long_context_docs_count = 0
long_context_tokens_total = 0
2024-11-18 11:50:22 -08:00
# First collect all original PDF paths
2024-11-19 13:41:32 -08:00
for done_work_item in done_work_items :
if match := re . search ( r " output_( \ w+).jsonl " , done_work_item ) :
done_work_hash = match . group ( 1 )
2025-05-06 21:21:06 +00:00
if done_work_hash in work_queue :
original_paths . update ( work_queue [ done_work_hash ] )
2025-01-29 15:30:39 -08:00
2024-11-18 11:50:22 -08:00
with ThreadPoolExecutor ( ) as executor :
2024-11-18 07:57:39 -08:00
futures = { executor . submit ( process_output_file , item ) : item for item in done_work_items }
2025-01-29 15:30:39 -08:00
2024-11-18 11:50:22 -08:00
for future in tqdm ( as_completed ( futures ) , total = len ( futures ) ) :
2025-01-29 15:30:39 -08:00
( doc_count , input_tokens , output_tokens , pages , fallback_pages , processed_paths , long_context_docs , long_context_tokens ) = future . result ( )
2024-11-18 07:57:39 -08:00
docs_total + = doc_count
input_tokens_total + = input_tokens
output_tokens_total + = output_tokens
2024-11-18 11:50:22 -08:00
pages_total + = pages
2024-11-20 10:42:26 -08:00
fallback_pages_total + = fallback_pages
2024-11-18 11:50:22 -08:00
all_processed_paths . update ( processed_paths )
2024-12-10 17:18:10 +00:00
long_context_docs_count + = long_context_docs
long_context_tokens_total + = long_context_tokens
2025-01-29 15:30:39 -08:00
2024-11-18 11:50:22 -08:00
skipped_paths = original_paths - all_processed_paths
2024-11-19 13:41:32 -08:00
2025-01-29 15:47:57 -08:00
print ( " \n Work Items Status: " )
2024-11-19 13:41:32 -08:00
print ( f " Total work items: { total_items : , } " )
print ( f " Completed items: { completed_items : , } " )
print ( f " Remaining items: { total_items - completed_items : , } " )
2025-01-29 15:30:39 -08:00
2025-01-29 15:47:57 -08:00
print ( " \n Results: " )
2024-11-18 07:57:39 -08:00
print ( f " Total documents processed: { docs_total : , } " )
2024-11-18 11:50:22 -08:00
print ( f " Total documents skipped: { len ( skipped_paths ) : , } " )
2024-11-20 10:42:26 -08:00
print ( f " Total pages on fallback: { fallback_pages_total : , } " )
2024-11-18 11:50:22 -08:00
print ( f " Total pages processed: { pages_total : , } " )
2025-01-29 15:30:39 -08:00
2024-11-18 11:50:22 -08:00
print ( f " \n Total output tokens: { output_tokens_total : , } " )
2024-11-20 23:57:10 +00:00
print ( f " Projected output tokens: { round ( ( output_tokens_total / max ( 1 , completed_items ) ) * total_items ) : , } " )
2024-11-18 11:50:22 -08:00
print ( f " \n Average pages per doc: { pages_total / max ( 1 , docs_total ) : ,.1f } " )
2024-11-18 07:57:39 -08:00
print ( f " Average output tokens per doc: { output_tokens_total / max ( 1 , docs_total ) : ,.1f } " )
2024-11-18 11:50:22 -08:00
print ( f " Average output tokens per page: { output_tokens_total / max ( 1 , pages_total ) : ,.1f } " )
2024-12-10 17:18:10 +00:00
# Print long context documents stats
print ( f " \n Long Context Documents (> { LONG_CONTEXT_THRESHOLD } tokens): { long_context_docs_count : , } " )
print ( f " Total tokens in long context documents: { long_context_tokens_total : , } " )
2024-11-12 13:28:39 -08:00
2024-11-08 08:14:20 -08:00
async def main ( ) :
2025-08-03 23:25:57 -04:00
parser = argparse . ArgumentParser ( description = " Manager for running millions of PDFs through a batch inference pipeline. " )
2025-01-29 15:30:39 -08:00
parser . add_argument (
" workspace " ,
help = " The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ " ,
)
parser . add_argument (
" --pdfs " ,
2025-01-30 12:48:10 -08:00
nargs = " * " ,
2025-01-29 15:30:39 -08:00
help = " Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths " ,
default = None ,
)
2025-07-24 18:47:03 +00:00
parser . add_argument (
" --model " ,
help = " Path where the model is located, allenai/olmOCR-7B-0725-FP8 is the default, can be local, s3, or hugging face. " ,
default = " allenai/olmOCR-7B-0725-FP8 " ,
)
# More detailed config options, usually you shouldn't have to change these
2025-01-29 15:30:39 -08:00
parser . add_argument ( " --workspace_profile " , help = " S3 configuration profile for accessing the workspace " , default = None )
parser . add_argument ( " --pdf_profile " , help = " S3 configuration profile for accessing the raw pdf documents " , default = None )
parser . add_argument ( " --pages_per_group " , type = int , default = 500 , help = " Aiming for this many pdf pages per work item group " )
parser . add_argument ( " --max_page_retries " , type = int , default = 8 , help = " Max number of times we will retry rendering a page " )
parser . add_argument ( " --max_page_error_rate " , type = float , default = 0.004 , help = " Rate of allowable failed pages in a document, 1/250 by default " )
2025-06-13 03:50:21 +00:00
parser . add_argument ( " --workers " , type = int , default = 20 , help = " Number of workers to run at a time " )
2025-01-29 15:30:39 -08:00
parser . add_argument ( " --apply_filter " , action = " store_true " , help = " Apply basic filtering to English pdfs which are not forms, and not likely seo spam " )
parser . add_argument ( " --stats " , action = " store_true " , help = " Instead of running any job, reports some statistics about the current workspace " )
2025-05-19 19:42:48 +00:00
parser . add_argument ( " --markdown " , action = " store_true " , help = " Also write natural text to markdown files preserving the folder structure of the input pdfs " )
2024-11-07 00:03:30 +00:00
2025-07-17 19:46:35 +00:00
parser . add_argument ( " --target_longest_image_dim " , type = int , help = " Dimension on longest side to use for rendering the pdf pages " , default = 1288 )
2025-07-15 18:00:01 +00:00
parser . add_argument ( " --target_anchor_text_len " , type = int , help = " Maximum amount of anchor text to use (characters), not used for new models " , default = - 1 )
2025-07-01 17:44:02 +00:00
parser . add_argument ( " --guided_decoding " , action = " store_true " , help = " Enable guided decoding for model YAML type outputs " )
2024-11-12 15:56:51 -08:00
2025-08-03 23:00:06 -04:00
vllm_group = parser . add_argument_group (
2025-08-04 18:21:47 +00:00
" VLLM arguments " , " These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM. "
2025-08-03 23:00:06 -04:00
)
2025-07-24 18:50:30 +00:00
vllm_group . add_argument (
" --gpu-memory-utilization " , type = float , help = " Fraction of VRAM vLLM may pre-allocate for KV-cache " " (passed through to vllm serve). "
)
2025-07-24 18:47:03 +00:00
vllm_group . add_argument ( " --max_model_len " , type = int , default = 16384 , help = " Upper bound (tokens) vLLM will allocate KV-cache for, lower if VLLM won ' t start " )
vllm_group . add_argument ( " --tensor-parallel-size " , " -tp " , type = int , default = 1 , help = " Tensor parallel size for vLLM " )
vllm_group . add_argument ( " --data-parallel-size " , " -dp " , type = int , default = 1 , help = " Data parallel size for vLLM " )
vllm_group . add_argument ( " --port " , type = int , default = 30024 , help = " Port to use for the VLLM server " )
2024-11-12 15:56:51 -08:00
# Beaker/job running stuff
2025-07-24 18:47:03 +00:00
beaker_group = parser . add_argument_group ( " beaker/cluster execution " )
beaker_group . add_argument ( " --beaker " , action = " store_true " , help = " Submit this job to beaker instead of running locally " )
beaker_group . add_argument ( " --beaker_workspace " , help = " Beaker workspace to submit to " , default = " ai2/olmocr " )
beaker_group . add_argument (
2025-01-29 15:30:39 -08:00
" --beaker_cluster " ,
help = " Beaker clusters you want to run on " ,
default = [ " ai2/jupiter-cirrascale-2 " , " ai2/ceres-cirrascale " , " ai2/neptune-cirrascale " , " ai2/saturn-cirrascale " , " ai2/augusta-google-1 " ] ,
)
2025-07-24 18:47:03 +00:00
beaker_group . add_argument ( " --beaker_gpus " , type = int , default = 1 , help = " Number of gpu replicas to run " )
beaker_group . add_argument ( " --beaker_priority " , type = str , default = " normal " , help = " Beaker priority level for the job " )
2025-08-03 23:00:06 -04:00
args , unknown_args = parser . parse_known_args ( )
2024-11-07 00:03:30 +00:00
2025-07-05 10:42:52 +02:00
logger . info (
" If you run out of GPU memory during start-up or get ' KV cache is larger than available memory ' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384 "
)
2025-07-23 03:40:05 +00:00
2024-11-13 13:23:29 -08:00
global workspace_s3 , pdf_s3
2025-06-02 18:07:31 +00:00
# set the global BASE_SERVER_PORT from args
global BASE_SERVER_PORT
BASE_SERVER_PORT = args . port
2024-11-13 13:23:29 -08:00
2024-11-14 08:49:12 -08:00
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
if " BEAKER_JOB_NAME " in os . environ :
2025-06-02 18:07:31 +00:00
server_logger . addHandler ( console_handler )
2025-01-29 15:30:39 -08:00
cred_path = os . path . join ( os . path . expanduser ( " ~ " ) , " .aws " , " credentials " )
2024-11-13 12:35:40 -08:00
os . makedirs ( os . path . dirname ( cred_path ) , exist_ok = True )
2024-11-13 11:26:46 -08:00
with open ( cred_path , " w " ) as f :
f . write ( os . environ . get ( " AWS_CREDENTIALS_FILE " ) )
2025-01-29 15:30:39 -08:00
cred_path = os . path . join ( os . path . expanduser ( " ~ " ) , " .gcs " , " credentials " )
2024-11-18 13:20:28 -08:00
os . makedirs ( os . path . dirname ( cred_path ) , exist_ok = True )
with open ( cred_path , " w " ) as f :
f . write ( os . environ . get ( " GOOGLE_APPLICATION_CREDENTIALS_FILE " ) )
2024-11-18 13:58:25 -08:00
os . environ [ " GOOGLE_APPLICATION_CREDENTIALS " ] = cred_path
2025-01-29 15:30:39 -08:00
workspace_s3 = boto3 . client ( " s3 " )
pdf_s3 = boto3 . client ( " s3 " )
2024-11-13 11:26:46 -08:00
2025-04-14 17:14:51 +00:00
# Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time
2025-04-18 15:47:31 +00:00
replica_count = int ( os . environ . get ( " BEAKER_REPLICA_COUNT " , " 1 " ) )
2025-06-06 18:52:01 +00:00
interval = 10 if ( replica_count - 1 ) * 10 < = 30 else 30 / max ( 1 , replica_count - 1 )
sleep_time = int ( os . environ . get ( " BEAKER_REPLICA_RANK " , " 0 " ) ) * interval
2025-04-18 15:47:31 +00:00
logger . info ( f " Beaker job sleeping for { sleep_time } seconds to stagger model downloads " )
2025-04-14 17:14:51 +00:00
await asyncio . sleep ( sleep_time )
2024-11-07 18:21:23 +00:00
if args . workspace_profile :
workspace_session = boto3 . Session ( profile_name = args . workspace_profile )
workspace_s3 = workspace_session . client ( " s3 " )
if args . pdf_profile :
pdf_session = boto3 . Session ( profile_name = args . pdf_profile )
pdf_s3 = pdf_session . client ( " s3 " )
2025-01-30 21:44:22 +00:00
# We need poppler to load the initial pdfs, even if we are not processing them here
2024-11-07 23:24:01 +00:00
check_poppler_version ( )
2024-11-07 21:08:46 +00:00
2024-11-18 11:04:51 -08:00
# Create work queue
2025-01-27 20:45:28 +00:00
if args . workspace . startswith ( " s3:// " ) :
work_queue = S3WorkQueue ( workspace_s3 , args . workspace )
else :
work_queue = LocalWorkQueue ( args . workspace )
2024-11-18 11:04:51 -08:00
2024-11-07 18:21:23 +00:00
if args . pdfs :
2024-11-13 12:59:52 -08:00
logger . info ( " Got --pdfs argument, going to add to the work queue " )
2025-01-30 12:48:10 -08:00
pdf_work_paths = set ( )
for pdf_path in args . pdfs :
# Expand s3 paths
if pdf_path . startswith ( " s3:// " ) :
logger . info ( f " Expanding s3 glob at { pdf_path } " )
pdf_work_paths | = set ( expand_s3_glob ( pdf_s3 , pdf_path ) )
elif os . path . exists ( pdf_path ) :
2025-03-31 13:28:30 -07:00
if (
pdf_path . lower ( ) . endswith ( " .pdf " )
or pdf_path . lower ( ) . endswith ( " .png " )
or pdf_path . lower ( ) . endswith ( " .jpg " )
or pdf_path . lower ( ) . endswith ( " .jpeg " )
) :
2025-03-03 14:45:06 -08:00
if open ( pdf_path , " rb " ) . read ( 4 ) == b " % PDF " :
logger . info ( f " Loading file at { pdf_path } as PDF document " )
pdf_work_paths . add ( pdf_path )
2025-03-31 10:59:38 -07:00
elif is_png ( pdf_path ) or is_jpeg ( pdf_path ) :
logger . info ( f " Loading file at { pdf_path } as image document " )
pdf_work_paths . add ( pdf_path )
2025-03-03 14:45:06 -08:00
else :
logger . warning ( f " File at { pdf_path } is not a valid PDF " )
2025-03-31 10:59:38 -07:00
elif pdf_path . lower ( ) . endswith ( " .txt " ) :
2025-03-03 14:45:06 -08:00
logger . info ( f " Loading file at { pdf_path } as list of paths " )
with open ( pdf_path , " r " ) as f :
2025-01-30 12:48:10 -08:00
pdf_work_paths | = set ( filter ( None , ( line . strip ( ) for line in f ) ) )
2025-03-03 14:45:06 -08:00
else :
raise ValueError ( f " Unsupported file extension for { pdf_path } " )
2025-01-28 15:12:28 -08:00
else :
2025-01-30 12:48:10 -08:00
raise ValueError ( " pdfs argument needs to be either a local path, an s3 path, or an s3 glob pattern... " )
2024-11-18 11:04:51 -08:00
2025-01-28 15:03:31 -08:00
logger . info ( f " Found { len ( pdf_work_paths ) : , } total pdf paths to add " )
2024-11-18 11:04:51 -08:00
# Estimate average pages per pdf
2025-01-28 15:03:31 -08:00
sample_size = min ( 100 , len ( pdf_work_paths ) )
sampled_pdfs = random . sample ( list ( pdf_work_paths ) , sample_size )
2024-11-18 11:04:51 -08:00
page_counts = [ ]
for pdf in tqdm ( sampled_pdfs , desc = " Sampling PDFs to calculate optimal length " ) :
try :
# Download the PDF to a temp file
with tempfile . NamedTemporaryFile ( suffix = " .pdf " ) as tmp_file :
2025-01-28 15:03:31 -08:00
tmp_file . write ( get_s3_bytes ( pdf_s3 , pdf ) )
2024-11-18 11:04:51 -08:00
tmp_file . flush ( )
2025-03-31 10:59:38 -07:00
if is_png ( tmp_file . name ) or is_jpeg ( tmp_file . name ) :
page_counts . append ( 1 )
else :
reader = PdfReader ( tmp_file . name )
page_counts . append ( len ( reader . pages ) )
2024-11-18 11:04:51 -08:00
except Exception as e :
logger . warning ( f " Failed to read { pdf } : { e } " )
if page_counts :
avg_pages_per_pdf = sum ( page_counts ) / len ( page_counts )
else :
logger . warning ( " Could not read any PDFs to estimate average page count. " )
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
items_per_group = max ( 1 , int ( args . pages_per_group / avg_pages_per_pdf ) )
logger . info ( f " Calculated items_per_group: { items_per_group } based on average pages per PDF: { avg_pages_per_pdf : .2f } " )
# Now call populate_queue
2025-01-28 15:03:31 -08:00
await work_queue . populate_queue ( pdf_work_paths , items_per_group )
2024-11-08 08:14:20 -08:00
2024-11-18 07:57:39 -08:00
if args . stats :
2025-05-06 21:21:06 +00:00
print_stats ( args , work_queue )
2024-11-18 07:57:39 -08:00
return
2024-11-12 15:56:51 -08:00
if args . beaker :
2024-11-13 08:00:14 -08:00
submit_beaker_job ( args )
2024-11-12 15:56:51 -08:00
return
2025-01-30 21:44:22 +00:00
# If you get this far, then you are doing inference and need a GPU
2025-06-13 19:53:34 +00:00
# check_sglang_version()
2025-01-30 21:44:22 +00:00
check_torch_gpu_available ( )
2024-11-13 12:59:52 -08:00
logger . info ( f " Starting pipeline with PID { os . getpid ( ) } " )
2025-03-03 13:42:13 -08:00
# Download the model before you do anything else
2025-04-17 09:59:28 -07:00
model_name_or_path = await download_model ( args . model )
2025-03-03 13:42:13 -08:00
2024-11-18 11:04:51 -08:00
# Initialize the work queue
2025-03-23 23:45:28 +01:00
qsize = await work_queue . initialize_queue ( )
2024-11-18 11:04:51 -08:00
2025-03-27 23:09:50 +01:00
if qsize == 0 :
logger . info ( " No work to do, exiting " )
return
2024-11-12 08:18:22 -08:00
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio . Semaphore ( 1 )
2025-08-03 23:00:06 -04:00
vllm_server = asyncio . create_task ( vllm_server_host ( model_name_or_path , args , semaphore , unknown_args ) )
2024-11-08 09:14:00 -08:00
2025-06-02 18:07:31 +00:00
await vllm_server_ready ( )
2024-11-07 18:21:23 +00:00
2024-11-14 12:06:13 -08:00
metrics_task = asyncio . create_task ( metrics_reporter ( work_queue ) )
2024-11-12 12:56:35 -08:00
2024-11-08 08:14:20 -08:00
# Create worker tasks to process the queue concurrently.
2024-11-08 09:14:00 -08:00
worker_tasks = [ ]
2024-11-08 08:14:20 -08:00
for i in range ( args . workers ) :
2024-11-12 09:33:53 -08:00
task = asyncio . create_task ( worker ( args , work_queue , semaphore , worker_id = i ) )
2024-11-08 09:14:00 -08:00
worker_tasks . append ( task )
2024-11-07 18:21:23 +00:00
2024-11-18 11:04:51 -08:00
# Wait for all worker tasks to finish
await asyncio . gather ( * worker_tasks )
2024-11-08 08:14:20 -08:00
2024-11-08 09:14:00 -08:00
# Wait for server to stop
2024-11-15 12:54:45 -08:00
process_pool . shutdown ( wait = False )
2025-06-02 18:07:31 +00:00
vllm_server . cancel ( )
2024-11-12 12:56:35 -08:00
metrics_task . cancel ( )
2025-06-13 19:53:34 +00:00
2025-07-23 16:48:56 +00:00
# Wait for cancelled tasks to complete
await asyncio . gather ( vllm_server , metrics_task , return_exceptions = True )
2025-06-02 21:10:30 +00:00
# Output final metrics summary
metrics_summary = metrics . get_metrics_summary ( )
logger . info ( " = " * 80 )
logger . info ( " FINAL METRICS SUMMARY " )
logger . info ( " = " * 80 )
logger . info ( f " Total elapsed time: { metrics_summary [ ' elapsed_time_seconds ' ] : .2f } seconds " )
2025-06-13 19:53:34 +00:00
2025-06-02 21:10:30 +00:00
# Output token counts and rates
2025-06-13 19:53:34 +00:00
total_metrics = metrics_summary [ " total_metrics " ]
rates = metrics_summary [ " rates " ]
2025-06-02 21:40:14 +00:00
logger . info ( f " Total Server Input tokens: { total_metrics . get ( ' server_input_tokens ' , 0 ) : , } " )
logger . info ( f " Total Server Output tokens: { total_metrics . get ( ' server_output_tokens ' , 0 ) : , } " )
2025-06-13 19:53:34 +00:00
2025-06-02 21:40:14 +00:00
logger . info ( f " Finished input tokens: { total_metrics . get ( ' finished_input_tokens ' , 0 ) : , } " )
logger . info ( f " Finished output tokens: { total_metrics . get ( ' finished_output_tokens ' , 0 ) : , } " )
2025-06-13 19:53:34 +00:00
logger . info ( f " Completed pages: { total_metrics . get ( ' completed_pages ' , 0 ) : , } " )
logger . info ( f " Failed pages: { total_metrics . get ( ' failed_pages ' , 0 ) : , } " )
logger . info (
f " Page Failure rate: { total_metrics . get ( ' failed_pages ' , 0 ) / max ( total_metrics . get ( ' completed_pages ' , 0 ) + total_metrics . get ( ' failed_pages ' , 0 ) , 1 ) * 100 : .2f } % "
)
2025-07-15 21:41:10 +00:00
# Output finished_on_attempt statistics
2025-07-23 16:48:56 +00:00
logger . info ( " " )
logger . info ( " Pages finished by attempt number: " )
2025-07-23 03:40:05 +00:00
total_finished = sum ( total_metrics . get ( f " finished_on_attempt_ { i } " , 0 ) for i in range ( args . max_page_retries ) )
2025-07-15 21:41:10 +00:00
cumulative = 0
2025-07-23 03:40:05 +00:00
2025-07-15 21:41:10 +00:00
for i in range ( args . max_page_retries ) :
2025-07-23 03:40:05 +00:00
if f " finished_on_attempt_ { i } " in total_metrics :
count = total_metrics [ f " finished_on_attempt_ { i } " ]
2025-07-15 21:41:10 +00:00
cumulative + = count
percentage = ( count / total_finished * 100 ) if total_finished > 0 else 0
cumulative_percentage = ( cumulative / total_finished * 100 ) if total_finished > 0 else 0
logger . info ( f " Attempt { i } : { count : , } pages ( { percentage : .1f } %) - Cumulative: { cumulative : , } ( { cumulative_percentage : .1f } %) " )
2025-06-02 21:10:30 +00:00
# Output rates
2025-06-13 19:53:34 +00:00
if " server_input_tokens_per_sec " in rates :
2025-06-13 03:53:33 +00:00
logger . info ( f " Server Input tokens/sec rate: { rates [ ' server_input_tokens_per_sec ' ] : .2f } " )
2025-06-13 19:53:34 +00:00
if " server_output_tokens_per_sec " in rates :
2025-06-13 03:53:33 +00:00
logger . info ( f " Server Output tokens/sec rate: { rates [ ' server_output_tokens_per_sec ' ] : .2f } " )
2025-06-17 17:06:45 +00:00
if " finished_input_tokens_per_sec " in rates :
logger . info ( f " Finished Input tokens/sec rate: { rates [ ' finished_input_tokens_per_sec ' ] : .2f } " )
if " finished_output_tokens_per_sec " in rates :
logger . info ( f " Finished Output tokens/sec rate: { rates [ ' finished_output_tokens_per_sec ' ] : .2f } " )
2025-06-13 03:53:33 +00:00
2025-06-02 21:10:30 +00:00
logger . info ( " = " * 80 )
2024-11-15 12:54:45 -08:00
logger . info ( " Work done " )
2024-11-08 08:14:20 -08:00
2024-11-18 11:04:51 -08:00
2024-11-08 08:14:20 -08:00
if __name__ == " __main__ " :
2025-07-23 03:40:05 +00:00
asyncio . run ( main ( ) )