mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-04 12:07:15 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			779 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			779 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#!/usr/bin/env python3
 | 
						||
"""
 | 
						||
Tagging pipeline for Dolma JSONL datasets.
 | 
						||
 | 
						||
For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
 | 
						||
this script issues a model prompt completion
 | 
						||
collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
 | 
						||
scratch/attributes/, mirroring the input structure.
 | 
						||
"""
 | 
						||
import argparse
 | 
						||
import asyncio
 | 
						||
import atexit
 | 
						||
import gzip
 | 
						||
import json
 | 
						||
import logging
 | 
						||
import os
 | 
						||
import random
 | 
						||
import re
 | 
						||
import sys
 | 
						||
import time
 | 
						||
from typing import Optional
 | 
						||
from urllib.parse import urlparse
 | 
						||
 | 
						||
import boto3
 | 
						||
import httpx
 | 
						||
import zstandard as zstd
 | 
						||
from huggingface_hub import snapshot_download
 | 
						||
from pydantic import BaseModel, Field, ValidationError
 | 
						||
 | 
						||
from olmocr.check import check_torch_gpu_available
 | 
						||
from olmocr.metrics import MetricsKeeper
 | 
						||
from olmocr.s3_utils import (
 | 
						||
    download_directory,
 | 
						||
    expand_s3_glob,
 | 
						||
    get_s3_bytes_with_backoff,
 | 
						||
    parse_s3_path,
 | 
						||
)
 | 
						||
from olmocr.version import VERSION
 | 
						||
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
 | 
						||
 | 
						||
# Initialize logger
 | 
						||
logger = logging.getLogger(__name__)
 | 
						||
logger.setLevel(logging.DEBUG)
 | 
						||
logger.propagate = False
 | 
						||
 | 
						||
server_logger = logging.getLogger("vllm")
 | 
						||
server_logger.propagate = False
 | 
						||
 | 
						||
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
 | 
						||
file_handler.setLevel(logging.DEBUG)
 | 
						||
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
 | 
						||
 | 
						||
console_handler = logging.StreamHandler()
 | 
						||
console_handler.setLevel(logging.INFO)
 | 
						||
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
 | 
						||
 | 
						||
# Add handlers to the logger
 | 
						||
logger.addHandler(file_handler)
 | 
						||
logger.addHandler(console_handler)
 | 
						||
server_logger.addHandler(file_handler)
 | 
						||
 | 
						||
 | 
						||
# Default port; overridden by --port
 | 
						||
SERVER_PORT = 30024
 | 
						||
 | 
						||
# Global variables for token statistics
 | 
						||
metrics = MetricsKeeper(window=60 * 5)
 | 
						||
 | 
						||
 | 
						||
class PIIClassification(BaseModel):
 | 
						||
    primary_language: str = Field(..., description="Primary language as a two-letter code")
 | 
						||
    document_type: str = Field(..., description="Basic summary of document type classification")
 | 
						||
    is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
 | 
						||
    contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
 | 
						||
 | 
						||
 | 
						||
