mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-25 06:06:23 +00:00
A bit more work on tagging
This commit is contained in:
parent
72bcfd8f31
commit
1854ae1269
@ -94,10 +94,10 @@ class DocumentClassification(pydantic.BaseModel):
|
||||
is_correspondence_or_letter: bool
|
||||
is_public_order: bool
|
||||
is_court_notice: bool
|
||||
|
||||
|
||||
# Additional context or notes
|
||||
classification_notes: str
|
||||
|
||||
|
||||
@property
|
||||
def get_document_types(self) -> List[str]:
|
||||
"""Get a list of all document types identified"""
|
||||
@ -396,7 +396,7 @@ def process_single_page(args, doc_info, pdf_s3_client=None):
|
||||
# Generate attribute key names using model name
|
||||
model_prefix = args.openai_model.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
language_key_name = f"{model_prefix}_language"
|
||||
|
||||
|
||||
# Initialize result attributes with all DocumentClassification fields
|
||||
result_attributes = {
|
||||
language_key_name: [[start_pos, end_pos, classification.language_code.name]],
|
||||
|
||||
@ -526,12 +526,12 @@ def apply_rule(doc, rule):
|
||||
def calculate_attribute_aggregate(doc, attribute_name, operation_type):
|
||||
"""
|
||||
Calculate an aggregate value for a numeric attribute.
|
||||
|
||||
|
||||
Args:
|
||||
doc: The document JSON object
|
||||
attribute_name: The attribute field to aggregate (e.g., "pii_tagging_ratio")
|
||||
operation_type: The type of aggregation to perform (e.g., "avg")
|
||||
|
||||
|
||||
Returns:
|
||||
The aggregated value, or None if calculation is not possible
|
||||
"""
|
||||
@ -539,33 +539,33 @@ def calculate_attribute_aggregate(doc, attribute_name, operation_type):
|
||||
if "attributes" not in doc or not doc["attributes"]:
|
||||
logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes")
|
||||
return None
|
||||
|
||||
|
||||
attributes = doc["attributes"]
|
||||
|
||||
|
||||
# Check if the specific attribute exists
|
||||
if attribute_name not in attributes:
|
||||
logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}")
|
||||
return None
|
||||
|
||||
|
||||
if not attributes[attribute_name]:
|
||||
logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}")
|
||||
return None
|
||||
|
||||
|
||||
# Extract the numeric values from the attribute spans
|
||||
# Each span is formatted as [start_pos, end_pos, value]
|
||||
values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
|
||||
|
||||
|
||||
if not values:
|
||||
logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}")
|
||||
return None
|
||||
|
||||
|
||||
# Convert all values to float
|
||||
try:
|
||||
numeric_values = [float(value) for value in values]
|
||||
except (ValueError, TypeError):
|
||||
logger.debug(f"Document {doc.get('id', 'unknown')} has non-numeric values for attribute: {attribute_name}")
|
||||
return None
|
||||
|
||||
|
||||
# Perform the aggregation
|
||||
if operation_type == "avg":
|
||||
if not numeric_values:
|
||||
@ -606,32 +606,32 @@ def apply_simple_rule(doc, attribute_name, rule_type):
|
||||
return False
|
||||
|
||||
# Handle numeric comparison rules (e.g., 'avg>0.3')
|
||||
if any(op in rule_type for op in ['>', '<', '>=', '<=', '==']):
|
||||
if any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
|
||||
# Parse the rule type into operation and comparison
|
||||
operation_parts = rule_type.split('>')
|
||||
operation_parts = rule_type.split(">")
|
||||
if len(operation_parts) == 2:
|
||||
operation_type, threshold = operation_parts
|
||||
comparison_op = '>'
|
||||
comparison_op = ">"
|
||||
else:
|
||||
operation_parts = rule_type.split('<')
|
||||
operation_parts = rule_type.split("<")
|
||||
if len(operation_parts) == 2:
|
||||
operation_type, threshold = operation_parts
|
||||
comparison_op = '<'
|
||||
comparison_op = "<"
|
||||
else:
|
||||
operation_parts = rule_type.split('>=')
|
||||
operation_parts = rule_type.split(">=")
|
||||
if len(operation_parts) == 2:
|
||||
operation_type, threshold = operation_parts
|
||||
comparison_op = '>='
|
||||
comparison_op = ">="
|
||||
else:
|
||||
operation_parts = rule_type.split('<=')
|
||||
operation_parts = rule_type.split("<=")
|
||||
if len(operation_parts) == 2:
|
||||
operation_type, threshold = operation_parts
|
||||
comparison_op = '<='
|
||||
comparison_op = "<="
|
||||
else:
|
||||
operation_parts = rule_type.split('==')
|
||||
operation_parts = rule_type.split("==")
|
||||
if len(operation_parts) == 2:
|
||||
operation_type, threshold = operation_parts
|
||||
comparison_op = '=='
|
||||
comparison_op = "=="
|
||||
else:
|
||||
raise ValueError(f"Invalid rule type: {rule_type}")
|
||||
|
||||
@ -648,15 +648,15 @@ def apply_simple_rule(doc, attribute_name, rule_type):
|
||||
return False
|
||||
|
||||
# Apply the comparison
|
||||
if comparison_op == '>':
|
||||
if comparison_op == ">":
|
||||
result = aggregate_value > threshold
|
||||
elif comparison_op == '<':
|
||||
elif comparison_op == "<":
|
||||
result = aggregate_value < threshold
|
||||
elif comparison_op == '>=':
|
||||
elif comparison_op == ">=":
|
||||
result = aggregate_value >= threshold
|
||||
elif comparison_op == '<=':
|
||||
elif comparison_op == "<=":
|
||||
result = aggregate_value <= threshold
|
||||
elif comparison_op == '==':
|
||||
elif comparison_op == "==":
|
||||
result = aggregate_value == threshold
|
||||
else:
|
||||
raise ValueError(f"Invalid comparison operator: {comparison_op}")
|
||||
@ -686,7 +686,7 @@ def apply_simple_rule(doc, attribute_name, rule_type):
|
||||
if result:
|
||||
logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (all {len(values)} values are True)")
|
||||
return result
|
||||
|
||||
|
||||
raise ValueError(f"Unknown rule type: {rule_type}")
|
||||
|
||||
|
||||
@ -939,8 +939,7 @@ class Parser:
|
||||
attribute_name, rule_type = rule_tuple
|
||||
|
||||
# Validate rule type
|
||||
if (rule_type not in ["any", "all"] and
|
||||
not any(op in rule_type for op in ['>', '<', '>=', '<=', '=='])):
|
||||
if rule_type not in ["any", "all"] and not any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
|
||||
raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all', or numeric comparison (e.g., 'avg>0.3')")
|
||||
|
||||
return RuleNode(attribute_name, rule_type)
|
||||
@ -1018,9 +1017,9 @@ def parse_rule(rule_string):
|
||||
raise ValueError(f"Invalid rule format: {rule_string}. Expected format: 'attribute_name:rule_type'")
|
||||
|
||||
attribute_name, rule_type = parts
|
||||
|
||||
|
||||
# Check for numeric comparison rule_type
|
||||
if any(op in rule_type for op in ['>', '<', '>=', '<=', '==']):
|
||||
if any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
|
||||
# This is a numeric comparison rule - we'll validate it in apply_simple_rule
|
||||
return attribute_name, rule_type
|
||||
elif rule_type not in ["any", "all"]:
|
||||
@ -1290,13 +1289,13 @@ def format_rule_stats(rule_stats):
|
||||
def generate_html_report(docs, title, summary, output_path):
|
||||
"""
|
||||
Generate an HTML report file with document texts
|
||||
|
||||
|
||||
Args:
|
||||
docs: List of documents to include in the report
|
||||
title: Title of the report
|
||||
summary: Summary statistics to include at the top
|
||||
output_path: Path to save the HTML file
|
||||
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
@ -1476,10 +1475,10 @@ def generate_html_report(docs, title, summary, output_path):
|
||||
doc_id = doc.get("id", f"unknown_{i}")
|
||||
# Get document text, falling back to JSON representation if not available
|
||||
doc_text = doc.get("text", json.dumps(doc, indent=2))
|
||||
|
||||
|
||||
# The first document gets the "selected" class
|
||||
selected_class = " selected" if i == 1 else ""
|
||||
|
||||
|
||||
html += f"""
|
||||
<div id="doc-{i}" class="document{selected_class}" tabindex="0">
|
||||
<div class="document-id">Document ID: {doc_id}</div>
|
||||
@ -1601,11 +1600,11 @@ def generate_html_report(docs, title, summary, output_path):
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
||||
|
||||
|
||||
# Write HTML to file
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(html)
|
||||
|
||||
|
||||
logger.info(f"Generated HTML report: {output_path}")
|
||||
|
||||
|
||||
@ -1661,40 +1660,40 @@ def main():
|
||||
# Get document IDs for each category
|
||||
ref_matches = set()
|
||||
hyp_matches = set()
|
||||
|
||||
|
||||
# Create mappings from document IDs to documents
|
||||
doc_map = {doc["id"]: doc for doc in all_docs if "id" in doc}
|
||||
|
||||
|
||||
# Find documents that match the reference and hypothesis rules
|
||||
for doc_id, doc in doc_map.items():
|
||||
if apply_rule(doc, ref_rule):
|
||||
ref_matches.add(doc_id)
|
||||
if apply_rule(doc, hyp_rule):
|
||||
hyp_matches.add(doc_id)
|
||||
|
||||
|
||||
# Calculate document sets for each category
|
||||
true_positives_ids = ref_matches.intersection(hyp_matches)
|
||||
true_negatives_ids = set(doc_map.keys()) - ref_matches - hyp_matches
|
||||
false_positives_ids = hyp_matches - ref_matches
|
||||
false_negatives_ids = ref_matches - hyp_matches
|
||||
|
||||
|
||||
# Create document lists for each category
|
||||
true_positives = [doc_map[doc_id] for doc_id in true_positives_ids]
|
||||
true_negatives = [doc_map[doc_id] for doc_id in true_negatives_ids]
|
||||
false_positives = [doc_map[doc_id] for doc_id in false_positives_ids]
|
||||
false_negatives = [doc_map[doc_id] for doc_id in false_negatives_ids]
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
tp = len(true_positives)
|
||||
tn = len(true_negatives)
|
||||
fp = len(false_positives)
|
||||
fn = len(false_negatives)
|
||||
|
||||
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
||||
iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
|
||||
|
||||
|
||||
# Prepare overall statistics
|
||||
overall_stats = {
|
||||
"total_docs": len(doc_map),
|
||||
@ -1733,37 +1732,29 @@ IoU: {iou:.4f}
|
||||
|
||||
# Generate HTML reports for each category
|
||||
logger.info("Generating HTML reports...")
|
||||
|
||||
|
||||
# True Positives
|
||||
generate_html_report(
|
||||
true_positives,
|
||||
"True Positives - Documents matching both Reference and Hypothesis Rules",
|
||||
summary,
|
||||
os.path.join(args.output_dir, "true_positives.html")
|
||||
true_positives, "True Positives - Documents matching both Reference and Hypothesis Rules", summary, os.path.join(args.output_dir, "true_positives.html")
|
||||
)
|
||||
|
||||
|
||||
# True Negatives
|
||||
generate_html_report(
|
||||
true_negatives,
|
||||
"True Negatives - Documents not matching either Rule",
|
||||
summary,
|
||||
os.path.join(args.output_dir, "true_negatives.html")
|
||||
)
|
||||
|
||||
generate_html_report(true_negatives, "True Negatives - Documents not matching either Rule", summary, os.path.join(args.output_dir, "true_negatives.html"))
|
||||
|
||||
# False Positives
|
||||
generate_html_report(
|
||||
false_positives,
|
||||
"False Positives - Documents matching Hypothesis but not Reference Rule",
|
||||
summary,
|
||||
os.path.join(args.output_dir, "false_positives.html")
|
||||
os.path.join(args.output_dir, "false_positives.html"),
|
||||
)
|
||||
|
||||
|
||||
# False Negatives
|
||||
generate_html_report(
|
||||
false_negatives,
|
||||
"False Negatives - Documents matching Reference but not Hypothesis Rule",
|
||||
summary,
|
||||
os.path.join(args.output_dir, "false_negatives.html")
|
||||
os.path.join(args.output_dir, "false_negatives.html"),
|
||||
)
|
||||
|
||||
# Generate index.html file that links to all reports
|
||||
@ -1857,7 +1848,7 @@ IoU: {iou:.4f}
|
||||
|
||||
with open(os.path.join(args.output_dir, "index.html"), "w", encoding="utf-8") as f:
|
||||
f.write(index_html)
|
||||
|
||||
|
||||
# Print summary
|
||||
logger.info("\n--- COMPARISON SUMMARY ---")
|
||||
logger.info(f"Documents Folder: {args.docs_folder}")
|
||||
@ -1887,7 +1878,7 @@ IoU: {iou:.4f}
|
||||
logger.info(f"Recall: {recall:.4f}")
|
||||
logger.info(f"F1 Score: {f1:.4f}")
|
||||
logger.info(f"IoU: {iou:.4f}")
|
||||
|
||||
|
||||
logger.info(f"\nResults saved to: {args.output_dir}/index.html")
|
||||
|
||||
|
||||
|
||||
788
scripts/tagging_pipeline_v2.py
Normal file
788
scripts/tagging_pipeline_v2.py
Normal file
@ -0,0 +1,788 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tagging pipeline for Dolma JSONL datasets.
|
||||
|
||||
For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
|
||||
this script issues a model prompt completion
|
||||
collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
|
||||
scratch/attributes/, mirroring the input structure.
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import atexit
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
import httpx
|
||||
import zstandard as zstd
|
||||
from huggingface_hub import snapshot_download
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from olmocr.check import (
|
||||
check_torch_gpu_available,
|
||||
)
|
||||
from olmocr.metrics import MetricsKeeper
|
||||
from olmocr.s3_utils import (
|
||||
download_directory,
|
||||
expand_s3_glob,
|
||||
get_s3_bytes_with_backoff,
|
||||
parse_s3_path,
|
||||
)
|
||||
from olmocr.version import VERSION
|
||||
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.propagate = False
|
||||
|
||||
server_logger = logging.getLogger("vllm")
|
||||
server_logger.propagate = False
|
||||
|
||||
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
||||
|
||||
# Add handlers to the logger
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
server_logger.addHandler(file_handler)
|
||||
|
||||
|
||||
# Default port; overridden by --port
|
||||
SERVER_PORT = 30024
|
||||
|
||||
# Global variables for token statistics
|
||||
metrics = MetricsKeeper(window=60 * 5)
|
||||
|
||||
|
||||
class PIIClassification(BaseModel):
|
||||
primary_language: str = Field(..., description="Primary language as a two-letter code")
|
||||
document_type: str = Field(..., description="Basic summary of document type classification")
|
||||
is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
|
||||
|
||||
is_academic_paper: bool
|
||||
is_textbook: bool
|
||||
is_news_article: bool
|
||||
is_test_or_quiz: bool
|
||||
is_homework_assignment: bool
|
||||
is_class_syllabus: bool
|
||||
is_meeting_minutes: bool
|
||||
is_legal_contract: bool
|
||||
is_form: bool
|
||||
is_correspondence_or_letter: bool
|
||||
is_public_order: bool
|
||||
is_court_notice: bool
|
||||
|
||||
contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
|
||||
|
||||
|
||||
async def _process_single_page(page_text: str) -> PIIClassification:
|
||||
"""Helper function to process a single document or page."""
|
||||
text = page_text
|
||||
|
||||
query = {
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{text}\n\n-----------\n"
|
||||
"Given the text above, determine what type of document it is. Answer in JSON. The format of your json object should be {'primary_language': str, 'document_type': str, 'is_resume_cv': bool, 'is_academic_paper': bool, 'is_textbook': bool, 'is_news_article': bool, 'is_test_or_quiz': bool, 'is_homework_assignment': bool, 'is_class_syllabus': bool, 'is_meeting_minutes': bool, 'is_legal_contract': bool, 'is_form': bool, 'is_correspondence_or_letter': bool, 'is_public_order': bool, 'is_court_notice': bool, 'contains_pii': bool}"
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.0,
|
||||
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
|
||||
}
|
||||
|
||||
url = f"http://localhost:{SERVER_PORT}/v1/chat/completions"
|
||||
|
||||
# ---------- HTTP call ---------------------------------------------------
|
||||
try:
|
||||
status, body = await apost(url, json_data=query)
|
||||
except Exception as e:
|
||||
logger.warning(f"Server network error: {e!s}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
metrics.add_metrics(server_requests=1)
|
||||
|
||||
if status != 200:
|
||||
logger.warning(f"Server HTTP {status}: {body[:250]!r}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
# ---------- Parse base JSON --------------------------------------------
|
||||
try:
|
||||
base = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Server response is not valid JSON: {body[:250]!r}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
# Token accounting if available
|
||||
if "usage" in base:
|
||||
metrics.add_metrics(
|
||||
server_input_tokens=base["usage"].get("prompt_tokens", 0),
|
||||
server_output_tokens=base["usage"].get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
# ---------- Extract the model message ----------------------------------
|
||||
try:
|
||||
content = base["choices"][0]["message"].get("content")
|
||||
except (KeyError, IndexError, AttributeError) as e:
|
||||
logger.warning(f"Missing fields in Server response: {e!s}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
if not isinstance(content, str):
|
||||
logger.warning("Server `content` is not a string; treating as error.")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
try:
|
||||
pii_classification: PIIClassification = PIIClassification.model_validate_json(content)
|
||||
return pii_classification
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Unable to parse pii classification object: {e!s}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
|
||||
# Manual simple implementation of HTTP Post
|
||||
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
|
||||
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
|
||||
# that at the scale of 100M+ requests, that they deadlock in different strange ways
|
||||
async def apost(url, json_data):
|
||||
parsed_url = urlparse(url)
|
||||
host = parsed_url.hostname
|
||||
port = parsed_url.port or 80
|
||||
path = parsed_url.path or "/"
|
||||
|
||||
writer = None
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(host, port)
|
||||
|
||||
json_payload = json.dumps(json_data)
|
||||
request = (
|
||||
f"POST {path} HTTP/1.1\r\n"
|
||||
f"Host: {host}\r\n"
|
||||
f"Content-Type: application/json\r\n"
|
||||
f"Content-Length: {len(json_payload)}\r\n"
|
||||
f"Connection: close\r\n\r\n"
|
||||
f"{json_payload}"
|
||||
)
|
||||
writer.write(request.encode())
|
||||
await writer.drain()
|
||||
|
||||
# Read status line
|
||||
status_line = await reader.readline()
|
||||
if not status_line:
|
||||
raise ConnectionError("No response from server")
|
||||
status_parts = status_line.decode().strip().split(" ", 2)
|
||||
if len(status_parts) < 2:
|
||||
raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
|
||||
status_code = int(status_parts[1])
|
||||
|
||||
# Read headers
|
||||
headers = {}
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if line in (b"\r\n", b"\n", b""):
|
||||
break
|
||||
key, _, value = line.decode().partition(":")
|
||||
headers[key.strip().lower()] = value.strip()
|
||||
|
||||
# Read response body
|
||||
if "content-length" in headers:
|
||||
body_length = int(headers["content-length"])
|
||||
response_body = await reader.readexactly(body_length)
|
||||
else:
|
||||
raise ConnectionError("Anything other than fixed content length responses are not implemented yet")
|
||||
|
||||
return status_code, response_body
|
||||
except Exception as e:
|
||||
# Pass through errors
|
||||
raise e
|
||||
finally:
|
||||
# But just make sure to close the socket on your way out
|
||||
if writer is not None:
|
||||
try:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
async def process_dolma_document(args, dolma_doc, sem):
|
||||
"""
|
||||
Query model to detect PII, enforcing a JSON schema.
|
||||
|
||||
Resilient to:
|
||||
• Transport / HTTP errors
|
||||
• Missing or malformed fields in the response
|
||||
• Non-string or None `content`
|
||||
• Bad JSON in the model's answer
|
||||
|
||||
Always returns: (doc_id, contains_pii: bool, text_length: int)
|
||||
"""
|
||||
doc_id = dolma_doc.get("id")
|
||||
text = dolma_doc.get("text", "") or ""
|
||||
|
||||
# Create keys for all fields in PIIClassification
|
||||
prefix = args.model.replace("/", "_")
|
||||
result_attributes = {}
|
||||
|
||||
# Initialize attribute lists for all PIIClassification fields
|
||||
for field_name in PIIClassification.model_fields:
|
||||
key_name = f"{prefix}_{field_name}"
|
||||
result_attributes[key_name] = []
|
||||
|
||||
# If pdf_page_numbers is present, sample first 5000 characters of the document
|
||||
if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
|
||||
page_numbers = dolma_doc["attributes"]["pdf_page_numbers"]
|
||||
|
||||
logger.info(f"Document {doc_id} has {len(page_numbers)} pages, processing first 5000 characters")
|
||||
|
||||
# Take first 5000 characters of the document
|
||||
sample_text = text[:5000]
|
||||
|
||||
# Process the sample with the semaphore to limit concurrent requests
|
||||
async with sem:
|
||||
pii_class = await _process_single_page(sample_text)
|
||||
|
||||
# Add all classification attributes to results
|
||||
for field_name in PIIClassification.model_fields:
|
||||
key_name = f"{prefix}_{field_name}"
|
||||
attribute_value = getattr(pii_class, field_name)
|
||||
|
||||
# Add the classification result to all pages
|
||||
for start_pos, end_pos, page_num in page_numbers:
|
||||
result_attributes[key_name].append([start_pos, end_pos, attribute_value])
|
||||
|
||||
return result_attributes
|
||||
else:
|
||||
raise NotImplementedError("Missing code here, expecting this to be dolma docs made by olmocr....")
|
||||
|
||||
|
||||
async def process_file(args, worker_id: int, file_uri: str):
|
||||
"""
|
||||
Download a JSONL file, query model per record, and collect attributes.
|
||||
"""
|
||||
# Fetch raw bytes (S3 or local)
|
||||
if file_uri.startswith("s3://"):
|
||||
raw = await asyncio.to_thread(get_s3_bytes_with_backoff, dataset_s3, file_uri)
|
||||
else:
|
||||
with open(file_uri, "rb") as f:
|
||||
raw = f.read()
|
||||
|
||||
# Decompress if needed
|
||||
if file_uri.endswith(".gz"):
|
||||
file_bytes = gzip.decompress(raw)
|
||||
elif file_uri.endswith(".ztd") or file_uri.endswith(".zst") or file_uri.endswith(".zstd"):
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
file_bytes = dctx.decompress(raw, max_output_size=1_000_000_000)
|
||||
else:
|
||||
file_bytes = raw
|
||||
|
||||
lines = file_bytes.decode("utf-8").splitlines()
|
||||
page_tasks = {}
|
||||
|
||||
# Send all records in parallel, max N queued at a time
|
||||
sem = asyncio.Semaphore(args.parallel_requests)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for line in lines:
|
||||
dolma_doc = json.loads(line)
|
||||
task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
|
||||
page_tasks[dolma_doc["id"]] = (task, dolma_doc)
|
||||
|
||||
logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}")
|
||||
|
||||
# Collect results and build attributes
|
||||
attributes = []
|
||||
for doc_id, (task, dolma_doc) in page_tasks.items():
|
||||
doc_attributes = task.result()
|
||||
|
||||
attributes.append({"id": doc_id, "attributes": doc_attributes})
|
||||
|
||||
return attributes
|
||||
|
||||
|
||||
async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, worker_id: int):
|
||||
"""
|
||||
Pop work-items off the queue, run PII tagging, write the attributes file
|
||||
next to the dataset (keeping the original compression), mark the item done,
|
||||
and drop an empty sentinel file in <workspace>/results/.
|
||||
"""
|
||||
while True:
|
||||
await semaphore.acquire()
|
||||
work_item = await work_queue.get_work()
|
||||
|
||||
if work_item is None:
|
||||
logger.info(f"Worker {worker_id} exiting – queue empty")
|
||||
semaphore.release()
|
||||
break
|
||||
|
||||
file_uri = work_item.work_paths[0]
|
||||
logger.info(f"Worker {worker_id} processing {file_uri}")
|
||||
|
||||
try:
|
||||
# ------------------------------------------------------------------
|
||||
# Run the per-file pipeline
|
||||
# ------------------------------------------------------------------
|
||||
attributes = await process_file(args, worker_id, file_uri)
|
||||
|
||||
# 1. Build the relative path that mirrors documents/…
|
||||
if file_uri.startswith("s3://"):
|
||||
_, key = parse_s3_path(file_uri)
|
||||
_, docs_prefix = parse_s3_path(args.dataset)
|
||||
rel_path = key[len(os.path.join(docs_prefix, "documents/")) :]
|
||||
else:
|
||||
docs_root = os.path.join(args.dataset, "documents")
|
||||
rel_path = os.path.relpath(file_uri, docs_root)
|
||||
|
||||
out_rel = os.path.join("attributes", args.attribute_name, rel_path)
|
||||
out_jsonl = "\n".join(json.dumps(x) for x in attributes) + "\n"
|
||||
|
||||
# 2. Preserve compression type
|
||||
if rel_path.endswith(".gz"):
|
||||
payload = gzip.compress(out_jsonl.encode("utf-8"))
|
||||
elif rel_path.endswith((".zst", ".ztd")):
|
||||
payload = zstd.ZstdCompressor().compress(out_jsonl.encode("utf-8"))
|
||||
else:
|
||||
payload = out_jsonl.encode("utf-8")
|
||||
|
||||
# 3. Write to args.dataset (local or S3)
|
||||
if args.dataset.startswith("s3://"):
|
||||
bucket, prefix = parse_s3_path(args.dataset)
|
||||
key = os.path.join(prefix, out_rel)
|
||||
workspace_s3.put_object(Bucket=bucket, Key=key, Body=payload)
|
||||
else:
|
||||
out_path = os.path.join(args.dataset, out_rel)
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
with open(out_path, "wb") as fh:
|
||||
fh.write(payload)
|
||||
|
||||
# 4. Mark queue item done
|
||||
await work_queue.mark_done(work_item)
|
||||
|
||||
# 5. Drop empty sentinel file in <workspace>/results/
|
||||
sentinel_rel = os.path.join("results", f"output_{work_item.hash}.jsonl")
|
||||
if args.scratch.startswith("s3://"):
|
||||
bkt, pfx = parse_s3_path(args.scratch)
|
||||
key = os.path.join(pfx, sentinel_rel)
|
||||
workspace_s3.put_object(Bucket=bkt, Key=key, Body=b"")
|
||||
else:
|
||||
sentinel_path = os.path.join(args.scratch, sentinel_rel)
|
||||
os.makedirs(os.path.dirname(sentinel_path), exist_ok=True)
|
||||
open(sentinel_path, "w").close()
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"Worker {worker_id} exception: {exc!s}")
|
||||
finally:
|
||||
semaphore.release()
|
||||
|
||||
|
||||
async def server_task(model_name_or_path, args, semaphore):
|
||||
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
|
||||
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
|
||||
|
||||
cmd = [
|
||||
"vllm",
|
||||
"serve",
|
||||
model_name_or_path,
|
||||
"--port",
|
||||
str(SERVER_PORT),
|
||||
"--uvicorn-log-level",
|
||||
"warning",
|
||||
"--disable-log-requests",
|
||||
]
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
# Ensure the subprocess is terminated on exit
|
||||
def _kill_proc():
|
||||
proc.terminate()
|
||||
|
||||
atexit.register(_kill_proc)
|
||||
|
||||
# Shared variables between tasks
|
||||
last_running_req, last_queue_req = 0, 0
|
||||
server_printed_ready_message = False
|
||||
last_semaphore_release = time.time()
|
||||
|
||||
async def process_line(line):
|
||||
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
|
||||
server_logger.info(line)
|
||||
|
||||
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
|
||||
# can see any warnings/errors more easily
|
||||
if not server_printed_ready_message:
|
||||
logger.info(line)
|
||||
|
||||
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
||||
server_printed_ready_message = True
|
||||
last_semaphore_release = time.time()
|
||||
|
||||
match = re.search(r"Running: (\d+) reqs", line)
|
||||
if match:
|
||||
last_running_req = int(match.group(1))
|
||||
|
||||
match = re.search(r"Waiting: (\d+) reqs", line)
|
||||
if match:
|
||||
last_queue_req = int(match.group(1))
|
||||
logger.info(f"running req: {last_running_req} queue req: {last_queue_req}")
|
||||
|
||||
async def read_stream(stream):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
try:
|
||||
line = line.decode("utf-8").rstrip()
|
||||
await process_line(line)
|
||||
except Exception as ex:
|
||||
logger.warning(f"Got {ex} when reading log line from inference server, skipping")
|
||||
|
||||
async def timeout_task():
|
||||
nonlocal last_running_req, last_queue_req, last_semaphore_release
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
|
||||
semaphore.release()
|
||||
last_semaphore_release = time.time()
|
||||
logger.info("Semaphore released, allowing a worker to proceed.")
|
||||
except asyncio.CancelledError:
|
||||
pass # Clean up if the task is cancelled
|
||||
|
||||
# Start tasks to read stdout, stderr, and handle timeout logic
|
||||
stdout_task = asyncio.create_task(read_stream(proc.stdout))
|
||||
stderr_task = asyncio.create_task(read_stream(proc.stderr))
|
||||
timeout_task = asyncio.create_task(timeout_task())
|
||||
|
||||
try:
|
||||
await proc.wait()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Got cancellation request for server")
|
||||
proc.terminate()
|
||||
raise
|
||||
|
||||
timeout_task.cancel()
|
||||
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
||||
|
||||
|
||||
async def server_host(model_name_or_path, args, semaphore):
|
||||
MAX_RETRIES = 5
|
||||
retry = 0
|
||||
|
||||
while retry < MAX_RETRIES:
|
||||
await server_task(model_name_or_path, args, semaphore)
|
||||
logger.warning("Server task ended")
|
||||
retry += 1
|
||||
|
||||
if retry >= MAX_RETRIES:
|
||||
logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline")
|
||||
logger.error("")
|
||||
logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
async def check_server_ready():
|
||||
max_attempts = 300
|
||||
delay_sec = 1
|
||||
url = f"http://localhost:{SERVER_PORT}/v1/models"
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient() as session:
|
||||
response = await session.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info("server is ready.")
|
||||
return
|
||||
else:
|
||||
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
|
||||
except Exception:
|
||||
logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...")
|
||||
|
||||
await asyncio.sleep(delay_sec)
|
||||
|
||||
raise Exception("model server did not become ready after waiting.")
|
||||
|
||||
|
||||
async def download_model(model_name_or_path: str):
|
||||
if model_name_or_path.startswith("s3://") or model_name_or_path.startswith("gs://") or model_name_or_path.startswith("weka://"):
|
||||
logger.info(f"Downloading model directory from '{model_name_or_path}'")
|
||||
model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
|
||||
download_directory([model_name_or_path], model_cache_dir)
|
||||
return model_cache_dir
|
||||
elif os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
|
||||
logger.info(f"Using local model path at '{model_name_or_path}'")
|
||||
return model_name_or_path
|
||||
else:
|
||||
logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
|
||||
snapshot_download(repo_id=model_name_or_path)
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
async def metrics_reporter(work_queue):
|
||||
while True:
|
||||
# Leading newlines preserve table formatting in logs
|
||||
logger.info(f"Queue remaining: {work_queue.size}")
|
||||
logger.info("\n" + str(metrics))
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
def submit_beaker_job(args):
|
||||
from beaker import ( # type: ignore
|
||||
Beaker,
|
||||
Constraints,
|
||||
EnvVar,
|
||||
ExperimentSpec,
|
||||
ImageSource,
|
||||
Priority,
|
||||
ResultSpec,
|
||||
SecretNotFound,
|
||||
TaskContext,
|
||||
TaskResources,
|
||||
TaskSpec,
|
||||
)
|
||||
|
||||
b = Beaker.from_env(default_workspace=args.beaker_workspace)
|
||||
account = b.account.whoami()
|
||||
owner = account.name
|
||||
beaker_image = f"jakep/olmocr-tagging-{VERSION}"
|
||||
|
||||
task_name = f"olmocr-{os.path.basename(args.dataset.rstrip('/'))}"
|
||||
|
||||
# Take out --beaker flag so the workers will just run things
|
||||
args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
|
||||
|
||||
# Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
|
||||
args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i - 1] == "--pdfs"))]
|
||||
|
||||
try:
|
||||
b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
|
||||
b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
|
||||
b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
|
||||
except SecretNotFound:
|
||||
print(
|
||||
f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]"
|
||||
)
|
||||
|
||||
if input().strip().lower() != "y":
|
||||
print("Exiting...")
|
||||
sys.exit(1)
|
||||
|
||||
b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
|
||||
b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
|
||||
b.secret.write(
|
||||
f"{owner}-AWS_CREDENTIALS_FILE",
|
||||
open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(),
|
||||
args.beaker_workspace,
|
||||
)
|
||||
|
||||
env_var_secrets = [
|
||||
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
|
||||
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
|
||||
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
|
||||
]
|
||||
|
||||
try:
|
||||
b.secret.get("OLMOCR_PREVIEW_HF_TOKEN", args.beaker_workspace)
|
||||
env_var_secrets.append(EnvVar(name="HF_TOKEN", secret="OLMOCR_PREVIEW_HF_TOKEN"))
|
||||
except SecretNotFound:
|
||||
pass
|
||||
|
||||
try:
|
||||
b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
|
||||
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
|
||||
except SecretNotFound:
|
||||
print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
|
||||
lines = []
|
||||
prev_empty = False
|
||||
for line in iter(input, None):
|
||||
if not line and prev_empty:
|
||||
break
|
||||
prev_empty = not line
|
||||
lines.append(line)
|
||||
gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
|
||||
if gcs_sa_key:
|
||||
b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
|
||||
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
|
||||
|
||||
# Create the experiment spec
|
||||
experiment_spec = ExperimentSpec(
|
||||
budget="ai2/oe-data",
|
||||
description=task_name,
|
||||
tasks=[
|
||||
TaskSpec(
|
||||
name=task_name,
|
||||
propagate_failure=False,
|
||||
propagate_preemption=False,
|
||||
replicas=args.beaker_gpus,
|
||||
context=TaskContext(
|
||||
priority=Priority(args.beaker_priority),
|
||||
preemptible=True,
|
||||
),
|
||||
image=ImageSource(beaker=beaker_image),
|
||||
command=["python", "scripts/tagging_pipeline.py"] + args_list,
|
||||
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
|
||||
resources=TaskResources(gpu_count=1),
|
||||
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
|
||||
result=ResultSpec(path="/noop-results"),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
|
||||
|
||||
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Tagging pipeline for Dolma JSONL dataset")
|
||||
parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
|
||||
parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
|
||||
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
|
||||
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
|
||||
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
|
||||
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
|
||||
|
||||
# Beaker/job running stuff
|
||||
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
|
||||
parser.add_argument("--beaker_workspace", help="Beaker workspace to submit to", default="ai2/olmocr")
|
||||
parser.add_argument(
|
||||
"--beaker_cluster",
|
||||
help="Beaker clusters you want to run on",
|
||||
default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"],
|
||||
)
|
||||
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
|
||||
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
|
||||
|
||||
parser.add_argument("--port", type=int, default=30024, help="Port for Model server")
|
||||
args = parser.parse_args()
|
||||
|
||||
global SERVER_PORT, workspace_s3, dataset_s3
|
||||
SERVER_PORT = args.port
|
||||
workspace_s3 = boto3.client("s3")
|
||||
dataset_s3 = boto3.client("s3")
|
||||
|
||||
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
|
||||
if "BEAKER_JOB_ID" in os.environ:
|
||||
server_logger.addHandler(console_handler)
|
||||
if "AWS_CREDENTIALS_FILE" in os.environ:
|
||||
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
|
||||
with open(cred_path, "w") as f:
|
||||
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
|
||||
if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
|
||||
cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
|
||||
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
|
||||
with open(cred_path, "w") as f:
|
||||
f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
|
||||
workspace_s3 = boto3.client("s3")
|
||||
dataset_s3 = boto3.client("s3")
|
||||
|
||||
# Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time
|
||||
replica_count = int(os.environ.get("BEAKER_REPLICA_COUNT", "1"))
|
||||
interval = 10 if (replica_count - 1) * 10 <= 240 else 240 / max(1, replica_count - 1)
|
||||
sleep_time = int(int(os.environ.get("BEAKER_REPLICA_RANK", "0")) * interval)
|
||||
logger.info(f"Beaker job sleeping for {sleep_time} seconds to stagger model downloads")
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
# Initialize work queue
|
||||
if args.scratch.startswith("s3://"):
|
||||
work_queue = S3WorkQueue(workspace_s3, args.scratch)
|
||||
else:
|
||||
work_queue = LocalWorkQueue(args.scratch)
|
||||
|
||||
# Discover input files
|
||||
files = set()
|
||||
if args.dataset.startswith("s3://"):
|
||||
pattern = args.dataset.rstrip("/") + "/documents/*.jsonl*"
|
||||
matched = expand_s3_glob(dataset_s3, pattern)
|
||||
files = set(matched.keys())
|
||||
else:
|
||||
docs_dir = os.path.join(args.dataset, "documents")
|
||||
for root, _, fns in os.walk(docs_dir):
|
||||
for fn in fns:
|
||||
if fn.endswith((".jsonl", ".jsonl.gz", ".jsonl.ztd")):
|
||||
files.add(os.path.join(root, fn))
|
||||
|
||||
# Populate the work queue if needed
|
||||
await work_queue.populate_queue(list(files), items_per_group=1)
|
||||
|
||||
if args.beaker:
|
||||
submit_beaker_job(args)
|
||||
return
|
||||
|
||||
# If you get this far, then you are doing inference and need a GPU
|
||||
check_torch_gpu_available()
|
||||
|
||||
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
||||
|
||||
# Download the model before you do anything else
|
||||
model_name_or_path = await download_model(args.model)
|
||||
|
||||
# Initialize the work queue
|
||||
qsize = await work_queue.initialize_queue()
|
||||
|
||||
if qsize == 0:
|
||||
logger.info("No work to do, exiting")
|
||||
return
|
||||
|
||||
# Create a semaphore to control worker access
|
||||
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
||||
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore))
|
||||
|
||||
await check_server_ready()
|
||||
|
||||
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
||||
|
||||
# Create worker tasks to process the queue concurrently.
|
||||
worker_tasks = []
|
||||
for i in range(args.workers):
|
||||
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
|
||||
worker_tasks.append(task)
|
||||
|
||||
# Wait for all worker tasks to finish
|
||||
await asyncio.gather(*worker_tasks)
|
||||
|
||||
model_server.cancel()
|
||||
metrics_task.cancel()
|
||||
logger.info("Work done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -55,8 +55,13 @@ class TestNormalizeText(unittest.TestCase):
|
||||
self.assertEqual(normalize_text("Hello<br/>everyone"), normalize_text("Hello\neveryone"))
|
||||
|
||||
def test_two_stars(self):
|
||||
self.assertEqual(normalize_text("**Georges V.** (2007) – *Le Forez du VIe au IVe millénaire av. J.-C. Territoires, identités et stratégies des sociétés humaines du Massif central dans le bassin amont de la Loire (France)*, thèse de doctorat, université de Bourgogne, Dijon, 2 vol., 435 p."),
|
||||
"Georges V. (2007) - Le Forez du VIe au IVe millénaire av. J.-C. Territoires, identités et stratégies des sociétés humaines du Massif central dans le bassin amont de la Loire (France), thèse de doctorat, université de Bourgogne, Dijon, 2 vol., 435 p.")
|
||||
self.assertEqual(
|
||||
normalize_text(
|
||||
"**Georges V.** (2007) – *Le Forez du VIe au IVe millénaire av. J.-C. Territoires, identités et stratégies des sociétés humaines du Massif central dans le bassin amont de la Loire (France)*, thèse de doctorat, université de Bourgogne, Dijon, 2 vol., 435 p."
|
||||
),
|
||||
"Georges V. (2007) - Le Forez du VIe au IVe millénaire av. J.-C. Territoires, identités et stratégies des sociétés humaines du Massif central dans le bassin amont de la Loire (France), thèse de doctorat, université de Bourgogne, Dijon, 2 vol., 435 p.",
|
||||
)
|
||||
|
||||
|
||||
class TestBasePDFTest(unittest.TestCase):
|
||||
"""Test the BasePDFTest class"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user