From 8ef7e56c86eebad0250e758eb0a626e95a278c12 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 30 Apr 2025 21:18:22 +0000 Subject: [PATCH] Trying a new rich tagging pipeline for PII --- scripts/rich_tagging_pipeline.py | 849 +++++++++++++++++++++++++++++++ 1 file changed, 849 insertions(+) create mode 100644 scripts/rich_tagging_pipeline.py diff --git a/scripts/rich_tagging_pipeline.py b/scripts/rich_tagging_pipeline.py new file mode 100644 index 0000000..7df964a --- /dev/null +++ b/scripts/rich_tagging_pipeline.py @@ -0,0 +1,849 @@ +#!/usr/bin/env python3 +""" +Tagging pipeline for Dolma JSONL datasets. + +For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder, +this script issues a model prompt completion +collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under +scratch/attributes/, mirroring the input structure. +""" +import argparse +import asyncio +import atexit +import gzip +import json +import logging +import os +import random +import re +import sys +import time +from typing import Optional +from urllib.parse import urlparse + +import boto3 +import httpx +import zstandard as zstd +from huggingface_hub import snapshot_download +from pydantic import BaseModel, Field, ValidationError + +from olmocr.check import ( + check_torch_gpu_available, +) +from olmocr.metrics import MetricsKeeper +from olmocr.s3_utils import ( + download_directory, + expand_s3_glob, + get_s3_bytes_with_backoff, + parse_s3_path, +) +from olmocr.version import VERSION +from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue + +# Initialize logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.propagate = False + +server_logger = logging.getLogger("vllm") +server_logger.propagate = False + +file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a") +file_handler.setLevel(logging.DEBUG) +file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) + +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) + +# Add handlers to the logger +logger.addHandler(file_handler) +logger.addHandler(console_handler) +server_logger.addHandler(file_handler) + + +# Default port; overridden by --port +SERVER_PORT = 30024 + +# Global variables for token statistics +metrics = MetricsKeeper(window=60 * 5) + + +class PIIClassification(BaseModel): + primary_language: str = Field(..., description="Primary language as a two-letter code") + document_type: str = Field(..., description="Basic summary of document type classification") + is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv") + contains_pii: Optional[bool] = Field(..., description="True if document contains PII") + +class RichPIIClassification(BaseModel): + primary_language: str = Field(..., description="Primary language as a two-letter code") + document_type: str = Field(..., description="Basic summary of document type classification") + + is_public_document: bool = Field(..., description="True if the document is meant for public dissemination") + + contains_pii_government_id: Optional[bool] = Field(...) + contains_pii_financial_info: Optional[bool] = Field(...) + contains_pii_biometric_data: Optional[bool] = Field(...) + contains_pii_login_info: Optional[bool] = Field(...) + + contains_identifier_name: Optional[bool] = Field(..., description="True if the document contains a name") + contains_identifier_email: Optional[bool] = Field(..., description="True if the document contains an email address") + contains_identifier_phone_number: Optional[bool] = Field(..., description="True if the document contains phone numbers") + + contains_identifier_with_address: Optional[bool] = Field(...) + contains_identifier_with_biographical_info: Optional[bool] = Field(...) + contains_identifier_with_location_info: Optional[bool] = Field(...) + contains_identifier_with_employment_info: Optional[bool] = Field(...) + contains_identifier_with_education_info: Optional[bool] = Field(...) + contains_identifier_with_medical_info: Optional[bool] = Field(...) + + def contains_any_pii(self) -> bool: + if self.is_public_document: + return False + + if self.contains_pii_government_id or self.contains_pii_financial_info or self.contains_pii_biometric_data or self.contains_pii_login_info: + return True + + if self.contains_identifier_name or self.contains_identifier_email or self.contains_identifier_phone_number: + return self.contains_identifier_with_address or self.contains_identifier_with_biographical_info or self.contains_identifier_with_location_info or self.contains_identifier_with_employment_info or self.contains_identifier_with_education_info or self.contains_identifier_with_medical_info + else: + return False + + +async def _process_single_page(page_text: str) -> RichPIIClassification: + """Helper function to process a single document or page.""" + text = page_text + + basic_prompt = "Given the text above, determine what type of document it is, and if it's a resume/CV. answer in JSON. The format of your json object should be {'primary_language': str, 'document_type': str, 'is_resume_cv': bool, 'contains_pii': bool}" + + rich_prompt = """You are a document analyzer that identifies Personally Identifiable Information (PII) in documents. +Your task is to analyze the provided document image and determine: +1. Whether the document is intended for public release or dissemination (e.g., research paper, public report, etc.) +2. If the document contains any PII + +For PII identification, follow these specific guidelines: + +PII THAT OCCURS EVEN WITHOUT AN IDENTIFIER: +The following should ALWAYS be marked as PII even if they do not occur alongside an identifier: +- Government IDs (Social Security Numbers, passport numbers, driver's license numbers, tax IDs) +- Financial Information (credit card numbers, bank account/routing numbers) +- Biometric Data (fingerprints, retina scans, facial recognition data, voice signatures) +- Login information (ONLY mark as PII when a username, password, and login location are present together) + +IDENTIFIERS FOR PII: +The following are considered identifiers that can make information PII: +- Names (full names, first names, last names, nicknames) +- Email addresses +- Phone numbers + +PII THAT MUST CO-OCCUR WITH AN IDENTIFIER: +The following types of information should ONLY be marked as PII if they occur ALONGSIDE an identifier (commonly, a person's name): +- Addresses (street address, postal code, etc.) +- Biographical Information (date of birth, place of birth, gender, sexual orientation, race, ethnicity, citizenship/immigration status, religion) +- Location Information (geolocations, specific coordinates) +- Employment Information (job titles, workplace names, employment history) +- Education Information (school names, degrees, transcripts) +- Medical Information (health records, diagnoses, genetic or neural data) + +If the document is a form, then only consider fields which are filled out with specific values as potential PII. +If this page does not itself contain PII, but references documents (such as curriculum vitae, personal statements) that typically contain PII, then do not mark it as PII. +Only consider actual occurrences of the PII within the document shown.""" + + query = { + "model": "google/gemma-3-4b-it", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": ( + f"{text}\n\n-----------\n" + f"{rich_prompt}" + ), + } + ], + } + ], + "max_tokens": 100, + "temperature": 0.0, + "response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": RichRIIClassification.model_json_schema()}}, + } + + url = f"http://localhost:{SERVER_PORT}/v1/chat/completions" + + # ---------- HTTP call --------------------------------------------------- + try: + status, body = await apost(url, json_data=query) + except Exception as e: + logger.warning(f"Server network error: {e!s}") + metrics.add_metrics(server_errors=1) + return RichPIIClassification(primary_language="en", document_type="unknown", is_public_document=False) + + metrics.add_metrics(server_requests=1) + + if status != 200: + logger.warning(f"Server HTTP {status}: {body[:250]!r}") + metrics.add_metrics(server_errors=1) + return RichPIIClassification(primary_language="en", document_type="unknown", is_public_document=False) + + # ---------- Parse base JSON -------------------------------------------- + try: + base = json.loads(body) + except json.JSONDecodeError: + logger.warning(f"Server response is not valid JSON: {body[:250]!r}") + metrics.add_metrics(server_errors=1) + return RichPIIClassification(primary_language="en", document_type="unknown", is_public_document=False) + + # Token accounting if available + if "usage" in base: + metrics.add_metrics( + server_input_tokens=base["usage"].get("prompt_tokens", 0), + server_output_tokens=base["usage"].get("completion_tokens", 0), + ) + + # ---------- Extract the model message ---------------------------------- + try: + content = base["choices"][0]["message"].get("content") + except (KeyError, IndexError, AttributeError) as e: + logger.warning(f"Missing fields in Server response: {e!s}") + metrics.add_metrics(server_errors=1) + return RichPIIClassification(primary_language="en", document_type="unknown", is_public_document=False) + + if not isinstance(content, str): + logger.warning("Server `content` is not a string; treating as error.") + metrics.add_metrics(server_errors=1) + return RichPIIClassification(primary_language="en", document_type="unknown", is_public_document=False) + + try: + pii_classification: RichPIIClassification = RichPIIClassification.model_validate_json(content) + return pii_classification + except ValidationError as e: + logger.warning(f"Unable to parse pii classification object: {e!s}") + metrics.add_metrics(server_errors=1) + return RichPIIClassification(primary_language="en", document_type="unknown", is_public_document=False) + + +# Manual simple implementation of HTTP Post +# It feels strange perhaps, but httpx and aiohttp are very complex beasts +# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed +# that at the scale of 100M+ requests, that they deadlock in different strange ways +async def apost(url, json_data): + parsed_url = urlparse(url) + host = parsed_url.hostname + port = parsed_url.port or 80 + path = parsed_url.path or "/" + + writer = None + try: + reader, writer = await asyncio.open_connection(host, port) + + json_payload = json.dumps(json_data) + request = ( + f"POST {path} HTTP/1.1\r\n" + f"Host: {host}\r\n" + f"Content-Type: application/json\r\n" + f"Content-Length: {len(json_payload)}\r\n" + f"Connection: close\r\n\r\n" + f"{json_payload}" + ) + writer.write(request.encode()) + await writer.drain() + + # Read status line + status_line = await reader.readline() + if not status_line: + raise ConnectionError("No response from server") + status_parts = status_line.decode().strip().split(" ", 2) + if len(status_parts) < 2: + raise ValueError(f"Malformed status line: {status_line.decode().strip()}") + status_code = int(status_parts[1]) + + # Read headers + headers = {} + while True: + line = await reader.readline() + if line in (b"\r\n", b"\n", b""): + break + key, _, value = line.decode().partition(":") + headers[key.strip().lower()] = value.strip() + + # Read response body + if "content-length" in headers: + body_length = int(headers["content-length"]) + response_body = await reader.readexactly(body_length) + else: + raise ConnectionError("Anything other than fixed content length responses are not implemented yet") + + return status_code, response_body + except Exception as e: + # Pass through errors + raise e + finally: + # But just make sure to close the socket on your way out + if writer is not None: + try: + writer.close() + await writer.wait_closed() + except: + pass + + +async def process_dolma_document(args, dolma_doc, sem): + """ + Query model to detect PII, enforcing a JSON schema. + + Resilient to: + • Transport / HTTP errors + • Missing or malformed fields in the response + • Non-string or None `content` + • Bad JSON in the model's answer + + Always returns: (doc_id, contains_pii: bool, text_length: int) + """ + doc_id = dolma_doc.get("id") + text = dolma_doc.get("text", "") or "" + + language_key_name = f"{args.model.replace('/', '_')}_language" + contains_pii_key_name = f"{args.model.replace('/', '_')}_contains_pii" + + result_attributes = {contains_pii_key_name: [], language_key_name: []} + + # If pdf_page_numbers is present, split the text and process each page separately + if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]: + page_numbers = dolma_doc["attributes"]["pdf_page_numbers"] + + logger.info(f"Document {doc_id} has {len(page_numbers)} pages, processing each individually") + + # Filter pages down to actual real content + selected_page_numbers = [tuple(p) for p in page_numbers if p[0] < p[1]] + first_page_number = selected_page_numbers[0] + + # Sample 3 pages max per document, but always include the first page, it's a good signal for CV classification + random.shuffle(selected_page_numbers) + selected_page_numbers = selected_page_numbers[:3] + + if first_page_number not in selected_page_numbers: + selected_page_numbers[0] = first_page_number + + for start_pos, end_pos, page_num in page_numbers: + if (start_pos, end_pos, page_num) in selected_page_numbers: + page_text = text[start_pos:end_pos] + + # Process each page with the semaphore to limit concurrent requests + async with sem: + pii_class = await _process_single_page(page_text) + + result_attributes[contains_pii_key_name].append([start_pos, end_pos, pii_class.contains_any_pii()]) + result_attributes[language_key_name].append([start_pos, end_pos, pii_class.primary_language]) + else: + result_attributes[contains_pii_key_name].append([start_pos, end_pos, None]) + result_attributes[language_key_name].append([start_pos, end_pos, None]) + + return result_attributes + else: + raise NotImplementedError("Missing code here, expecting this to be dolma docs made by olmocr....") + + +async def process_file(args, worker_id: int, file_uri: str): + """ + Download a JSONL file, query model per record, and collect attributes. + """ + # Fetch raw bytes (S3 or local) + if file_uri.startswith("s3://"): + raw = await asyncio.to_thread(get_s3_bytes_with_backoff, dataset_s3, file_uri) + else: + with open(file_uri, "rb") as f: + raw = f.read() + + # Decompress if needed + if file_uri.endswith(".gz"): + file_bytes = gzip.decompress(raw) + elif file_uri.endswith(".ztd") or file_uri.endswith(".zst") or file_uri.endswith(".zstd"): + dctx = zstd.ZstdDecompressor() + file_bytes = dctx.decompress(raw, max_output_size=1_000_000_000) + else: + file_bytes = raw + + lines = file_bytes.decode("utf-8").splitlines() + page_tasks = {} + + # Send all records in parallel, max N queued at a time + sem = asyncio.Semaphore(args.parallel_requests) + + async with asyncio.TaskGroup() as tg: + for line in lines: + dolma_doc = json.loads(line) + task = tg.create_task(process_dolma_document(args, dolma_doc, sem)) + page_tasks[dolma_doc["id"]] = (task, dolma_doc) + + logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}") + + # Collect results and build attributes + attributes = [] + for doc_id, (task, dolma_doc) in page_tasks.items(): + doc_attributes = task.result() + + attributes.append({"id": doc_id, "attributes": doc_attributes}) + + return attributes + + +async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, worker_id: int): + """ + Pop work-items off the queue, run PII tagging, write the attributes file + next to the dataset (keeping the original compression), mark the item done, + and drop an empty sentinel file in /results/. + """ + while True: + await semaphore.acquire() + work_item = await work_queue.get_work() + + if work_item is None: + logger.info(f"Worker {worker_id} exiting – queue empty") + semaphore.release() + break + + file_uri = work_item.work_paths[0] + logger.info(f"Worker {worker_id} processing {file_uri}") + + try: + # ------------------------------------------------------------------ + # Run the per-file pipeline + # ------------------------------------------------------------------ + attributes = await process_file(args, worker_id, file_uri) + + # 1. Build the relative path that mirrors documents/… + if file_uri.startswith("s3://"): + _, key = parse_s3_path(file_uri) + _, docs_prefix = parse_s3_path(args.dataset) + rel_path = key[len(os.path.join(docs_prefix, "documents/")) :] + else: + docs_root = os.path.join(args.dataset, "documents") + rel_path = os.path.relpath(file_uri, docs_root) + + out_rel = os.path.join("attributes", args.attribute_name, rel_path) + out_jsonl = "\n".join(json.dumps(x) for x in attributes) + "\n" + + # 2. Preserve compression type + if rel_path.endswith(".gz"): + payload = gzip.compress(out_jsonl.encode("utf-8")) + elif rel_path.endswith((".zst", ".ztd")): + payload = zstd.ZstdCompressor().compress(out_jsonl.encode("utf-8")) + else: + payload = out_jsonl.encode("utf-8") + + # 3. Write to args.dataset (local or S3) + if args.dataset.startswith("s3://"): + bucket, prefix = parse_s3_path(args.dataset) + key = os.path.join(prefix, out_rel) + workspace_s3.put_object(Bucket=bucket, Key=key, Body=payload) + else: + out_path = os.path.join(args.dataset, out_rel) + os.makedirs(os.path.dirname(out_path), exist_ok=True) + with open(out_path, "wb") as fh: + fh.write(payload) + + # 4. Mark queue item done + await work_queue.mark_done(work_item) + + # 5. Drop empty sentinel file in /results/ + sentinel_rel = os.path.join("results", f"output_{work_item.hash}.jsonl") + if args.scratch.startswith("s3://"): + bkt, pfx = parse_s3_path(args.scratch) + key = os.path.join(pfx, sentinel_rel) + workspace_s3.put_object(Bucket=bkt, Key=key, Body=b"") + else: + sentinel_path = os.path.join(args.scratch, sentinel_rel) + os.makedirs(os.path.dirname(sentinel_path), exist_ok=True) + open(sentinel_path, "w").close() + + except Exception as exc: + logger.exception(f"Worker {worker_id} exception: {exc!s}") + finally: + semaphore.release() + + +async def server_task(model_name_or_path, args, semaphore): + # Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory + # mem_fraction_arg = ["--mem-fraction-static", "0.80"] + + cmd = [ + "vllm", + "serve", + model_name_or_path, + "--port", + str(SERVER_PORT), + "--uvicorn-log-level", + "warning", + "--disable-log-requests", + ] + + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + # Ensure the subprocess is terminated on exit + def _kill_proc(): + proc.terminate() + + atexit.register(_kill_proc) + + # Shared variables between tasks + last_running_req, last_queue_req = 0, 0 + server_printed_ready_message = False + last_semaphore_release = time.time() + + async def process_line(line): + nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message + server_logger.info(line) + + # if the server hasn't initialized yet, log all the lines to the main logger also, so that the user + # can see any warnings/errors more easily + if not server_printed_ready_message: + logger.info(line) + + if not server_printed_ready_message and "The server is fired up and ready to roll!" in line: + server_printed_ready_message = True + last_semaphore_release = time.time() + + match = re.search(r"Running: (\d+) reqs", line) + if match: + last_running_req = int(match.group(1)) + + match = re.search(r"Waiting: (\d+) reqs", line) + if match: + last_queue_req = int(match.group(1)) + logger.info(f"running req: {last_running_req} queue req: {last_queue_req}") + + async def read_stream(stream): + while True: + line = await stream.readline() + if not line: + break + try: + line = line.decode("utf-8").rstrip() + await process_line(line) + except Exception as ex: + logger.warning(f"Got {ex} when reading log line from inference server, skipping") + + async def timeout_task(): + nonlocal last_running_req, last_queue_req, last_semaphore_release + try: + while True: + await asyncio.sleep(1) + if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked(): + semaphore.release() + last_semaphore_release = time.time() + logger.info("Semaphore released, allowing a worker to proceed.") + except asyncio.CancelledError: + pass # Clean up if the task is cancelled + + # Start tasks to read stdout, stderr, and handle timeout logic + stdout_task = asyncio.create_task(read_stream(proc.stdout)) + stderr_task = asyncio.create_task(read_stream(proc.stderr)) + timeout_task = asyncio.create_task(timeout_task()) + + try: + await proc.wait() + except asyncio.CancelledError: + logger.info("Got cancellation request for server") + proc.terminate() + raise + + timeout_task.cancel() + await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True) + + +async def server_host(model_name_or_path, args, semaphore): + MAX_RETRIES = 5 + retry = 0 + + while retry < MAX_RETRIES: + await server_task(model_name_or_path, args, semaphore) + logger.warning("Server task ended") + retry += 1 + + if retry >= MAX_RETRIES: + logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline") + logger.error("") + logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4") + sys.exit(1) + + +async def check_server_ready(): + max_attempts = 300 + delay_sec = 1 + url = f"http://localhost:{SERVER_PORT}/v1/models" + + for attempt in range(1, max_attempts + 1): + try: + async with httpx.AsyncClient() as session: + response = await session.get(url) + + if response.status_code == 200: + logger.info("server is ready.") + return + else: + logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}") + except Exception: + logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...") + + await asyncio.sleep(delay_sec) + + raise Exception("model server did not become ready after waiting.") + + +async def download_model(model_name_or_path: str): + if model_name_or_path.startswith("s3://") or model_name_or_path.startswith("gs://") or model_name_or_path.startswith("weka://"): + logger.info(f"Downloading model directory from '{model_name_or_path}'") + model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model") + download_directory([model_name_or_path], model_cache_dir) + return model_cache_dir + elif os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path): + logger.info(f"Using local model path at '{model_name_or_path}'") + return model_name_or_path + else: + logger.info(f"Downloading model with hugging face '{model_name_or_path}'") + snapshot_download(repo_id=model_name_or_path) + return model_name_or_path + + +async def metrics_reporter(work_queue): + while True: + # Leading newlines preserve table formatting in logs + logger.info(f"Queue remaining: {work_queue.size}") + logger.info("\n" + str(metrics)) + await asyncio.sleep(10) + + +def submit_beaker_job(args): + from beaker import ( # type: ignore + Beaker, + Constraints, + EnvVar, + ExperimentSpec, + ImageSource, + Priority, + ResultSpec, + SecretNotFound, + TaskContext, + TaskResources, + TaskSpec, + ) + + b = Beaker.from_env(default_workspace=args.beaker_workspace) + account = b.account.whoami() + owner = account.name + beaker_image = f"jakep/olmocr-tagging-{VERSION}" + + task_name = f"olmocr-{os.path.basename(args.dataset.rstrip('/'))}" + + # Take out --beaker flag so the workers will just run things + args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"] + + # Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally + args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i - 1] == "--pdfs"))] + + try: + b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace) + b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace) + b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace) + except SecretNotFound: + print( + f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]" + ) + + if input().strip().lower() != "y": + print("Exiting...") + sys.exit(1) + + b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace) + b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace) + b.secret.write( + f"{owner}-AWS_CREDENTIALS_FILE", + open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(), + args.beaker_workspace, + ) + + env_var_secrets = [ + EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"), + EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"), + EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"), + ] + + try: + b.secret.get("OLMOCR_PREVIEW_HF_TOKEN", args.beaker_workspace) + env_var_secrets.append(EnvVar(name="HF_TOKEN", secret="OLMOCR_PREVIEW_HF_TOKEN")) + except SecretNotFound: + pass + + try: + b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace) + env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY")) + except SecretNotFound: + print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):") + lines = [] + prev_empty = False + for line in iter(input, None): + if not line and prev_empty: + break + prev_empty = not line + lines.append(line) + gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line + if gcs_sa_key: + b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace) + env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY")) + + # Create the experiment spec + experiment_spec = ExperimentSpec( + budget="ai2/oe-data", + description=task_name, + tasks=[ + TaskSpec( + name=task_name, + propagate_failure=False, + propagate_preemption=False, + replicas=args.beaker_gpus, + context=TaskContext( + priority=Priority(args.beaker_priority), + preemptible=True, + ), + image=ImageSource(beaker=beaker_image), + command=["python", "scripts/rich_tagging_pipeline.py"] + args_list, + env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets, + resources=TaskResources(gpu_count=1), + constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]), + result=ResultSpec(path="/noop-results"), + ) + ], + ) + + experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace) + + print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}") + + +async def main(): + parser = argparse.ArgumentParser(description="Tagging pipeline for Dolma JSONL dataset") + parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder") + parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)") + parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers") + parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model") + parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format") + parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming") + + # Beaker/job running stuff + parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally") + parser.add_argument("--beaker_workspace", help="Beaker workspace to submit to", default="ai2/olmocr") + parser.add_argument( + "--beaker_cluster", + help="Beaker clusters you want to run on", + default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"], + ) + parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run") + parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job") + + parser.add_argument("--port", type=int, default=30024, help="Port for Model server") + args = parser.parse_args() + + global SERVER_PORT, workspace_s3, dataset_s3 + SERVER_PORT = args.port + workspace_s3 = boto3.client("s3") + dataset_s3 = boto3.client("s3") + + # setup the job to work in beaker environment, load secrets, adjust logging, etc. + if "BEAKER_JOB_ID" in os.environ: + server_logger.addHandler(console_handler) + if "AWS_CREDENTIALS_FILE" in os.environ: + cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") + os.makedirs(os.path.dirname(cred_path), exist_ok=True) + with open(cred_path, "w") as f: + f.write(os.environ.get("AWS_CREDENTIALS_FILE")) + if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ: + cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials") + os.makedirs(os.path.dirname(cred_path), exist_ok=True) + with open(cred_path, "w") as f: + f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE")) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path + workspace_s3 = boto3.client("s3") + dataset_s3 = boto3.client("s3") + + # Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time + replica_count = int(os.environ.get("BEAKER_REPLICA_COUNT", "1")) + interval = 10 if (replica_count - 1) * 10 <= 240 else 240 / max(1, replica_count - 1) + sleep_time = int(int(os.environ.get("BEAKER_REPLICA_RANK", "0")) * interval) + logger.info(f"Beaker job sleeping for {sleep_time} seconds to stagger model downloads") + await asyncio.sleep(sleep_time) + + # Initialize work queue + if args.scratch.startswith("s3://"): + work_queue = S3WorkQueue(workspace_s3, args.scratch) + else: + work_queue = LocalWorkQueue(args.scratch) + + # Discover input files + files = set() + if args.dataset.startswith("s3://"): + pattern = args.dataset.rstrip("/") + "/documents/*.jsonl*" + matched = expand_s3_glob(dataset_s3, pattern) + files = set(matched.keys()) + else: + docs_dir = os.path.join(args.dataset, "documents") + for root, _, fns in os.walk(docs_dir): + for fn in fns: + if fn.endswith((".jsonl", ".jsonl.gz", ".jsonl.ztd")): + files.add(os.path.join(root, fn)) + + # Populate the work queue if needed + await work_queue.populate_queue(list(files), items_per_group=1) + + if args.beaker: + submit_beaker_job(args) + return + + # If you get this far, then you are doing inference and need a GPU + check_torch_gpu_available() + + logger.info(f"Starting pipeline with PID {os.getpid()}") + + # Download the model before you do anything else + model_name_or_path = await download_model(args.model) + + # Initialize the work queue + qsize = await work_queue.initialize_queue() + + if qsize == 0: + logger.info("No work to do, exiting") + return + + # Create a semaphore to control worker access + # We only allow one worker to move forward with requests, until the server has no more requests in its queue + # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible + # As soon as one worker is no longer saturating the gpu, the next one can start sending requests + semaphore = asyncio.Semaphore(1) + + model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore)) + + await check_server_ready() + + metrics_task = asyncio.create_task(metrics_reporter(work_queue)) + + # Create worker tasks to process the queue concurrently. + worker_tasks = [] + for i in range(args.workers): + task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i)) + worker_tasks.append(task) + + # Wait for all worker tasks to finish + await asyncio.gather(*worker_tasks) + + model_server.cancel() + metrics_task.cancel() + logger.info("Work done") + + +if __name__ == "__main__": + asyncio.run(main())