async def _process_single_page(page_text: str) -> PIIClassification:
 | 
						||
    """Helper function to process a single document or page."""
 | 
						||
    text = page_text
 | 
						||
 | 
						||
    query = {
 | 
						||
        "model": "google/gemma-3-4b-it",
 | 
						||
        "messages": [
 | 
						||
            {
 | 
						||
                "role": "user",
 | 
						||
                "content": [
 | 
						||
                    {
 | 
						||
                        "type": "text",
 | 
						||
                        "text": (
 | 
						||
                            f"{text}\n\n-----------\n"
 | 
						||
                            "Given the text above, determine what type of document it is, and if it's a resume/CV. answer in JSON. The format of your json object should be {'primary_language': str, 'document_type': str, 'is_resume_cv': bool, 'contains_pii': bool}"
 | 
						||
                        ),
 | 
						||
                    }
 | 
						||
                ],
 | 
						||
            }
 | 
						||
        ],
 | 
						||
        "max_tokens": 100,
 | 
						||
        "temperature": 0.0,
 | 
						||
        "response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
 | 
						||
    }
 | 
						||
 | 
						||
    url = f"http://localhost:{SERVER_PORT}/v1/chat/completions"
 | 
						||
 | 
						||
    # ---------- HTTP call ---------------------------------------------------
 | 
						||
    try:
 | 
						||
        status, body = await apost(url, json_data=query)
 | 
						||
    except Exception as e:
 | 
						||
        logger.warning(f"Server network error: {e!s}")
 | 
						||
        metrics.add_metrics(server_errors=1)
 | 
						||
        return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
 | 
						||
 | 
						||
    metrics.add_metrics(server_requests=1)
 | 
						||
 | 
						||
    if status != 200:
 | 
						||
        logger.warning(f"Server HTTP {status}: {body[:250]!r}")
 | 
						||
        metrics.add_metrics(server_errors=1)
 | 
						||
        return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
 | 
						||
 | 
						||
    # ---------- Parse base JSON --------------------------------------------
 | 
						||
    try:
 | 
						||
        base = json.loads(body)
 | 
						||
    except json.JSONDecodeError:
 | 
						||
        logger.warning(f"Server response is not valid JSON: {body[:250]!r}")
 | 
						||
        metrics.add_metrics(server_errors=1)
 | 
						||
        return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
 | 
						||
 | 
						||
    # Token accounting if available
 | 
						||
    if "usage" in base:
 | 
						||
        metrics.add_metrics(
 | 
						||
            server_input_tokens=base["usage"].get("prompt_tokens", 0),
 | 
						||
            server_output_tokens=base["usage"].get("completion_tokens", 0),
 | 
						||
        )
 | 
						||
 | 
						||
    # ---------- Extract the model message ----------------------------------
 | 
						||
    try:
 | 
						||
        content = base["choices"][0]["message"].get("content")
 | 
						||
    except (KeyError, IndexError, AttributeError) as e:
 | 
						||
        logger.warning(f"Missing fields in Server response: {e!s}")
 | 
						||
        metrics.add_metrics(server_errors=1)
 | 
						||
        return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
 | 
						||
 | 
						||
    if not isinstance(content, str):
 | 
						||
        logger.warning("Server `content` is not a string; treating as error.")
 | 
						||
        metrics.add_metrics(server_errors=1)
 | 
						||
        return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
 | 
						||
 | 
						||
    try:
 | 
						||
        pii_classification: PIIClassification = PIIClassification.model_validate_json(content)
 | 
						||
        return pii_classification
 | 
						||
    except ValidationError as e:
 | 
						||
        logger.warning(f"Unable to parse pii classification object: {e!s}")
 | 
						||
        metrics.add_metrics(server_errors=1)
 | 
						||
        return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
 | 
						||
 | 
						||
 | 
						||
# Manual simple implementation of HTTP Post
 | 
						||
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
 | 
						||
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
 | 
						||
# that at the scale of 100M+ requests, that they deadlock in different strange ways
 | 
						||
async def apost(url, json_data):
 | 
						||
    parsed_url = urlparse(url)
 | 
						||
    host = parsed_url.hostname
 | 
						||
    port = parsed_url.port or 80
 | 
						||
    path = parsed_url.path or "/"
 | 
						||
 | 
						||
    writer = None
 | 
						||
    try:
 | 
						||
        reader, writer = await asyncio.open_connection(host, port)
 | 
						||
 | 
						||
        json_payload = json.dumps(json_data)
 | 
						||
        request = (
 | 
						||
            f"POST {path} HTTP/1.1\r\n"
 | 
						||
            f"Host: {host}\r\n"
 | 
						||
            f"Content-Type: application/json\r\n"
 | 
						||
            f"Content-Length: {len(json_payload)}\r\n"
 | 
						||
            f"Connection: close\r\n\r\n"
 | 
						||
            f"{json_payload}"
 | 
						||
        )
 | 
						||
        writer.write(request.encode())
 | 
						||
        await writer.drain()
 | 
						||
 | 
						||
        # Read status line
 | 
						||
        status_line = await reader.readline()
 | 
						||
        if not status_line:
 | 
						||
            raise ConnectionError("No response from server")
 | 
						||
        status_parts = status_line.decode().strip().split(" ", 2)
 | 
						||
        if len(status_parts) < 2:
 | 
						||
            raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
 | 
						||
        status_code = int(status_parts[1])
 | 
						||
 | 
						||
        # Read headers
 | 
						||
        headers = {}
 | 
						||
        while True:
 | 
						||
            line = await reader.readline()
 | 
						||
            if line in (b"\r\n", b"\n", b""):
 | 
						||
                break
 | 
						||
            key, _, value = line.decode().partition(":")
 | 
						||
            headers[key.strip().lower()] = value.strip()
 | 
						||
 | 
						||
        # Read response body
 | 
						||
        if "content-length" in headers:
 | 
						||
            body_length = int(headers["content-length"])
 | 
						||
            response_body = await reader.readexactly(body_length)
 | 
						||
        else:
 | 
						||
            raise ConnectionError("Anything other than fixed content length responses are not implemented yet")
 | 
						||
 | 
						||
        return status_code, response_body
 | 
						||
    except Exception as e:
 | 
						||
        # Pass through errors
 | 
						||
        raise e
 | 
						||
    finally:
 | 
						||
        # But just make sure to close the socket on your way out
 | 
						||
        if writer is not None:
 | 
						||
            try:
 | 
						||
                writer.close()
 | 
						||
                await writer.wait_closed()
 | 
						||
            except:
 | 
						||
                pass
 | 
						||
 | 
						||
 | 
						||
