2025-04-23 18:27:26 +00:00
#!/usr/bin/env python3
"""
Tagging pipeline for Dolma JSONL datasets .
For each . jsonl , . jsonl . gz , or . jsonl . ztd file under the dataset / documents folder ,
this script issues a simple SGLang completion per record ( e . g . , " Is this document in English? " ) ,
collects the yes / no answers , and writes corresponding Dolma attributes JSONL files under
scratch / attributes / , mirroring the input structure .
"""
import argparse
import asyncio
import atexit
2025-04-23 19:56:13 +00:00
import gzip
2025-04-23 18:27:26 +00:00
import json
import logging
import os
2025-04-24 20:14:42 +00:00
import random
2025-04-23 18:27:26 +00:00
import re
import sys
import time
2025-04-24 20:14:42 +00:00
from typing import Optional
2025-04-23 18:27:26 +00:00
from urllib . parse import urlparse
import boto3
import httpx
2025-04-23 14:47:00 -07:00
import zstandard as zstd
2025-04-23 18:27:26 +00:00
from huggingface_hub import snapshot_download
2025-04-24 10:18:30 -07:00
from pydantic import BaseModel , Field , ValidationError
2025-04-23 18:27:26 +00:00
from olmocr . check import (
check_sglang_version ,
check_torch_gpu_available ,
)
2025-04-24 20:31:59 +00:00
from olmocr . metrics import MetricsKeeper
2025-04-23 18:27:26 +00:00
from olmocr . s3_utils import (
download_directory ,
expand_s3_glob ,
get_s3_bytes_with_backoff ,
parse_s3_path ,
)
from olmocr . version import VERSION
from olmocr . work_queue import LocalWorkQueue , S3WorkQueue , WorkQueue
# Initialize logger
logger = logging . getLogger ( __name__ )
logger . setLevel ( logging . DEBUG )
logger . propagate = False
sglang_logger = logging . getLogger ( " sglang " )
sglang_logger . propagate = False
file_handler = logging . FileHandler ( " olmocr-pipeline-debug.log " , mode = " a " )
file_handler . setLevel ( logging . DEBUG )
file_handler . setFormatter ( logging . Formatter ( " %(asctime)s - %(name)s - %(levelname)s - %(message)s " ) )
console_handler = logging . StreamHandler ( )
console_handler . setLevel ( logging . INFO )
console_handler . setFormatter ( logging . Formatter ( " %(asctime)s - %(name)s - %(levelname)s - %(message)s " ) )
# Add handlers to the logger
logger . addHandler ( file_handler )
logger . addHandler ( console_handler )
sglang_logger . addHandler ( file_handler )
# Default port; overridden by --port
SGLANG_SERVER_PORT = 30024
# Global variables for token statistics
metrics = MetricsKeeper ( window = 60 * 5 )
2025-04-24 10:18:30 -07:00
class PIIClassification ( BaseModel ) :
2025-04-28 15:57:20 -07:00
primary_language : str = Field ( . . . , description = " Primary language as a two-letter code " )
document_type : str = Field ( . . . , description = " Basic summary of document type classification " )
is_resume_cv : Optional [ bool ] = Field ( . . . , description = " True if the document is a page from a resume or cv " )
contains_pii : Optional [ bool ] = Field ( . . . , description = " True if document contains PII " )
2025-04-24 10:18:30 -07:00
async def _process_single_page ( page_text : str ) - > PIIClassification :
""" Helper function to process a single document or page. """
text = page_text
2025-04-24 20:31:59 +00:00
metrics . add_metrics ( sglang_requests = 1 )
2025-04-24 10:18:30 -07:00
query = {
" model " : " google/gemma-3-4b-it " ,
" messages " : [
{
" role " : " user " ,
" content " : [
{
" type " : " text " ,
" text " : (
f " { text } \n \n ----------- \n "
2025-04-28 15:57:20 -07:00
" Given the text above, determine what type of document it is, and if it ' s a resume/CV. answer in JSON. The format of your json object should be { ' primary_language ' : str, ' document_type ' : str, ' is_resume_cv ' : bool, ' contains_pii ' : bool} "
2025-04-24 10:18:30 -07:00
) ,
}
] ,
}
] ,
" max_tokens " : 100 ,
" temperature " : 0.0 ,
" response_format " : { " type " : " json_schema " , " json_schema " : { " name " : " PIIClassification " , " schema " : PIIClassification . model_json_schema ( ) } } ,
}
url = f " http://localhost: { SGLANG_SERVER_PORT } /v1/chat/completions "
# ---------- HTTP call ---------------------------------------------------
try :
status , body = await apost ( url , json_data = query )
except Exception as e :
logger . warning ( f " SGLang network error: { e !s} " )
metrics . add_metrics ( sglang_errors = 1 )
2025-04-28 15:57:20 -07:00
return PIIClassification ( primary_language = " en " , document_type = " unknown " , is_resume_cv = None , contains_pii = None )
2025-04-24 10:18:30 -07:00
if status != 200 :
logger . warning ( f " SGLang HTTP { status } : { body [ : 250 ] !r} " )
metrics . add_metrics ( sglang_errors = 1 )
2025-04-28 15:57:20 -07:00
return PIIClassification ( primary_language = " en " , document_type = " unknown " , is_resume_cv = None , contains_pii = None )
2025-04-24 10:18:30 -07:00
# ---------- Parse base JSON --------------------------------------------
try :
base = json . loads ( body )
except json . JSONDecodeError :
logger . warning ( f " SGLang response is not valid JSON: { body [ : 250 ] !r} " )
metrics . add_metrics ( sglang_errors = 1 )
2025-04-28 15:57:20 -07:00
return PIIClassification ( primary_language = " en " , document_type = " unknown " , is_resume_cv = None , contains_pii = None )
2025-04-24 10:18:30 -07:00
# Token accounting if available
if " usage " in base :
metrics . add_metrics (
sglang_input_tokens = base [ " usage " ] . get ( " prompt_tokens " , 0 ) ,
sglang_output_tokens = base [ " usage " ] . get ( " completion_tokens " , 0 ) ,
)
# ---------- Extract the model message ----------------------------------
try :
content = base [ " choices " ] [ 0 ] [ " message " ] . get ( " content " )
except ( KeyError , IndexError , AttributeError ) as e :
logger . warning ( f " Missing fields in SGLang response: { e !s} " )
metrics . add_metrics ( sglang_errors = 1 )
2025-04-28 15:57:20 -07:00
return PIIClassification ( primary_language = " en " , document_type = " unknown " , is_resume_cv = None , contains_pii = None )
2025-04-24 10:18:30 -07:00
if not isinstance ( content , str ) :
logger . warning ( " SGLang `content` is not a string; treating as error. " )
metrics . add_metrics ( sglang_errors = 1 )
2025-04-28 15:57:20 -07:00
return PIIClassification ( primary_language = " en " , document_type = " unknown " , is_resume_cv = None , contains_pii = None )
2025-04-24 10:18:30 -07:00
try :
pii_classification : PIIClassification = PIIClassification . model_validate_json ( content )
return pii_classification
except ValidationError as e :
logger . warning ( f " Unable to parse pii classification object: { e !s} " )
metrics . add_metrics ( sglang_errors = 1 )
2025-04-28 15:57:20 -07:00
return PIIClassification ( primary_language = " en " , document_type = " unknown " , is_resume_cv = None , contains_pii = None )
2025-04-23 18:27:26 +00:00
2025-04-23 14:47:00 -07:00
2025-04-23 14:46:16 -07:00
# Manual simple implementation of HTTP Post
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost ( url , json_data ) :
parsed_url = urlparse ( url )
host = parsed_url . hostname
port = parsed_url . port or 80
path = parsed_url . path or " / "
writer = None
try :
reader , writer = await asyncio . open_connection ( host , port )
json_payload = json . dumps ( json_data )
request = (
f " POST { path } HTTP/1.1 \r \n "
f " Host: { host } \r \n "
f " Content-Type: application/json \r \n "
f " Content-Length: { len ( json_payload ) } \r \n "
f " Connection: close \r \n \r \n "
f " { json_payload } "
)
writer . write ( request . encode ( ) )
await writer . drain ( )
# Read status line
status_line = await reader . readline ( )
if not status_line :
raise ConnectionError ( " No response from server " )
status_parts = status_line . decode ( ) . strip ( ) . split ( " " , 2 )
if len ( status_parts ) < 2 :
raise ValueError ( f " Malformed status line: { status_line . decode ( ) . strip ( ) } " )
status_code = int ( status_parts [ 1 ] )
# Read headers
headers = { }
while True :
line = await reader . readline ( )
if line in ( b " \r \n " , b " \n " , b " " ) :
break
key , _ , value = line . decode ( ) . partition ( " : " )
headers [ key . strip ( ) . lower ( ) ] = value . strip ( )
# Read response body
if " content-length " in headers :
body_length = int ( headers [ " content-length " ] )
response_body = await reader . readexactly ( body_length )
else :
raise ConnectionError ( " Anything other than fixed content length responses are not implemented yet " )
return status_code , response_body
except Exception as e :
# Pass through errors
raise e
finally :
# But just make sure to close the socket on your way out
if writer is not None :
try :
writer . close ( )
await writer . wait_closed ( )
except :
pass
2025-04-23 14:47:00 -07:00
2025-04-24 20:14:42 +00:00
2025-04-24 10:18:30 -07:00
async def process_dolma_document ( args , dolma_doc , sem ) :
2025-04-23 19:56:13 +00:00
"""
2025-04-23 15:54:49 -07:00
Query SGLang to detect PII , enforcing a JSON schema .
Resilient to :
• Transport / HTTP errors
• Missing or malformed fields in the response
• Non - string or None ` content `
2025-04-24 10:18:30 -07:00
• Bad JSON in the model ' s answer
2025-04-23 15:54:49 -07:00
Always returns : ( doc_id , contains_pii : bool , text_length : int )
2025-04-23 19:56:13 +00:00
"""
2025-04-23 15:54:49 -07:00
doc_id = dolma_doc . get ( " id " )
text = dolma_doc . get ( " text " , " " ) or " "
2025-04-24 10:18:30 -07:00
key_name = f " { args . model . replace ( ' / ' , ' _ ' ) } _pii_classification "
2025-04-23 15:54:49 -07:00
2025-04-24 20:14:42 +00:00
result_attributes = { key_name : [ ] }
2025-04-24 10:18:30 -07:00
# If pdf_page_numbers is present, split the text and process each page separately
if " attributes " in dolma_doc and " pdf_page_numbers " in dolma_doc [ " attributes " ] :
page_numbers = dolma_doc [ " attributes " ] [ " pdf_page_numbers " ]
2025-04-24 20:14:42 +00:00
2025-04-24 10:18:30 -07:00
logger . info ( f " Document { doc_id } has { len ( page_numbers ) } pages, processing each individually " )
# Filter pages down to actual real content
2025-04-24 20:14:42 +00:00
selected_page_numbers = [ tuple ( p ) for p in page_numbers if p [ 0 ] < p [ 1 ] ]
2025-04-24 10:18:30 -07:00
# Sample 3 pages max per document
random . shuffle ( selected_page_numbers )
selected_page_numbers = selected_page_numbers [ : 3 ]
2025-04-24 20:14:42 +00:00
2025-04-24 10:18:30 -07:00
for start_pos , end_pos , page_num in page_numbers :
if ( start_pos , end_pos , page_num ) in selected_page_numbers :
page_text = text [ start_pos : end_pos ]
# Process each page with the semaphore to limit concurrent requests
async with sem :
2025-04-24 20:14:42 +00:00
pii_class = await _process_single_page ( page_text )
2025-04-28 15:57:20 -07:00
result_attributes [ key_name ] . append ( [ start_pos , end_pos , pii_class . is_resume_cv ] )
2025-04-24 10:18:30 -07:00
else :
result_attributes [ key_name ] . append ( [ start_pos , end_pos , None ] )
2025-04-23 14:46:16 -07:00
2025-04-24 10:18:30 -07:00
return result_attributes
else :
raise NotImplementedError ( " Missing code here, expecting this to be dolma docs made by olmocr.... " )
2025-04-24 20:14:42 +00:00
2025-04-23 18:39:31 +00:00
2025-04-23 18:27:26 +00:00
async def process_file ( args , worker_id : int , file_uri : str ) :
"""
2025-04-23 19:56:13 +00:00
Download a JSONL file , query SGLang per record , and collect attributes .
2025-04-23 18:27:26 +00:00
"""
# Fetch raw bytes (S3 or local)
if file_uri . startswith ( " s3:// " ) :
raw = await asyncio . to_thread ( get_s3_bytes_with_backoff , dataset_s3 , file_uri )
else :
2025-04-23 14:47:00 -07:00
with open ( file_uri , " rb " ) as f :
2025-04-23 19:56:13 +00:00
raw = f . read ( )
# Decompress if needed
2025-04-23 14:47:00 -07:00
if file_uri . endswith ( " .gz " ) :
2025-04-23 19:56:13 +00:00
file_bytes = gzip . decompress ( raw )
2025-04-23 14:47:00 -07:00
elif file_uri . endswith ( " .ztd " ) or file_uri . endswith ( " .zst " ) or file_uri . endswith ( " .zstd " ) :
2025-04-23 19:56:13 +00:00
dctx = zstd . ZstdDecompressor ( )
file_bytes = dctx . decompress ( raw , max_output_size = 1_000_000_000 )
else :
file_bytes = raw
2025-04-23 18:39:31 +00:00
2025-04-23 14:47:00 -07:00
lines = file_bytes . decode ( " utf-8 " ) . splitlines ( )
2025-04-23 18:39:31 +00:00
page_tasks = { }
2025-04-23 14:46:16 -07:00
# Send all records in parallel, max 500 queued at a time
sem = asyncio . Semaphore ( 500 )
2025-04-23 14:47:00 -07:00
2025-04-23 18:39:31 +00:00
async with asyncio . TaskGroup ( ) as tg :
2025-04-23 19:56:13 +00:00
for line in lines :
2025-04-24 10:18:30 -07:00
dolma_doc = json . loads ( line )
task = tg . create_task ( process_dolma_document ( args , dolma_doc , sem ) )
2025-04-24 20:14:42 +00:00
page_tasks [ dolma_doc [ " id " ] ] = ( task , dolma_doc )
2025-04-23 18:39:31 +00:00
2025-04-23 14:46:16 -07:00
logger . info ( f " Started taskgroup with { len ( page_tasks ) } items for { file_uri } " )
2025-04-23 19:56:13 +00:00
# Collect results and build attributes
attributes = [ ]
2025-04-24 20:14:42 +00:00
for doc_id , ( task , dolma_doc ) in page_tasks . items ( ) :
doc_attributes = task . result ( )
attributes . append ( { " id " : doc_id , " attributes " : doc_attributes } )
2025-04-23 14:46:16 -07:00
2025-04-23 19:56:13 +00:00
return attributes
2025-04-23 18:39:31 +00:00
2025-04-23 18:27:26 +00:00
2025-04-23 15:54:49 -07:00
async def worker ( args , work_queue : WorkQueue , semaphore : asyncio . Semaphore , worker_id : int ) :
"""
Pop work - items off the queue , run PII tagging , write the attributes file
next to the dataset ( keeping the original compression ) , mark the item done ,
and drop an empty sentinel file in < workspace > / results / .
"""
2025-04-23 18:27:26 +00:00
while True :
await semaphore . acquire ( )
work_item = await work_queue . get_work ( )
2025-04-23 15:54:49 -07:00
2025-04-23 18:27:26 +00:00
if work_item is None :
2025-04-23 15:54:49 -07:00
logger . info ( f " Worker { worker_id } exiting – queue empty " )
2025-04-23 18:27:26 +00:00
semaphore . release ( )
break
2025-04-23 19:56:13 +00:00
file_uri = work_item . work_paths [ 0 ]
2025-04-23 15:54:49 -07:00
logger . info ( f " Worker { worker_id } processing { file_uri } " )
2025-04-23 18:27:26 +00:00
try :
2025-04-23 15:54:49 -07:00
# ------------------------------------------------------------------
# Run the per-file pipeline
# ------------------------------------------------------------------
attributes = await process_file ( args , worker_id , file_uri )
2025-04-23 19:56:13 +00:00
2025-04-23 15:54:49 -07:00
# 1. Build the relative path that mirrors documents/…
2025-04-23 14:47:00 -07:00
if file_uri . startswith ( " s3:// " ) :
2025-04-23 19:56:13 +00:00
_ , key = parse_s3_path ( file_uri )
_ , docs_prefix = parse_s3_path ( args . dataset )
2025-04-23 14:47:00 -07:00
rel_path = key [ len ( os . path . join ( docs_prefix , " documents/ " ) ) : ]
2025-04-23 19:56:13 +00:00
else :
2025-04-23 14:47:00 -07:00
docs_root = os . path . join ( args . dataset , " documents " )
2025-04-23 19:56:13 +00:00
rel_path = os . path . relpath ( file_uri , docs_root )
2025-04-24 20:14:42 +00:00
out_rel = os . path . join ( " attributes " , args . attribute_name , rel_path )
2025-04-23 15:54:49 -07:00
out_jsonl = " \n " . join ( json . dumps ( x ) for x in attributes ) + " \n "
2025-04-23 19:56:13 +00:00
2025-04-23 15:54:49 -07:00
# 2. Preserve compression type
if rel_path . endswith ( " .gz " ) :
payload = gzip . compress ( out_jsonl . encode ( " utf-8 " ) )
elif rel_path . endswith ( ( " .zst " , " .ztd " ) ) :
payload = zstd . ZstdCompressor ( ) . compress ( out_jsonl . encode ( " utf-8 " ) )
2025-04-23 19:56:13 +00:00
else :
2025-04-23 15:54:49 -07:00
payload = out_jsonl . encode ( " utf-8 " )
# 3. Write to args.dataset (local or S3)
if args . dataset . startswith ( " s3:// " ) :
bucket , prefix = parse_s3_path ( args . dataset )
key = os . path . join ( prefix , out_rel )
workspace_s3 . put_object ( Bucket = bucket , Key = key , Body = payload )
else :
out_path = os . path . join ( args . dataset , out_rel )
2025-04-23 19:56:13 +00:00
os . makedirs ( os . path . dirname ( out_path ) , exist_ok = True )
2025-04-23 15:54:49 -07:00
with open ( out_path , " wb " ) as fh :
fh . write ( payload )
2025-04-23 18:27:26 +00:00
2025-04-23 15:54:49 -07:00
# 4. Mark queue item done
2025-04-23 18:27:26 +00:00
await work_queue . mark_done ( work_item )
2025-04-23 15:54:49 -07:00
# 5. Drop empty sentinel file in <workspace>/results/
sentinel_rel = os . path . join ( " results " , f " output_ { work_item . hash } .jsonl " )
if args . scratch . startswith ( " s3:// " ) :
bkt , pfx = parse_s3_path ( args . scratch )
key = os . path . join ( pfx , sentinel_rel )
workspace_s3 . put_object ( Bucket = bkt , Key = key , Body = b " " )
else :
sentinel_path = os . path . join ( args . scratch , sentinel_rel )
os . makedirs ( os . path . dirname ( sentinel_path ) , exist_ok = True )
open ( sentinel_path , " w " ) . close ( )
except Exception as exc :
logger . exception ( f " Worker { worker_id } exception: { exc !s} " )
2025-04-23 18:27:26 +00:00
finally :
semaphore . release ( )
async def sglang_server_task ( model_name_or_path , args , semaphore ) :
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
2025-04-24 10:18:30 -07:00
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
2025-04-23 18:27:26 +00:00
cmd = [
" python3 " ,
" -m " ,
" sglang.launch_server " ,
" --model-path " ,
model_name_or_path ,
" --port " ,
str ( SGLANG_SERVER_PORT ) ,
" --log-level-http " ,
" warning " ,
2025-04-28 15:57:20 -07:00
" --mem-fraction-static " , " 0.40 "
2025-04-23 18:27:26 +00:00
]
proc = await asyncio . create_subprocess_exec (
* cmd ,
stdout = asyncio . subprocess . PIPE ,
stderr = asyncio . subprocess . PIPE ,
)
# Ensure the subprocess is terminated on exit
def _kill_proc ( ) :
proc . terminate ( )
atexit . register ( _kill_proc )
# Shared variables between tasks
last_running_req , last_queue_req = 0 , 0
server_printed_ready_message = False
last_semaphore_release = time . time ( )
async def process_line ( line ) :
nonlocal last_running_req , last_queue_req , last_semaphore_release , server_printed_ready_message
sglang_logger . info ( line )
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
# can see any warnings/errors more easily
if not server_printed_ready_message :
logger . info ( line )
if " Detected errors during sampling " in line :
logger . error ( " Cannot continue, sampling errors detected, model is probably corrupt " )
sys . exit ( 1 )
# TODO, need to trace down this issue in sglang itself, but it will otherwise cause the server to lock up
if " IndexError: list index out of range " in line :
logger . error ( " IndexError in model, restarting server " )
proc . terminate ( )
if not server_printed_ready_message and " The server is fired up and ready to roll! " in line :
server_printed_ready_message = True
last_semaphore_release = time . time ( )
match = re . search ( r " #running-req: ( \ d+) " , line )
if match :
last_running_req = int ( match . group ( 1 ) )
match = re . search ( r " #queue-req: ( \ d+) " , line )
if match :
last_queue_req = int ( match . group ( 1 ) )
logger . info ( f " sglang running req: { last_running_req } queue req: { last_queue_req } " )
async def read_stream ( stream ) :
while True :
line = await stream . readline ( )
if not line :
break
try :
line = line . decode ( " utf-8 " ) . rstrip ( )
await process_line ( line )
except Exception as ex :
logger . warning ( f " Got { ex } when reading log line from inference server, skipping " )
async def timeout_task ( ) :
nonlocal last_running_req , last_queue_req , last_semaphore_release
try :
while True :
await asyncio . sleep ( 1 )
if server_printed_ready_message and last_queue_req == 0 and time . time ( ) - last_semaphore_release > 30 and semaphore . locked ( ) :
semaphore . release ( )
last_semaphore_release = time . time ( )
logger . info ( " Semaphore released, allowing a worker to proceed. " )
except asyncio . CancelledError :
pass # Clean up if the task is cancelled
# Start tasks to read stdout, stderr, and handle timeout logic
stdout_task = asyncio . create_task ( read_stream ( proc . stdout ) )
stderr_task = asyncio . create_task ( read_stream ( proc . stderr ) )
timeout_task = asyncio . create_task ( timeout_task ( ) )
try :
await proc . wait ( )
except asyncio . CancelledError :
logger . info ( " Got cancellation request for SGLang server " )
proc . terminate ( )
raise
timeout_task . cancel ( )
await asyncio . gather ( stdout_task , stderr_task , timeout_task , return_exceptions = True )
async def sglang_server_host ( model_name_or_path , args , semaphore ) :
MAX_RETRIES = 5
retry = 0
2025-04-28 15:57:20 -07:00
await asyncio . sleep ( 1000000 )
2025-04-23 18:27:26 +00:00
while retry < MAX_RETRIES :
await sglang_server_task ( model_name_or_path , args , semaphore )
logger . warning ( " SGLang server task ended " )
retry + = 1
if retry > = MAX_RETRIES :
logger . error ( f " Ended up starting the sglang server more than { retry } times, cancelling pipeline " )
logger . error ( " " )
logger . error ( " Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html " )
sys . exit ( 1 )
async def sglang_server_ready ( ) :
max_attempts = 300
delay_sec = 1
url = f " http://localhost: { SGLANG_SERVER_PORT } /v1/models "
for attempt in range ( 1 , max_attempts + 1 ) :
try :
async with httpx . AsyncClient ( ) as session :
response = await session . get ( url )
if response . status_code == 200 :
logger . info ( " sglang server is ready. " )
return
else :
logger . info ( f " Attempt { attempt } : Unexpected status code { response . status_code } " )
except Exception :
logger . warning ( f " Attempt { attempt } : Please wait for sglang server to become ready... " )
await asyncio . sleep ( delay_sec )
raise Exception ( " sglang server did not become ready after waiting. " )
async def download_model ( model_name_or_path : str ) :
if model_name_or_path . startswith ( " s3:// " ) or model_name_or_path . startswith ( " gs:// " ) or model_name_or_path . startswith ( " weka:// " ) :
logger . info ( f " Downloading model directory from ' { model_name_or_path } ' " )
model_cache_dir = os . path . join ( os . path . expanduser ( " ~ " ) , " .cache " , " olmocr " , " model " )
download_directory ( [ model_name_or_path ] , model_cache_dir )
return model_cache_dir
elif os . path . isabs ( model_name_or_path ) and os . path . isdir ( model_name_or_path ) :
logger . info ( f " Using local model path at ' { model_name_or_path } ' " )
return model_name_or_path
else :
logger . info ( f " Downloading model with hugging face ' { model_name_or_path } ' " )
snapshot_download ( repo_id = model_name_or_path )
return model_name_or_path
async def metrics_reporter ( work_queue ) :
while True :
# Leading newlines preserve table formatting in logs
logger . info ( f " Queue remaining: { work_queue . size } " )
logger . info ( " \n " + str ( metrics ) )
await asyncio . sleep ( 10 )
def submit_beaker_job ( args ) :
from beaker import ( # type: ignore
Beaker ,
Constraints ,
EnvVar ,
ExperimentSpec ,
ImageSource ,
Priority ,
ResultSpec ,
SecretNotFound ,
TaskContext ,
TaskResources ,
TaskSpec ,
)
b = Beaker . from_env ( default_workspace = args . beaker_workspace )
account = b . account . whoami ( )
owner = account . name
beaker_image = f " jakep/olmocr-inference- { VERSION } "
task_name = f " olmocr- { os . path . basename ( args . dataset . rstrip ( ' / ' ) ) } "
# Take out --beaker flag so the workers will just run things
args_list = [ arg for arg in sys . argv [ 1 : ] if arg != " --beaker " ]
# Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
args_list = [ arg for i , arg in enumerate ( args_list ) if not ( arg . startswith ( " --pdfs " ) or ( i > 0 and args_list [ i - 1 ] == " --pdfs " ) ) ]
try :
b . secret . get ( f " { owner } -WEKA_ACCESS_KEY_ID " , args . beaker_workspace )
b . secret . get ( f " { owner } -WEKA_SECRET_ACCESS_KEY " , args . beaker_workspace )
b . secret . get ( f " { owner } -AWS_CREDENTIALS_FILE " , args . beaker_workspace )
except SecretNotFound :
print (
f " Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace { args . beaker_workspace } ? [y/n] "
)
if input ( ) . strip ( ) . lower ( ) != " y " :
print ( " Exiting... " )
sys . exit ( 1 )
b . secret . write ( f " { owner } -WEKA_ACCESS_KEY_ID " , os . environ . get ( " WEKA_ACCESS_KEY_ID " , " " ) , args . beaker_workspace )
b . secret . write ( f " { owner } -WEKA_SECRET_ACCESS_KEY " , os . environ . get ( " WEKA_SECRET_ACCESS_KEY " , " " ) , args . beaker_workspace )
b . secret . write (
f " { owner } -AWS_CREDENTIALS_FILE " ,
open ( os . path . join ( os . path . expanduser ( " ~ " ) , " .aws " , " credentials " ) ) . read ( ) ,
args . beaker_workspace ,
)
env_var_secrets = [
EnvVar ( name = " WEKA_ACCESS_KEY_ID " , secret = f " { owner } -WEKA_ACCESS_KEY_ID " ) ,
EnvVar ( name = " WEKA_SECRET_ACCESS_KEY " , secret = f " { owner } -WEKA_SECRET_ACCESS_KEY " ) ,
EnvVar ( name = " AWS_CREDENTIALS_FILE " , secret = f " { owner } -AWS_CREDENTIALS_FILE " ) ,
]
try :
b . secret . get ( " OLMOCR_PREVIEW_HF_TOKEN " , args . beaker_workspace )
env_var_secrets . append ( EnvVar ( name = " HF_TOKEN " , secret = " OLMOCR_PREVIEW_HF_TOKEN " ) )
except SecretNotFound :
pass
try :
b . secret . get ( " OE_DATA_GCS_SA_KEY " , args . beaker_workspace )
env_var_secrets . append ( EnvVar ( name = " GOOGLE_APPLICATION_CREDENTIALS_FILE " , secret = " OE_DATA_GCS_SA_KEY " ) )
except SecretNotFound :
print ( " Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline): " )
lines = [ ]
prev_empty = False
for line in iter ( input , None ) :
if not line and prev_empty :
break
prev_empty = not line
lines . append ( line )
gcs_sa_key = " \n " . join ( lines [ : - 1 ] ) . strip ( ) # Remove the last empty line
if gcs_sa_key :
b . secret . write ( " OE_DATA_GCS_SA_KEY " , gcs_sa_key , args . beaker_workspace )
env_var_secrets . append ( EnvVar ( name = " GOOGLE_APPLICATION_CREDENTIALS_FILE " , secret = " OE_DATA_GCS_SA_KEY " ) )
# Create the experiment spec
experiment_spec = ExperimentSpec (
budget = " ai2/oe-data " ,
description = task_name ,
tasks = [
TaskSpec (
name = task_name ,
propagate_failure = False ,
propagate_preemption = False ,
replicas = args . beaker_gpus ,
context = TaskContext (
priority = Priority ( args . beaker_priority ) ,
preemptible = True ,
) ,
image = ImageSource ( beaker = beaker_image ) ,
command = [ " python " , " -m " , " scripts/tagging_pipeline.py " ] + args_list ,
env_vars = [ EnvVar ( name = " BEAKER_JOB_NAME " , value = task_name ) , EnvVar ( name = " OWNER " , value = owner ) ] + env_var_secrets ,
resources = TaskResources ( gpu_count = 1 ) ,
constraints = Constraints ( cluster = args . beaker_cluster if isinstance ( args . beaker_cluster , list ) else [ args . beaker_cluster ] ) ,
result = ResultSpec ( path = " /noop-results " ) ,
)
] ,
)
experiment_data = b . experiment . create ( spec = experiment_spec , workspace = args . beaker_workspace )
print ( f " Experiment URL: https://beaker.org/ex/ { experiment_data . id } " )
async def main ( ) :
parser = argparse . ArgumentParser ( description = " Tagging pipeline for Dolma JSONL dataset " )
parser . add_argument ( " dataset " , help = " Dolma dataset root (local or s3://) with documents/ folder " )
parser . add_argument ( " scratch " , help = " Scratch workspace (local dir or s3://) " )
parser . add_argument ( " --workers " , type = int , default = 4 , help = " Number of concurrent workers " )
parser . add_argument ( " --model " , default = " google/gemma-3-4b-it " , help = " SGLang model path or name " )
2025-04-24 20:14:42 +00:00
parser . add_argument ( " --attribute_name " , default = " model_pii_tagging " , help = " Path to use for attribute naming " )
2025-04-23 18:27:26 +00:00
# Beaker/job running stuff
parser . add_argument ( " --beaker " , action = " store_true " , help = " Submit this job to beaker instead of running locally " )
parser . add_argument ( " --beaker_workspace " , help = " Beaker workspace to submit to " , default = " ai2/olmocr " )
parser . add_argument (
" --beaker_cluster " ,
help = " Beaker clusters you want to run on " ,
default = [ " ai2/jupiter-cirrascale-2 " , " ai2/ceres-cirrascale " , " ai2/neptune-cirrascale " , " ai2/saturn-cirrascale " , " ai2/augusta-google-1 " ] ,
)
parser . add_argument ( " --beaker_gpus " , type = int , default = 1 , help = " Number of gpu replicas to run " )
parser . add_argument ( " --beaker_priority " , type = str , default = " normal " , help = " Beaker priority level for the job " )
parser . add_argument ( " --port " , type = int , default = 30024 , help = " Port for SGLang server " )
args = parser . parse_args ( )
global SGLANG_SERVER_PORT , workspace_s3 , dataset_s3
SGLANG_SERVER_PORT = args . port
workspace_s3 = boto3 . client ( " s3 " )
dataset_s3 = boto3 . client ( " s3 " )
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
2025-04-24 20:24:06 +00:00
if " BEAKER_JOB_ID " in os . environ :
2025-04-23 18:27:26 +00:00
sglang_logger . addHandler ( console_handler )
2025-04-24 20:31:59 +00:00
if " AWS_CREDENTIALS_FILE " in os . environ :
cred_path = os . path . join ( os . path . expanduser ( " ~ " ) , " .aws " , " credentials " )
os . makedirs ( os . path . dirname ( cred_path ) , exist_ok = True )
with open ( cred_path , " w " ) as f :
f . write ( os . environ . get ( " AWS_CREDENTIALS_FILE " ) )
if " GOOGLE_APPLICATION_CREDENTIALS " in os . environ :
cred_path = os . path . join ( os . path . expanduser ( " ~ " ) , " .gcs " , " credentials " )
os . makedirs ( os . path . dirname ( cred_path ) , exist_ok = True )
with open ( cred_path , " w " ) as f :
f . write ( os . environ . get ( " GOOGLE_APPLICATION_CREDENTIALS_FILE " ) )
os . environ [ " GOOGLE_APPLICATION_CREDENTIALS " ] = cred_path
2025-04-23 18:27:26 +00:00
workspace_s3 = boto3 . client ( " s3 " )
2025-04-24 20:43:54 +00:00
dataset_s3 = boto3 . client ( " s3 " )
2025-04-23 18:27:26 +00:00
# Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time
replica_count = int ( os . environ . get ( " BEAKER_REPLICA_COUNT " , " 1 " ) )
interval = 10 if ( replica_count - 1 ) * 10 < = 240 else 240 / max ( 1 , replica_count - 1 )
sleep_time = int ( int ( os . environ . get ( " BEAKER_REPLICA_RANK " , " 0 " ) ) * interval )
logger . info ( f " Beaker job sleeping for { sleep_time } seconds to stagger model downloads " )
await asyncio . sleep ( sleep_time )
# Initialize work queue
if args . scratch . startswith ( " s3:// " ) :
work_queue = S3WorkQueue ( workspace_s3 , args . scratch )
else :
work_queue = LocalWorkQueue ( args . scratch )
# Discover input files
files = set ( )
if args . dataset . startswith ( " s3:// " ) :
pattern = args . dataset . rstrip ( " / " ) + " /documents/*.jsonl* "
matched = expand_s3_glob ( dataset_s3 , pattern )
files = set ( matched . keys ( ) )
else :
docs_dir = os . path . join ( args . dataset , " documents " )
for root , _ , fns in os . walk ( docs_dir ) :
for fn in fns :
if fn . endswith ( ( " .jsonl " , " .jsonl.gz " , " .jsonl.ztd " ) ) :
files . add ( os . path . join ( root , fn ) )
# Populate the work queue if needed
await work_queue . populate_queue ( list ( files ) , items_per_group = 1 )
if args . beaker :
submit_beaker_job ( args )
return
# If you get this far, then you are doing inference and need a GPU
2025-04-23 14:46:16 -07:00
check_sglang_version ( )
check_torch_gpu_available ( )
2025-04-23 18:27:26 +00:00
logger . info ( f " Starting pipeline with PID { os . getpid ( ) } " )
# Download the model before you do anything else
model_name_or_path = await download_model ( args . model )
# Initialize the work queue
qsize = await work_queue . initialize_queue ( )
if qsize == 0 :
logger . info ( " No work to do, exiting " )
return
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio . Semaphore ( 1 )
2025-04-23 14:46:16 -07:00
sglang_server = asyncio . create_task ( sglang_server_host ( model_name_or_path , args , semaphore ) )
2025-04-23 18:27:26 +00:00
2025-04-23 14:46:16 -07:00
await sglang_server_ready ( )
2025-04-23 18:27:26 +00:00
metrics_task = asyncio . create_task ( metrics_reporter ( work_queue ) )
# Create worker tasks to process the queue concurrently.
worker_tasks = [ ]
for i in range ( args . workers ) :
task = asyncio . create_task ( worker ( args , work_queue , semaphore , worker_id = i ) )
worker_tasks . append ( task )
# Wait for all worker tasks to finish
await asyncio . gather ( * worker_tasks )
2025-04-23 14:46:16 -07:00
sglang_server . cancel ( )
2025-04-23 18:27:26 +00:00
metrics_task . cancel ( )
logger . info ( " Work done " )
if __name__ == " __main__ " :
2025-04-24 20:14:42 +00:00
asyncio . run ( main ( ) )