A bit more work on tagging

This commit is contained in:
Jake Poznanski 2025-05-09 19:31:07 +00:00
parent 72bcfd8f31
commit 1854ae1269
4 changed files with 850 additions and 66 deletions

View File

@ -94,10 +94,10 @@ class DocumentClassification(pydantic.BaseModel):
is_correspondence_or_letter: bool
is_public_order: bool
is_court_notice: bool
# Additional context or notes
classification_notes: str
@property
def get_document_types(self) -> List[str]:
"""Get a list of all document types identified"""
@ -396,7 +396,7 @@ def process_single_page(args, doc_info, pdf_s3_client=None):
# Generate attribute key names using model name
model_prefix = args.openai_model.replace("/", "_").replace("-", "_").replace(".", "_")
language_key_name = f"{model_prefix}_language"
# Initialize result attributes with all DocumentClassification fields
result_attributes = {
language_key_name: [[start_pos, end_pos, classification.language_code.name]],

View File

@ -526,12 +526,12 @@ def apply_rule(doc, rule):
def calculate_attribute_aggregate(doc, attribute_name, operation_type):
"""
Calculate an aggregate value for a numeric attribute.
Args:
doc: The document JSON object
attribute_name: The attribute field to aggregate (e.g., "pii_tagging_ratio")
operation_type: The type of aggregation to perform (e.g., "avg")
Returns:
The aggregated value, or None if calculation is not possible
"""
@ -539,33 +539,33 @@ def calculate_attribute_aggregate(doc, attribute_name, operation_type):
if "attributes" not in doc or not doc["attributes"]:
logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes")
return None
attributes = doc["attributes"]
# Check if the specific attribute exists
if attribute_name not in attributes:
logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}")
return None
if not attributes[attribute_name]:
logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}")
return None
# Extract the numeric values from the attribute spans
# Each span is formatted as [start_pos, end_pos, value]
values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
if not values:
logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}")
return None
# Convert all values to float
try:
numeric_values = [float(value) for value in values]
except (ValueError, TypeError):
logger.debug(f"Document {doc.get('id', 'unknown')} has non-numeric values for attribute: {attribute_name}")
return None
# Perform the aggregation
if operation_type == "avg":
if not numeric_values:
@ -606,32 +606,32 @@ def apply_simple_rule(doc, attribute_name, rule_type):
return False
# Handle numeric comparison rules (e.g., 'avg>0.3')
if any(op in rule_type for op in ['>', '<', '>=', '<=', '==']):
if any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
# Parse the rule type into operation and comparison
operation_parts = rule_type.split('>')
operation_parts = rule_type.split(">")
if len(operation_parts) == 2:
operation_type, threshold = operation_parts
comparison_op = '>'
comparison_op = ">"
else:
operation_parts = rule_type.split('<')
operation_parts = rule_type.split("<")
if len(operation_parts) == 2:
operation_type, threshold = operation_parts
comparison_op = '<'
comparison_op = "<"
else:
operation_parts = rule_type.split('>=')
operation_parts = rule_type.split(">=")
if len(operation_parts) == 2:
operation_type, threshold = operation_parts
comparison_op = '>='
comparison_op = ">="
else:
operation_parts = rule_type.split('<=')
operation_parts = rule_type.split("<=")
if len(operation_parts) == 2:
operation_type, threshold = operation_parts
comparison_op = '<='
comparison_op = "<="
else:
operation_parts = rule_type.split('==')
operation_parts = rule_type.split("==")
if len(operation_parts) == 2:
operation_type, threshold = operation_parts
comparison_op = '=='
comparison_op = "=="
else:
raise ValueError(f"Invalid rule type: {rule_type}")
@ -648,15 +648,15 @@ def apply_simple_rule(doc, attribute_name, rule_type):
return False
# Apply the comparison
if comparison_op == '>':
if comparison_op == ">":
result = aggregate_value > threshold
elif comparison_op == '<':
elif comparison_op == "<":
result = aggregate_value < threshold
elif comparison_op == '>=':
elif comparison_op == ">=":
result = aggregate_value >= threshold
elif comparison_op == '<=':
elif comparison_op == "<=":
result = aggregate_value <= threshold
elif comparison_op == '==':
elif comparison_op == "==":
result = aggregate_value == threshold
else:
raise ValueError(f"Invalid comparison operator: {comparison_op}")
@ -686,7 +686,7 @@ def apply_simple_rule(doc, attribute_name, rule_type):
if result:
logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (all {len(values)} values are True)")
return result
raise ValueError(f"Unknown rule type: {rule_type}")
@ -939,8 +939,7 @@ class Parser:
attribute_name, rule_type = rule_tuple
# Validate rule type
if (rule_type not in ["any", "all"] and
not any(op in rule_type for op in ['>', '<', '>=', '<=', '=='])):
if rule_type not in ["any", "all"] and not any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all', or numeric comparison (e.g., 'avg>0.3')")
return RuleNode(attribute_name, rule_type)
@ -1018,9 +1017,9 @@ def parse_rule(rule_string):
raise ValueError(f"Invalid rule format: {rule_string}. Expected format: 'attribute_name:rule_type'")
attribute_name, rule_type = parts
# Check for numeric comparison rule_type
if any(op in rule_type for op in ['>', '<', '>=', '<=', '==']):
if any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
# This is a numeric comparison rule - we'll validate it in apply_simple_rule
return attribute_name, rule_type
elif rule_type not in ["any", "all"]:
@ -1290,13 +1289,13 @@ def format_rule_stats(rule_stats):
def generate_html_report(docs, title, summary, output_path):
"""
Generate an HTML report file with document texts
Args:
docs: List of documents to include in the report
title: Title of the report
summary: Summary statistics to include at the top
output_path: Path to save the HTML file
Returns:
None
"""
@ -1476,10 +1475,10 @@ def generate_html_report(docs, title, summary, output_path):
doc_id = doc.get("id", f"unknown_{i}")
# Get document text, falling back to JSON representation if not available
doc_text = doc.get("text", json.dumps(doc, indent=2))
# The first document gets the "selected" class
selected_class = " selected" if i == 1 else ""
html += f"""
<div id="doc-{i}" class="document{selected_class}" tabindex="0">
<div class="document-id">Document ID: {doc_id}</div>
@ -1601,11 +1600,11 @@ def generate_html_report(docs, title, summary, output_path):
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
# Write HTML to file
with open(output_path, "w", encoding="utf-8") as f:
f.write(html)
logger.info(f"Generated HTML report: {output_path}")
@ -1661,40 +1660,40 @@ def main():
# Get document IDs for each category
ref_matches = set()
hyp_matches = set()
# Create mappings from document IDs to documents
doc_map = {doc["id"]: doc for doc in all_docs if "id" in doc}
# Find documents that match the reference and hypothesis rules
for doc_id, doc in doc_map.items():
if apply_rule(doc, ref_rule):
ref_matches.add(doc_id)
if apply_rule(doc, hyp_rule):
hyp_matches.add(doc_id)
# Calculate document sets for each category
true_positives_ids = ref_matches.intersection(hyp_matches)
true_negatives_ids = set(doc_map.keys()) - ref_matches - hyp_matches
false_positives_ids = hyp_matches - ref_matches
false_negatives_ids = ref_matches - hyp_matches
# Create document lists for each category
true_positives = [doc_map[doc_id] for doc_id in true_positives_ids]
true_negatives = [doc_map[doc_id] for doc_id in true_negatives_ids]
false_positives = [doc_map[doc_id] for doc_id in false_positives_ids]
false_negatives = [doc_map[doc_id] for doc_id in false_negatives_ids]
# Calculate metrics
tp = len(true_positives)
tn = len(true_negatives)
fp = len(false_positives)
fn = len(false_negatives)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
# Prepare overall statistics
overall_stats = {
"total_docs": len(doc_map),
@ -1733,37 +1732,29 @@ IoU: {iou:.4f}
# Generate HTML reports for each category
logger.info("Generating HTML reports...")
# True Positives
generate_html_report(
true_positives,
"True Positives - Documents matching both Reference and Hypothesis Rules",
summary,
os.path.join(args.output_dir, "true_positives.html")
true_positives, "True Positives - Documents matching both Reference and Hypothesis Rules", summary, os.path.join(args.output_dir, "true_positives.html")
)
# True Negatives
generate_html_report(
true_negatives,
"True Negatives - Documents not matching either Rule",
summary,
os.path.join(args.output_dir, "true_negatives.html")
)
generate_html_report(true_negatives, "True Negatives - Documents not matching either Rule", summary, os.path.join(args.output_dir, "true_negatives.html"))
# False Positives
generate_html_report(
false_positives,
"False Positives - Documents matching Hypothesis but not Reference Rule",
summary,
os.path.join(args.output_dir, "false_positives.html")
os.path.join(args.output_dir, "false_positives.html"),
)
# False Negatives
generate_html_report(
false_negatives,
"False Negatives - Documents matching Reference but not Hypothesis Rule",
summary,
os.path.join(args.output_dir, "false_negatives.html")
os.path.join(args.output_dir, "false_negatives.html"),
)
# Generate index.html file that links to all reports
@ -1857,7 +1848,7 @@ IoU: {iou:.4f}
with open(os.path.join(args.output_dir, "index.html"), "w", encoding="utf-8") as f:
f.write(index_html)
# Print summary
logger.info("\n--- COMPARISON SUMMARY ---")
logger.info(f"Documents Folder: {args.docs_folder}")
@ -1887,7 +1878,7 @@ IoU: {iou:.4f}
logger.info(f"Recall: {recall:.4f}")
logger.info(f"F1 Score: {f1:.4f}")
logger.info(f"IoU: {iou:.4f}")
logger.info(f"\nResults saved to: {args.output_dir}/index.html")

View File

@ -0,0 +1,788 @@
#!/usr/bin/env python3
"""
Tagging pipeline for Dolma JSONL datasets.
For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
this script issues a model prompt completion
collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
scratch/attributes/, mirroring the input structure.
"""
import argparse
import asyncio
import atexit
import gzip
import json
import logging
import os
import re
import sys
import time
from typing import Optional
from urllib.parse import urlparse
import boto3
import httpx
import zstandard as zstd
from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field, ValidationError
from olmocr.check import (
check_torch_gpu_available,
)
from olmocr.metrics import MetricsKeeper
from olmocr.s3_utils import (
download_directory,
expand_s3_glob,
get_s3_bytes_with_backoff,
parse_s3_path,
)
from olmocr.version import VERSION
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
# Initialize logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.propagate = False
server_logger = logging.getLogger("vllm")
server_logger.propagate = False
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
# Add handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)
server_logger.addHandler(file_handler)
# Default port; overridden by --port
SERVER_PORT = 30024
# Global variables for token statistics
metrics = MetricsKeeper(window=60 * 5)
class PIIClassification(BaseModel):
primary_language: str = Field(..., description="Primary language as a two-letter code")
document_type: str = Field(..., description="Basic summary of document type classification")
is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
is_academic_paper: bool
is_textbook: bool
is_news_article: bool
is_test_or_quiz: bool
is_homework_assignment: bool
is_class_syllabus: bool
is_meeting_minutes: bool
is_legal_contract: bool
is_form: bool
is_correspondence_or_letter: bool
is_public_order: bool
is_court_notice: bool
contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
async def _process_single_page(page_text: str) -> PIIClassification:
"""Helper function to process a single document or page."""
text = page_text
query = {
"model": "google/gemma-3-4b-it",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": (
f"{text}\n\n-----------\n"
"Given the text above, determine what type of document it is. Answer in JSON. The format of your json object should be {'primary_language': str, 'document_type': str, 'is_resume_cv': bool, 'is_academic_paper': bool, 'is_textbook': bool, 'is_news_article': bool, 'is_test_or_quiz': bool, 'is_homework_assignment': bool, 'is_class_syllabus': bool, 'is_meeting_minutes': bool, 'is_legal_contract': bool, 'is_form': bool, 'is_correspondence_or_letter': bool, 'is_public_order': bool, 'is_court_notice': bool, 'contains_pii': bool}"
),
}
],
}
],
"max_tokens": 100,
"temperature": 0.0,
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
}
url = f"http://localhost:{SERVER_PORT}/v1/chat/completions"
# ---------- HTTP call ---------------------------------------------------
try:
status, body = await apost(url, json_data=query)
except Exception as e:
logger.warning(f"Server network error: {e!s}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
metrics.add_metrics(server_requests=1)
if status != 200:
logger.warning(f"Server HTTP {status}: {body[:250]!r}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
# ---------- Parse base JSON --------------------------------------------
try:
base = json.loads(body)
except json.JSONDecodeError:
logger.warning(f"Server response is not valid JSON: {body[:250]!r}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
# Token accounting if available
if "usage" in base:
metrics.add_metrics(
server_input_tokens=base["usage"].get("prompt_tokens", 0),
server_output_tokens=base["usage"].get("completion_tokens", 0),
)
# ---------- Extract the model message ----------------------------------
try:
content = base["choices"][0]["message"].get("content")
except (KeyError, IndexError, AttributeError) as e:
logger.warning(f"Missing fields in Server response: {e!s}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
if not isinstance(content, str):
logger.warning("Server `content` is not a string; treating as error.")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
try:
pii_classification: PIIClassification = PIIClassification.model_validate_json(content)
return pii_classification
except ValidationError as e:
logger.warning(f"Unable to parse pii classification object: {e!s}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
# Manual simple implementation of HTTP Post
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost(url, json_data):
parsed_url = urlparse(url)
host = parsed_url.hostname
port = parsed_url.port or 80
path = parsed_url.path or "/"
writer = None
try:
reader, writer = await asyncio.open_connection(host, port)
json_payload = json.dumps(json_data)
request = (
f"POST {path} HTTP/1.1\r\n"
f"Host: {host}\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: {len(json_payload)}\r\n"
f"Connection: close\r\n\r\n"
f"{json_payload}"
)
writer.write(request.encode())
await writer.drain()
# Read status line
status_line = await reader.readline()
if not status_line:
raise ConnectionError("No response from server")
status_parts = status_line.decode().strip().split(" ", 2)
if len(status_parts) < 2:
raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
status_code = int(status_parts[1])
# Read headers
headers = {}
while True:
line = await reader.readline()
if line in (b"\r\n", b"\n", b""):
break
key, _, value = line.decode().partition(":")
headers[key.strip().lower()] = value.strip()
# Read response body
if "content-length" in headers:
body_length = int(headers["content-length"])
response_body = await reader.readexactly(body_length)
else:
raise ConnectionError("Anything other than fixed content length responses are not implemented yet")
return status_code, response_body
except Exception as e:
# Pass through errors
raise e
finally:
# But just make sure to close the socket on your way out
if writer is not None:
try:
writer.close()
await writer.wait_closed()
except:
pass
async def process_dolma_document(args, dolma_doc, sem):
"""
Query model to detect PII, enforcing a JSON schema.
Resilient to:
Transport / HTTP errors
Missing or malformed fields in the response
Non-string or None `content`
Bad JSON in the model's answer
Always returns: (doc_id, contains_pii: bool, text_length: int)
"""
doc_id = dolma_doc.get("id")
text = dolma_doc.get("text", "") or ""
# Create keys for all fields in PIIClassification
prefix = args.model.replace("/", "_")
result_attributes = {}
# Initialize attribute lists for all PIIClassification fields
for field_name in PIIClassification.model_fields:
key_name = f"{prefix}_{field_name}"
result_attributes[key_name] = []
# If pdf_page_numbers is present, sample first 5000 characters of the document
if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
page_numbers = dolma_doc["attributes"]["pdf_page_numbers"]
logger.info(f"Document {doc_id} has {len(page_numbers)} pages, processing first 5000 characters")
# Take first 5000 characters of the document
sample_text = text[:5000]
# Process the sample with the semaphore to limit concurrent requests
async with sem:
pii_class = await _process_single_page(sample_text)
# Add all classification attributes to results
for field_name in PIIClassification.model_fields:
key_name = f"{prefix}_{field_name}"
attribute_value = getattr(pii_class, field_name)
# Add the classification result to all pages
for start_pos, end_pos, page_num in page_numbers:
result_attributes[key_name].append([start_pos, end_pos, attribute_value])
return result_attributes
else:
raise NotImplementedError("Missing code here, expecting this to be dolma docs made by olmocr....")
async def process_file(args, worker_id: int, file_uri: str):
"""
Download a JSONL file, query model per record, and collect attributes.
"""
# Fetch raw bytes (S3 or local)
if file_uri.startswith("s3://"):
raw = await asyncio.to_thread(get_s3_bytes_with_backoff, dataset_s3, file_uri)
else:
with open(file_uri, "rb") as f:
raw = f.read()
# Decompress if needed
if file_uri.endswith(".gz"):
file_bytes = gzip.decompress(raw)
elif file_uri.endswith(".ztd") or file_uri.endswith(".zst") or file_uri.endswith(".zstd"):
dctx = zstd.ZstdDecompressor()
file_bytes = dctx.decompress(raw, max_output_size=1_000_000_000)
else:
file_bytes = raw
lines = file_bytes.decode("utf-8").splitlines()
page_tasks = {}
# Send all records in parallel, max N queued at a time
sem = asyncio.Semaphore(args.parallel_requests)
async with asyncio.TaskGroup() as tg:
for line in lines:
dolma_doc = json.loads(line)
task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
page_tasks[dolma_doc["id"]] = (task, dolma_doc)
logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}")
# Collect results and build attributes
attributes = []
for doc_id, (task, dolma_doc) in page_tasks.items():
doc_attributes = task.result()
attributes.append({"id": doc_id, "attributes": doc_attributes})
return attributes
async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, worker_id: int):
"""
Pop work-items off the queue, run PII tagging, write the attributes file
next to the dataset (keeping the original compression), mark the item done,
and drop an empty sentinel file in <workspace>/results/.
"""
while True:
await semaphore.acquire()
work_item = await work_queue.get_work()
if work_item is None:
logger.info(f"Worker {worker_id} exiting queue empty")
semaphore.release()
break
file_uri = work_item.work_paths[0]
logger.info(f"Worker {worker_id} processing {file_uri}")
try:
# ------------------------------------------------------------------
# Run the per-file pipeline
# ------------------------------------------------------------------
attributes = await process_file(args, worker_id, file_uri)
# 1. Build the relative path that mirrors documents/…
if file_uri.startswith("s3://"):
_, key = parse_s3_path(file_uri)
_, docs_prefix = parse_s3_path(args.dataset)
rel_path = key[len(os.path.join(docs_prefix, "documents/")) :]
else:
docs_root = os.path.join(args.dataset, "documents")
rel_path = os.path.relpath(file_uri, docs_root)
out_rel = os.path.join("attributes", args.attribute_name, rel_path)
out_jsonl = "\n".join(json.dumps(x) for x in attributes) + "\n"
# 2. Preserve compression type
if rel_path.endswith(".gz"):
payload = gzip.compress(out_jsonl.encode("utf-8"))
elif rel_path.endswith((".zst", ".ztd")):
payload = zstd.ZstdCompressor().compress(out_jsonl.encode("utf-8"))
else:
payload = out_jsonl.encode("utf-8")
# 3. Write to args.dataset (local or S3)
if args.dataset.startswith("s3://"):
bucket, prefix = parse_s3_path(args.dataset)
key = os.path.join(prefix, out_rel)
workspace_s3.put_object(Bucket=bucket, Key=key, Body=payload)
else:
out_path = os.path.join(args.dataset, out_rel)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "wb") as fh:
fh.write(payload)
# 4. Mark queue item done
await work_queue.mark_done(work_item)
# 5. Drop empty sentinel file in <workspace>/results/
sentinel_rel = os.path.join("results", f"output_{work_item.hash}.jsonl")
if args.scratch.startswith("s3://"):
bkt, pfx = parse_s3_path(args.scratch)
key = os.path.join(pfx, sentinel_rel)
workspace_s3.put_object(Bucket=bkt, Key=key, Body=b"")
else:
sentinel_path = os.path.join(args.scratch, sentinel_rel)
os.makedirs(os.path.dirname(sentinel_path), exist_ok=True)
open(sentinel_path, "w").close()
except Exception as exc:
logger.exception(f"Worker {worker_id} exception: {exc!s}")
finally:
semaphore.release()
async def server_task(model_name_or_path, args, semaphore):
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
cmd = [
"vllm",
"serve",
model_name_or_path,
"--port",
str(SERVER_PORT),
"--uvicorn-log-level",
"warning",
"--disable-log-requests",
]
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# Ensure the subprocess is terminated on exit
def _kill_proc():
proc.terminate()
atexit.register(_kill_proc)
# Shared variables between tasks
last_running_req, last_queue_req = 0, 0
server_printed_ready_message = False
last_semaphore_release = time.time()
async def process_line(line):
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
server_logger.info(line)
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
# can see any warnings/errors more easily
if not server_printed_ready_message:
logger.info(line)
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
server_printed_ready_message = True
last_semaphore_release = time.time()
match = re.search(r"Running: (\d+) reqs", line)
if match:
last_running_req = int(match.group(1))
match = re.search(r"Waiting: (\d+) reqs", line)
if match:
last_queue_req = int(match.group(1))
logger.info(f"running req: {last_running_req} queue req: {last_queue_req}")
async def read_stream(stream):
while True:
line = await stream.readline()
if not line:
break
try:
line = line.decode("utf-8").rstrip()
await process_line(line)
except Exception as ex:
logger.warning(f"Got {ex} when reading log line from inference server, skipping")
async def timeout_task():
nonlocal last_running_req, last_queue_req, last_semaphore_release
try:
while True:
await asyncio.sleep(1)
if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
logger.info("Semaphore released, allowing a worker to proceed.")
except asyncio.CancelledError:
pass # Clean up if the task is cancelled
# Start tasks to read stdout, stderr, and handle timeout logic
stdout_task = asyncio.create_task(read_stream(proc.stdout))
stderr_task = asyncio.create_task(read_stream(proc.stderr))
timeout_task = asyncio.create_task(timeout_task())
try:
await proc.wait()
except asyncio.CancelledError:
logger.info("Got cancellation request for server")
proc.terminate()
raise
timeout_task.cancel()
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
async def server_host(model_name_or_path, args, semaphore):
MAX_RETRIES = 5
retry = 0
while retry < MAX_RETRIES:
await server_task(model_name_or_path, args, semaphore)
logger.warning("Server task ended")
retry += 1
if retry >= MAX_RETRIES:
logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline")
logger.error("")
logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4")
sys.exit(1)
async def check_server_ready():
max_attempts = 300
delay_sec = 1
url = f"http://localhost:{SERVER_PORT}/v1/models"
for attempt in range(1, max_attempts + 1):
try:
async with httpx.AsyncClient() as session:
response = await session.get(url)
if response.status_code == 200:
logger.info("server is ready.")
return
else:
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
except Exception:
logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...")
await asyncio.sleep(delay_sec)
raise Exception("model server did not become ready after waiting.")
async def download_model(model_name_or_path: str):
if model_name_or_path.startswith("s3://") or model_name_or_path.startswith("gs://") or model_name_or_path.startswith("weka://"):
logger.info(f"Downloading model directory from '{model_name_or_path}'")
model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
download_directory([model_name_or_path], model_cache_dir)
return model_cache_dir
elif os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
logger.info(f"Using local model path at '{model_name_or_path}'")
return model_name_or_path
else:
logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
snapshot_download(repo_id=model_name_or_path)
return model_name_or_path
async def metrics_reporter(work_queue):
while True:
# Leading newlines preserve table formatting in logs
logger.info(f"Queue remaining: {work_queue.size}")
logger.info("\n" + str(metrics))
await asyncio.sleep(10)
def submit_beaker_job(args):
from beaker import ( # type: ignore
Beaker,
Constraints,
EnvVar,
ExperimentSpec,
ImageSource,
Priority,
ResultSpec,
SecretNotFound,
TaskContext,
TaskResources,
TaskSpec,
)
b = Beaker.from_env(default_workspace=args.beaker_workspace)
account = b.account.whoami()
owner = account.name
beaker_image = f"jakep/olmocr-tagging-{VERSION}"
task_name = f"olmocr-{os.path.basename(args.dataset.rstrip('/'))}"
# Take out --beaker flag so the workers will just run things
args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
# Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i - 1] == "--pdfs"))]
try:
b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
except SecretNotFound:
print(
f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]"
)
if input().strip().lower() != "y":
print("Exiting...")
sys.exit(1)
b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
b.secret.write(
f"{owner}-AWS_CREDENTIALS_FILE",
open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(),
args.beaker_workspace,
)
env_var_secrets = [
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
]
try:
b.secret.get("OLMOCR_PREVIEW_HF_TOKEN", args.beaker_workspace)
env_var_secrets.append(EnvVar(name="HF_TOKEN", secret="OLMOCR_PREVIEW_HF_TOKEN"))
except SecretNotFound:
pass
try:
b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
except SecretNotFound:
print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
lines = []
prev_empty = False
for line in iter(input, None):
if not line and prev_empty:
break
prev_empty = not line
lines.append(line)
gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
if gcs_sa_key:
b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
# Create the experiment spec
experiment_spec = ExperimentSpec(
budget="ai2/oe-data",
description=task_name,
tasks=[
TaskSpec(
name=task_name,
propagate_failure=False,
propagate_preemption=False,
replicas=args.beaker_gpus,
context=TaskContext(
priority=Priority(args.beaker_priority),
preemptible=True,
),
image=ImageSource(beaker=beaker_image),
command=["python", "scripts/tagging_pipeline.py"] + args_list,
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
resources=TaskResources(gpu_count=1),
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
result=ResultSpec(path="/noop-results"),
)
],
)
experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
async def main():
parser = argparse.ArgumentParser(description="Tagging pipeline for Dolma JSONL dataset")
parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
# Beaker/job running stuff
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
parser.add_argument("--beaker_workspace", help="Beaker workspace to submit to", default="ai2/olmocr")
parser.add_argument(
"--beaker_cluster",
help="Beaker clusters you want to run on",
default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"],
)
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
parser.add_argument("--port", type=int, default=30024, help="Port for Model server")
args = parser.parse_args()
global SERVER_PORT, workspace_s3, dataset_s3
SERVER_PORT = args.port
workspace_s3 = boto3.client("s3")
dataset_s3 = boto3.client("s3")
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
if "BEAKER_JOB_ID" in os.environ:
server_logger.addHandler(console_handler)
if "AWS_CREDENTIALS_FILE" in os.environ:
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
workspace_s3 = boto3.client("s3")
dataset_s3 = boto3.client("s3")
# Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time
replica_count = int(os.environ.get("BEAKER_REPLICA_COUNT", "1"))
interval = 10 if (replica_count - 1) * 10 <= 240 else 240 / max(1, replica_count - 1)
sleep_time = int(int(os.environ.get("BEAKER_REPLICA_RANK", "0")) * interval)
logger.info(f"Beaker job sleeping for {sleep_time} seconds to stagger model downloads")
await asyncio.sleep(sleep_time)
# Initialize work queue
if args.scratch.startswith("s3://"):
work_queue = S3WorkQueue(workspace_s3, args.scratch)
else:
work_queue = LocalWorkQueue(args.scratch)
# Discover input files
files = set()
if args.dataset.startswith("s3://"):
pattern = args.dataset.rstrip("/") + "/documents/*.jsonl*"
matched = expand_s3_glob(dataset_s3, pattern)
files = set(matched.keys())
else:
docs_dir = os.path.join(args.dataset, "documents")
for root, _, fns in os.walk(docs_dir):
for fn in fns:
if fn.endswith((".jsonl", ".jsonl.gz", ".jsonl.ztd")):
files.add(os.path.join(root, fn))
# Populate the work queue if needed
await work_queue.populate_queue(list(files), items_per_group=1)
if args.beaker:
submit_beaker_job(args)
return
# If you get this far, then you are doing inference and need a GPU
check_torch_gpu_available()
logger.info(f"Starting pipeline with PID {os.getpid()}")
# Download the model before you do anything else
model_name_or_path = await download_model(args.model)
# Initialize the work queue
qsize = await work_queue.initialize_queue()
if qsize == 0:
logger.info("No work to do, exiting")
return
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore))
await check_server_ready()
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
# Create worker tasks to process the queue concurrently.
worker_tasks = []
for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
worker_tasks.append(task)
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)
model_server.cancel()
metrics_task.cancel()
logger.info("Work done")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -55,8 +55,13 @@ class TestNormalizeText(unittest.TestCase):
self.assertEqual(normalize_text("Hello<br/>everyone"), normalize_text("Hello\neveryone"))
def test_two_stars(self):
self.assertEqual(normalize_text("**Georges V.** (2007) *Le Forez du VIe au IVe millénaire av. J.-C. Territoires, identités et stratégies des sociétés humaines du Massif central dans le bassin amont de la Loire (France)*, thèse de doctorat, université de Bourgogne, Dijon, 2 vol., 435 p."),
"Georges V. (2007) - Le Forez du VIe au IVe millénaire av. J.-C. Territoires, identités et stratégies des sociétés humaines du Massif central dans le bassin amont de la Loire (France), thèse de doctorat, université de Bourgogne, Dijon, 2 vol., 435 p.")
self.assertEqual(
normalize_text(
"**Georges V.** (2007) *Le Forez du VIe au IVe millénaire av. J.-C. Territoires, identités et stratégies des sociétés humaines du Massif central dans le bassin amont de la Loire (France)*, thèse de doctorat, université de Bourgogne, Dijon, 2 vol., 435 p."
),
"Georges V. (2007) - Le Forez du VIe au IVe millénaire av. J.-C. Territoires, identités et stratégies des sociétés humaines du Massif central dans le bassin amont de la Loire (France), thèse de doctorat, université de Bourgogne, Dijon, 2 vol., 435 p.",
)
class TestBasePDFTest(unittest.TestCase):
"""Test the BasePDFTest class"""