async def process_dolma_document(args, dolma_doc, sem):
 | 
						||
    """
 | 
						||
    Query model to detect PII, enforcing a JSON schema.
 | 
						||
 | 
						||
    Resilient to:
 | 
						||
      • Transport / HTTP errors
 | 
						||
      • Missing or malformed fields in the response
 | 
						||
      • Non-string or None `content`
 | 
						||
      • Bad JSON in the model's answer
 | 
						||
 | 
						||
    Always returns: (doc_id, contains_pii: bool, text_length: int)
 | 
						||
    """
 | 
						||
    doc_id = dolma_doc.get("id")
 | 
						||
    text = dolma_doc.get("text", "") or ""
 | 
						||
 | 
						||
    language_key_name = f"{args.model.replace('/', '_')}_language"
 | 
						||
    resume_cv_key_name = f"{args.model.replace('/', '_')}_is_resume_cv"
 | 
						||
 | 
						||
    result_attributes = {resume_cv_key_name: [], language_key_name: []}
 | 
						||
 | 
						||
    # If pdf_page_numbers is present, split the text and process each page separately
 | 
						||
    if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
 | 
						||
        page_numbers = dolma_doc["attributes"]["pdf_page_numbers"]
 | 
						||
 | 
						||
        logger.info(f"Document {doc_id} has {len(page_numbers)} pages, processing each individually")
 | 
						||
 | 
						||
        # Filter pages down to actual real content
 | 
						||
        selected_page_numbers = [tuple(p) for p in page_numbers if p[0] < p[1]]
 | 
						||
        first_page_number = selected_page_numbers[0]
 | 
						||
 | 
						||
        # Sample 3 pages max per document, but always include the first page, it's a good signal for CV classification
 | 
						||
        random.shuffle(selected_page_numbers)
 | 
						||
        selected_page_numbers = selected_page_numbers[:3]
 | 
						||
 | 
						||
        if first_page_number not in selected_page_numbers:
 | 
						||
            selected_page_numbers[0] = first_page_number
 | 
						||
 | 
						||
        for start_pos, end_pos, page_num in page_numbers:
 | 
						||
            if (start_pos, end_pos, page_num) in selected_page_numbers:
 | 
						||
                page_text = text[start_pos:end_pos]
 | 
						||
 | 
						||
                # Process each page with the semaphore to limit concurrent requests
 | 
						||
                async with sem:
 | 
						||
                    pii_class = await _process_single_page(page_text)
 | 
						||
 | 
						||
                result_attributes[resume_cv_key_name].append([start_pos, end_pos, pii_class.is_resume_cv])
 | 
						||
                result_attributes[language_key_name].append([start_pos, end_pos, pii_class.primary_language])
 | 
						||
            else:
 | 
						||
                result_attributes[resume_cv_key_name].append([start_pos, end_pos, None])
 | 
						||
                result_attributes[language_key_name].append([start_pos, end_pos, None])
 | 
						||
 | 
						||
        return result_attributes
 | 
						||
    else:
 | 
						||
        raise NotImplementedError("Missing code here, expecting this to be dolma docs made by olmocr....")
 | 
						||
 | 
						||
 | 
						||
async def process_file(args, worker_id: int, file_uri: str):
 | 
						||
    """
 | 
						||
    Download a JSONL file, query model per record, and collect attributes.
 | 
						||
    """
 | 
						||
    # Fetch raw bytes (S3 or local)
 | 
						||
    if file_uri.startswith("s3://"):
 | 
						||
        raw = await asyncio.to_thread(get_s3_bytes_with_backoff, dataset_s3, file_uri)
 | 
						||
    else:
 | 
						||
        with open(file_uri, "rb") as f:
 | 
						||
            raw = f.read()
 | 
						||
 | 
						||
    # Decompress if needed
 | 
						||
    if file_uri.endswith(".gz"):
 | 
						||
        file_bytes = gzip.decompress(raw)
 | 
						||
    elif file_uri.endswith(".ztd") or file_uri.endswith(".zst") or file_uri.endswith(".zstd"):
 | 
						||
        dctx = zstd.ZstdDecompressor()
 | 
						||
        file_bytes = dctx.decompress(raw, max_output_size=1_000_000_000)
 | 
						||
    else:
 | 
						||
        file_bytes = raw
 | 
						||
 | 
						||
    lines = file_bytes.decode("utf-8").splitlines()
 | 
						||
    page_tasks = {}
 | 
						||
 | 
						||
    # Send all records in parallel, max N queued at a time
 | 
						||
    sem = asyncio.Semaphore(args.parallel_requests)
 | 
						||
 | 
						||
    async with asyncio.TaskGroup() as tg:
 | 
						||
        for line in lines:
 | 
						||
            dolma_doc = json.loads(line)
 | 
						||
            task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
 | 
						||
            page_tasks[dolma_doc["id"]] = (task, dolma_doc)
 | 
						||
 | 
						||
    logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}")
 | 
						||
 | 
						||
    # Collect results and build attributes
 | 
						||
    attributes = []
 | 
						||
    for doc_id, (task, dolma_doc) in page_tasks.items():
 | 
						||
        doc_attributes = task.result()
 | 
						||
 | 
						||
        attributes.append({"id": doc_id, "attributes": doc_attributes})
 | 
						||
 | 
						||
    return attributes
 | 
						||
 | 
						||
 | 
						||
async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, worker_id: int):
 | 
						||
    """
 | 
						||
    Pop work-items off the queue, run PII tagging, write the attributes file
 | 
						||
    next to the dataset (keeping the original compression), mark the item done,
 | 
						||
    and drop an empty sentinel file in <workspace>/results/.
 | 
						||
    """
 | 
						||
    while True:
 | 
						||
        await semaphore.acquire()
 | 
						||
        work_item = await work_queue.get_work()
 | 
						||
 | 
						||
        if work_item is None:
 | 
						||
            logger.info(f"Worker {worker_id} exiting – queue empty")
 | 
						||
            semaphore.release()
 | 
						||
            break
 | 
						||
 | 
						||
        file_uri = work_item.work_paths[0]
 | 
						||
        logger.info(f"Worker {worker_id} processing {file_uri}")
 | 
						||
 | 
						||
        try:
 | 
						||
            # ------------------------------------------------------------------
 | 
						||
            # Run the per-file pipeline
 | 
						||
            # ------------------------------------------------------------------
 | 
						||
            attributes = await process_file(args, worker_id, file_uri)
 | 
						||
 | 
						||
            # 1. Build the relative path that mirrors documents/…
 | 
						||
            if file_uri.startswith("s3://"):
 | 
						||
                _, key = parse_s3_path(file_uri)
 | 
						||
                _, docs_prefix = parse_s3_path(args.dataset)
 | 
						||
                rel_path = key[len(os.path.join(docs_prefix, "documents/")) :]
 | 
						||
            else:
 | 
						||
                docs_root = os.path.join(args.dataset, "documents")
 | 
						||
                rel_path = os.path.relpath(file_uri, docs_root)
 | 
						||
 | 
						||
            out_rel = os.path.join("attributes", args.attribute_name, rel_path)
 | 
						||
            out_jsonl = "\n".join(json.dumps(x) for x in attributes) + "\n"
 | 
						||
 | 
						||
            # 2. Preserve compression type
 | 
						||
            if rel_path.endswith(".gz"):
 | 
						||
                payload = gzip.compress(out_jsonl.encode("utf-8"))
 | 
						||
            elif rel_path.endswith((".zst", ".ztd")):
 | 
						||
                payload = zstd.ZstdCompressor().compress(out_jsonl.encode("utf-8"))
 | 
						||
            else:
 | 
						||
                payload = out_jsonl.encode("utf-8")
 | 
						||
 | 
						||
            # 3. Write to args.dataset (local or S3)
 | 
						||
            if args.dataset.startswith("s3://"):
 | 
						||
                bucket, prefix = parse_s3_path(args.dataset)
 | 
						||
                key = os.path.join(prefix, out_rel)
 | 
						||
                workspace_s3.put_object(Bucket=bucket, Key=key, Body=payload)
 | 
						||
            else:
 | 
						||
                out_path = os.path.join(args.dataset, out_rel)
 | 
						||
                os.makedirs(os.path.dirname(out_path), exist_ok=True)
 | 
						||
                with open(out_path, "wb") as fh:
 | 
						||
                    fh.write(payload)
 | 
						||
 | 
						||
            # 4. Mark queue item done
 | 
						||
            await work_queue.mark_done(work_item)
 | 
						||
 | 
						||
            # 5. Drop empty sentinel file in <workspace>/results/
 | 
						||
            sentinel_rel = os.path.join("results", f"output_{work_item.hash}.jsonl")
 | 
						||
            if args.scratch.startswith("s3://"):
 | 
						||
                bkt, pfx = parse_s3_path(args.scratch)
 | 
						||
                key = os.path.join(pfx, sentinel_rel)
 | 
						||
                workspace_s3.put_object(Bucket=bkt, Key=key, Body=b"")
 | 
						||
            else:
 | 
						||
                sentinel_path = os.path.join(args.scratch, sentinel_rel)
 | 
						||
                os.makedirs(os.path.dirname(sentinel_path), exist_ok=True)
 | 
						||
                open(sentinel_path, "w").close()
 | 
						||
 | 
						||
        except Exception as exc:
 | 
						||
            logger.exception(f"Worker {worker_id} exception: {exc!s}")
 | 
						||
        finally:
 | 
						||
            semaphore.release()
 | 
						||
 | 
						||
 | 
						||
async def server_task(model_name_or_path, args, semaphore):
 | 
						||
    # Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
 | 
						||
    # mem_fraction_arg = ["--mem-fraction-static", "0.80"]
 | 
						||
 | 
						||
    cmd = [
 | 
						||
        "vllm",
 | 
						||
        "serve",
 | 
						||
        model_name_or_path,
 | 
						||
        "--port",
 | 
						||
        str(SERVER_PORT),
 | 
						||
        "--uvicorn-log-level",
 | 
						||
        "warning",
 | 
						||
        "--disable-log-requests",
 | 
						||
    ]
 | 
						||
 | 
						||
    proc = await asyncio.create_subprocess_exec(
 | 
						||
        *cmd,
 | 
						||
        stdout=asyncio.subprocess.PIPE,
 | 
						||
        stderr=asyncio.subprocess.PIPE,
 | 
						||
    )
 | 
						||
 | 
						||
    # Ensure the subprocess is terminated on exit
 | 
						||
    def _kill_proc():
 | 
						||
        proc.terminate()
 | 
						||
 | 
						||
    atexit.register(_kill_proc)
 | 
						||
 | 
						||
    # Shared variables between tasks
 | 
						||
    last_running_req, last_queue_req = 0, 0
 | 
						||
    server_printed_ready_message = False
 | 
						||
    last_semaphore_release = time.time()
 | 
						||
 | 
						||
    async def process_line(line):
 | 
						||
        nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
 | 
						||
        server_logger.info(line)
 | 
						||
 | 
						||
        # if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
 | 
						||
        # can see any warnings/errors more easily
 | 
						||
        if not server_printed_ready_message:
 | 
						||
            logger.info(line)
 | 
						||
 | 
						||
        if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
 | 
						||
            server_printed_ready_message = True
 | 
						||
            last_semaphore_release = time.time()
 | 
						||
 | 
						||
        match = re.search(r"Running: (\d+) reqs", line)
 | 
						||
        if match:
 | 
						||
            last_running_req = int(match.group(1))
 | 
						||
 | 
						||
        match = re.search(r"Waiting: (\d+) reqs", line)
 | 
						||
        if match:
 | 
						||
            last_queue_req = int(match.group(1))
 | 
						||
            logger.info(f"running req: {last_running_req} queue req: {last_queue_req}")
 | 
						||
 | 
						||
    async def read_stream(stream):
 | 
						||
        while True:
 | 
						||
            line = await stream.readline()
 | 
						||
            if not line:
 | 
						||
                break
 | 
						||
            try:
 | 
						||
                line = line.decode("utf-8").rstrip()
 | 
						||
                await process_line(line)
 | 
						||
            except Exception as ex:
 | 
						||
                logger.warning(f"Got {ex} when reading log line from inference server, skipping")
 | 
						||
 | 
						||
    async def timeout_task():
 | 
						||
        nonlocal last_running_req, last_queue_req, last_semaphore_release
 | 
						||
        try:
 | 
						||
            while True:
 | 
						||
                await asyncio.sleep(1)
 | 
						||
                if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
 | 
						||
                    semaphore.release()
 | 
						||
                    last_semaphore_release = time.time()
 | 
						||
                    logger.info("Semaphore released, allowing a worker to proceed.")
 | 
						||
        except asyncio.CancelledError:
 | 
						||
            pass  # Clean up if the task is cancelled
 | 
						||
 | 
						||
    # Start tasks to read stdout, stderr, and handle timeout logic
 | 
						||
    stdout_task = asyncio.create_task(read_stream(proc.stdout))
 | 
						||
    stderr_task = asyncio.create_task(read_stream(proc.stderr))
 | 
						||
    timeout_task = asyncio.create_task(timeout_task())
 | 
						||
 | 
						||
    try:
 | 
						||
        await proc.wait()
 | 
						||
    except asyncio.CancelledError:
 | 
						||
        logger.info("Got cancellation request for server")
 | 
						||
        proc.terminate()
 | 
						||
        raise
 | 
						||
 | 
						||
    timeout_task.cancel()
 | 
						||
    await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
 | 
						||
 | 
						||
 | 
						||
async def server_host(model_name_or_path, args, semaphore):
 | 
						||
    MAX_RETRIES = 5
 | 
						||
    retry = 0
 | 
						||
 | 
						||
    while retry < MAX_RETRIES:
 | 
						||
        await server_task(model_name_or_path, args, semaphore)
 | 
						||
        logger.warning("Server task ended")
 | 
						||
        retry += 1
 | 
						||
 | 
						||
    if retry >= MAX_RETRIES:
 | 
						||
        logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline")
 | 
						||
        logger.error("")
 | 
						||
        logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4")
 | 
						||
        sys.exit(1)
 | 
						||
 | 
						||
 | 
						||
async def check_server_ready():
 | 
						||
    max_attempts = 300
 | 
						||
    delay_sec = 1
 | 
						||
    url = f"http://localhost:{SERVER_PORT}/v1/models"
 | 
						||
 | 
						||
    for attempt in range(1, max_attempts + 1):
 | 
						||
        try:
 | 
						||
            async with httpx.AsyncClient() as session:
 | 
						||
                response = await session.get(url)
 | 
						||
 | 
						||
                if response.status_code == 200:
 | 
						||
                    logger.info("server is ready.")
 | 
						||
                    return
 | 
						||
                else:
 | 
						||
                    logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
 | 
						||
        except Exception:
 | 
						||
            logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...")
 | 
						||
 | 
						||
        await asyncio.sleep(delay_sec)
 | 
						||
 | 
						||
    raise Exception("model server did not become ready after waiting.")
 | 
						||
 | 
						||
 | 
						||
async def download_model(model_name_or_path: str):
 | 
						||
    if model_name_or_path.startswith("s3://") or model_name_or_path.startswith("gs://") or model_name_or_path.startswith("weka://"):
 | 
						||
        logger.info(f"Downloading model directory from '{model_name_or_path}'")
 | 
						||
        model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
 | 
						||
        download_directory([model_name_or_path], model_cache_dir)
 | 
						||
        return model_cache_dir
 | 
						||
    elif os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
 | 
						||
        logger.info(f"Using local model path at '{model_name_or_path}'")
 | 
						||
        return model_name_or_path
 | 
						||
    else:
 | 
						||
        logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
 | 
						||
        snapshot_download(repo_id=model_name_or_path)
 | 
						||
        return model_name_or_path
 | 
						||
 | 
						||
 | 
						||
async def metrics_reporter(work_queue):
 | 
						||
    while True:
 | 
						||
        # Leading newlines preserve table formatting in logs
 | 
						||
        logger.info(f"Queue remaining: {work_queue.size}")
 | 
						||
        logger.info("\n" + str(metrics))
 | 
						||
        await asyncio.sleep(10)
 | 
						||
 | 
						||
 | 
						||
def submit_beaker_job(args):
 | 
						||
    from beaker import (  # type: ignore
 | 
						||
        Beaker,
 | 
						||
        Constraints,
 | 
						||
        EnvVar,
 | 
						||
        ExperimentSpec,
 | 
						||
        ImageSource,
 | 
						||
        Priority,
 | 
						||
        ResultSpec,
 | 
						||
        SecretNotFound,
 | 
						||
        TaskContext,
 | 
						||
        TaskResources,
 | 
						||
        TaskSpec,
 | 
						||
    )
 | 
						||
 | 
						||
    b = Beaker.from_env(default_workspace=args.beaker_workspace)
 | 
						||
    account = b.account.whoami()
 | 
						||
    owner = account.name
 | 
						||
    beaker_image = f"jakep/olmocr-tagging-{VERSION}"
 | 
						||
 | 
						||
    task_name = f"olmocr-{os.path.basename(args.dataset.rstrip('/'))}"
 | 
						||
 | 
						||
    # Take out --beaker flag so the workers will just run things
 | 
						||
    args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
 | 
						||
 | 
						||
    # Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
 | 
						||
    args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i - 1] == "--pdfs"))]
 | 
						||
 | 
						||
    try:
 | 
						||
        b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
 | 
						||
        b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
 | 
						||
        b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
 | 
						||
    except SecretNotFound:
 | 
						||
        print(
 | 
						||
            f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]"
 | 
						||
        )
 | 
						||
 | 
						||
        if input().strip().lower() != "y":
 | 
						||
            print("Exiting...")
 | 
						||
            sys.exit(1)
 | 
						||
 | 
						||
        b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
 | 
						||
        b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
 | 
						||
        b.secret.write(
 | 
						||
            f"{owner}-AWS_CREDENTIALS_FILE",
 | 
						||
            open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(),
 | 
						||
            args.beaker_workspace,
 | 
						||
        )
 | 
						||
 | 
						||
    env_var_secrets = [
 | 
						||
        EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
 | 
						||
        EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
 | 
						||
        EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
 | 
						||
    ]
 | 
						||
 | 
						||
    try:
 | 
						||
        b.secret.get("OLMOCR_PREVIEW_HF_TOKEN", args.beaker_workspace)
 | 
						||
        env_var_secrets.append(EnvVar(name="HF_TOKEN", secret="OLMOCR_PREVIEW_HF_TOKEN"))
 | 
						||
    except SecretNotFound:
 | 
						||
        pass
 | 
						||
 | 
						||
    try:
 | 
						||
        b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
 | 
						||
        env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
 | 
						||
    except SecretNotFound:
 | 
						||
        print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
 | 
						||
        lines = []
 | 
						||
        prev_empty = False
 | 
						||
        for line in iter(input, None):
 | 
						||
            if not line and prev_empty:
 | 
						||
                break
 | 
						||
            prev_empty = not line
 | 
						||
            lines.append(line)
 | 
						||
        gcs_sa_key = "\n".join(lines[:-1]).strip()  # Remove the last empty line
 | 
						||
        if gcs_sa_key:
 | 
						||
            b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
 | 
						||
            env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
 | 
						||
 | 
						||
    # Create the experiment spec
 | 
						||
    experiment_spec = ExperimentSpec(
 | 
						||
        budget="ai2/oe-data",
 | 
						||
        description=task_name,
 | 
						||
        tasks=[
 | 
						||
            TaskSpec(
 | 
						||
                name=task_name,
 | 
						||
                propagate_failure=False,
 | 
						||
                propagate_preemption=False,
 | 
						||
                replicas=args.beaker_gpus,
 | 
						||
                context=TaskContext(
 | 
						||
                    priority=Priority(args.beaker_priority),
 | 
						||
                    preemptible=True,
 | 
						||
                ),
 | 
						||
                image=ImageSource(beaker=beaker_image),
 | 
						||
                command=["python", "scripts/tagging_pipeline.py"] + args_list,
 | 
						||
                env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
 | 
						||
                resources=TaskResources(gpu_count=1),
 | 
						||
                constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
 | 
						||
                result=ResultSpec(path="/noop-results"),
 | 
						||
            )
 | 
						||
        ],
 | 
						||
    )
 | 
						||
 | 
						||
    experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
 | 
						||
 | 
						||
    print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
 | 
						||
 | 
						||
 | 
						||
async def main():
 | 
						||
    parser = argparse.ArgumentParser(description="Tagging pipeline for Dolma JSONL dataset")
 | 
						||
    parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
 | 
						||
    parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
 | 
						||
    parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
 | 
						||
    parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
 | 
						||
    parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
 | 
						||
    parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
 | 
						||
 | 
						||
    # Beaker/job running stuff
 | 
						||
    parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
 | 
						||
    parser.add_argument("--beaker_workspace", help="Beaker workspace to submit to", default="ai2/olmocr")
 | 
						||
    parser.add_argument(
 | 
						||
        "--beaker_cluster",
 | 
						||
        help="Beaker clusters you want to run on",
 | 
						||
        default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"],
 | 
						||
    )
 | 
						||
    parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
 | 
						||
    parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
 | 
						||
 | 
						||
    parser.add_argument("--port", type=int, default=30024, help="Port for Model server")
 | 
						||
    args = parser.parse_args()
 | 
						||
 | 
						||
    global SERVER_PORT, workspace_s3, dataset_s3
 | 
						||
    SERVER_PORT = args.port
 | 
						||
    workspace_s3 = boto3.client("s3")
 | 
						||
    dataset_s3 = boto3.client("s3")
 | 
						||
 | 
						||
    # setup the job to work in beaker environment, load secrets, adjust logging, etc.
 | 
						||
    if "BEAKER_JOB_ID" in os.environ:
 | 
						||
        server_logger.addHandler(console_handler)
 | 
						||
        if "AWS_CREDENTIALS_FILE" in os.environ:
 | 
						||
            cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
 | 
						||
            os.makedirs(os.path.dirname(cred_path), exist_ok=True)
 | 
						||
            with open(cred_path, "w") as f:
 | 
						||
                f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
 | 
						||
        if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
 | 
						||
            cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
 | 
						||
            os.makedirs(os.path.dirname(cred_path), exist_ok=True)
 | 
						||
            with open(cred_path, "w") as f:
 | 
						||
                f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
 | 
						||
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
 | 
						||
        workspace_s3 = boto3.client("s3")
 | 
						||
        dataset_s3 = boto3.client("s3")
 | 
						||
 | 
						||
        # Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time
 | 
						||
        replica_count = int(os.environ.get("BEAKER_REPLICA_COUNT", "1"))
 | 
						||
        interval = 10 if (replica_count - 1) * 10 <= 240 else 240 / max(1, replica_count - 1)
 | 
						||
        sleep_time = int(int(os.environ.get("BEAKER_REPLICA_RANK", "0")) * interval)
 | 
						||
        logger.info(f"Beaker job sleeping for {sleep_time} seconds to stagger model downloads")
 | 
						||
        await asyncio.sleep(sleep_time)
 | 
						||
 | 
						||
    # Initialize work queue
 | 
						||
    if args.scratch.startswith("s3://"):
 | 
						||
        work_queue = S3WorkQueue(workspace_s3, args.scratch)
 | 
						||
    else:
 | 
						||
        work_queue = LocalWorkQueue(args.scratch)
 | 
						||
 | 
						||
    # Discover input files
 | 
						||
    files = set()
 | 
						||
    if args.dataset.startswith("s3://"):
 | 
						||
        pattern = args.dataset.rstrip("/") + "/documents/*.jsonl*"
 | 
						||
        matched = expand_s3_glob(dataset_s3, pattern)
 | 
						||
        files = set(matched.keys())
 | 
						||
    else:
 | 
						||
        docs_dir = os.path.join(args.dataset, "documents")
 | 
						||
        for root, _, fns in os.walk(docs_dir):
 | 
						||
            for fn in fns:
 | 
						||
                if fn.endswith((".jsonl", ".jsonl.gz", ".jsonl.ztd")):
 | 
						||
                    files.add(os.path.join(root, fn))
 | 
						||
 | 
						||
    # Populate the work queue if needed
 | 
						||
    await work_queue.populate_queue(list(files), items_per_group=1)
 | 
						||
 | 
						||
    if args.beaker:
 | 
						||
        submit_beaker_job(args)
 | 
						||
        return
 | 
						||
 | 
						||
    # If you get this far, then you are doing inference and need a GPU
 | 
						||
    check_torch_gpu_available()
 | 
						||
 | 
						||
    logger.info(f"Starting pipeline with PID {os.getpid()}")
 | 
						||
 | 
						||
    # Download the model before you do anything else
 | 
						||
    model_name_or_path = await download_model(args.model)
 | 
						||
 | 
						||
    # Initialize the work queue
 | 
						||
    qsize = await work_queue.initialize_queue()
 | 
						||
 | 
						||
    if qsize == 0:
 | 
						||
        logger.info("No work to do, exiting")
 | 
						||
        return
 | 
						||
 | 
						||
    # Create a semaphore to control worker access
 | 
						||
    # We only allow one worker to move forward with requests, until the server has no more requests in its queue
 | 
						||
    # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
 | 
						||
    # As soon as one worker is no longer saturating the gpu, the next one can start sending requests
 | 
						||
    semaphore = asyncio.Semaphore(1)
 | 
						||
 | 
						||
    model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore))
 | 
						||
 | 
						||
    await check_server_ready()
 | 
						||
 | 
						||
    metrics_task = asyncio.create_task(metrics_reporter(work_queue))
 | 
						||
 | 
						||
    # Create worker tasks to process the queue concurrently.
 | 
						||
    worker_tasks = []
 | 
						||
    for i in range(args.workers):
 | 
						||
        task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
 | 
						||
        worker_tasks.append(task)
 | 
						||
 | 
						||
    # Wait for all worker tasks to finish
 | 
						||
    await asyncio.gather(*worker_tasks)
 | 
						||
 | 
						||
    model_server.cancel()
 | 
						||
    metrics_task.cancel()
 | 
						||
    logger.info("Work done")
 | 
						||
 | 
						||
 | 
						||
if __name__ == "__main__":
 | 
						||
    asyncio.run(main())
 |