Refactoring tagging bigly

This commit is contained in:
Jake Poznanski 2025-04-24 10:18:30 -07:00
parent 811d267bd5
commit c326fae03c

View File

@ -18,6 +18,7 @@ import os
import re
import sys
import time
import random
from concurrent.futures import ProcessPoolExecutor
from urllib.parse import urlparse
@ -26,6 +27,8 @@ import httpx
import torch
import zstandard as zstd
from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field, ValidationError
from typing import Optional
from olmocr.check import (
check_sglang_version,
@ -71,8 +74,89 @@ metrics = MetricsKeeper(window=60 * 5)
tracker = WorkerTracker()
# Process pool for offloading cpu bound work, like calculating anchor texts, max 32 workers, otherwise it can spawn way too many workers on a big machine
process_pool = ProcessPoolExecutor(max_workers=min(multiprocessing.cpu_count() // 2 + 1, 32), mp_context=multiprocessing.get_context("spawn"))
class PIIClassification(BaseModel):
is_resume_or_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv.")
async def _process_single_page(page_text: str) -> PIIClassification:
"""Helper function to process a single document or page."""
text = page_text
text_len = len(text)
# Count the attempt up-front
metrics.add_metrics(sglang_documents=1)
query = {
"model": "google/gemma-3-4b-it",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": (
f"{text}\n\n-----------\n"
"Given the text above, determine if the text above is a resume (résumé) or CV. Answer in a simple JSON block."
),
}
],
}
],
"max_tokens": 100,
"temperature": 0.0,
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
}
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
# ---------- HTTP call ---------------------------------------------------
try:
status, body = await apost(url, json_data=query)
except Exception as e:
logger.warning(f"SGLang network error: {e!s}")
metrics.add_metrics(sglang_errors=1)
return PIIClassification(is_resume_or_cv=None)
if status != 200:
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
metrics.add_metrics(sglang_errors=1)
return PIIClassification(is_resume_or_cv=None)
# ---------- Parse base JSON --------------------------------------------
try:
base = json.loads(body)
except json.JSONDecodeError:
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
metrics.add_metrics(sglang_errors=1)
return PIIClassification(is_resume_or_cv=None)
# Token accounting if available
if "usage" in base:
metrics.add_metrics(
sglang_input_tokens=base["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base["usage"].get("completion_tokens", 0),
)
# ---------- Extract the model message ----------------------------------
try:
content = base["choices"][0]["message"].get("content")
except (KeyError, IndexError, AttributeError) as e:
logger.warning(f"Missing fields in SGLang response: {e!s}")
metrics.add_metrics(sglang_errors=1)
return PIIClassification(is_resume_or_cv=None)
if not isinstance(content, str):
logger.warning("SGLang `content` is not a string; treating as error.")
metrics.add_metrics(sglang_errors=1)
return PIIClassification(is_resume_or_cv=None)
try:
pii_classification: PIIClassification = PIIClassification.model_validate_json(content)
return pii_classification
except ValidationError as e:
logger.warning(f"Unable to parse pii classification object: {e!s}")
metrics.add_metrics(sglang_errors=1)
return PIIClassification(is_resume_or_cv=None)
# Manual simple implementation of HTTP Post
@ -139,8 +223,7 @@ async def apost(url, json_data):
except:
pass
async def process_dolma_document(dolma_doc):
async def process_dolma_document(args, dolma_doc, sem):
"""
Query SGLang to detect PII, enforcing a JSON schema.
@ -148,7 +231,7 @@ async def process_dolma_document(dolma_doc):
Transport / HTTP errors
Missing or malformed fields in the response
Non-string or None `content`
Bad JSON in the models answer
Bad JSON in the model's answer
Always returns: (doc_id, contains_pii: bool, text_length: int)
"""
@ -156,88 +239,41 @@ async def process_dolma_document(dolma_doc):
text = dolma_doc.get("text", "") or ""
text_len = len(text)
# Count the attempt up-front
metrics.add_metrics(sglang_documents=1)
key_name = f"{args.model.replace('/', '_')}_pii_classification"
# 1) Define the JSON Schema for the response
pii_schema = {"type": "object", "properties": {"is_resume_or_cv": {"type": "boolean"}}, "required": ["is_resume_or_cv"], "additionalProperties": False}
# 2) Build the request payload, including `response_format`
query = {
"model": "google/gemma-3-4b-it",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": (
f"{text}\n\n-----------\n"
"Given the text above, determine if the text above is a resume (résumé) or CV. Answer in a simple JSON block."
),
}
],
}
],
"max_tokens": 100,
"temperature": 0.0,
"response_format": {"type": "json_schema", "json_schema": {"name": "PiiDetection", "schema": pii_schema}},
result_attributes = {
key_name: []
}
# If pdf_page_numbers is present, split the text and process each page separately
if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
page_numbers = dolma_doc["attributes"]["pdf_page_numbers"]
logger.info(f"Document {doc_id} has {len(page_numbers)} pages, processing each individually")
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
# Filter pages down to actual real content
selected_page_numbers = [p for p in page_numbers if p[0] < p[1]]
# ---------- HTTP call ---------------------------------------------------
try:
status, body = await apost(url, json_data=query)
except Exception as e:
logger.warning(f"SGLang network error: {e!s}")
metrics.add_metrics(sglang_errors=1)
return doc_id, False, text_len
# Sample 3 pages max per document
random.shuffle(selected_page_numbers)
selected_page_numbers = selected_page_numbers[:3]
for start_pos, end_pos, page_num in page_numbers:
if (start_pos, end_pos, page_num) in selected_page_numbers:
page_text = text[start_pos:end_pos]
if status != 200:
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
metrics.add_metrics(sglang_errors=1)
return doc_id, False, text_len
# ---------- Parse base JSON --------------------------------------------
try:
base = json.loads(body)
except json.JSONDecodeError:
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
metrics.add_metrics(sglang_errors=1)
return doc_id, False, text_len
# Token accounting if available
if "usage" in base:
metrics.add_metrics(
sglang_input_tokens=base["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base["usage"].get("completion_tokens", 0),
)
# ---------- Extract the model message ----------------------------------
try:
content = base["choices"][0]["message"].get("content")
except (KeyError, IndexError, AttributeError) as e:
logger.warning(f"Missing fields in SGLang response: {e!s}")
metrics.add_metrics(sglang_errors=1)
return doc_id, False, text_len
if not isinstance(content, str):
logger.warning("SGLang `content` is not a string; treating as error.")
metrics.add_metrics(sglang_errors=1)
return doc_id, False, text_len
# ---------- Parse the models JSON payload -----------------------------
try:
model_json = json.loads(content)
contains_pii = bool(model_json.get("is_resume_or_cv", False))
except json.JSONDecodeError:
logger.warning(f"Model JSON malformed: {content[:250]!r}")
metrics.add_metrics(sglang_errors=1)
return doc_id, False, text_len
return doc_id, contains_pii, text_len
# Process each page with the semaphore to limit concurrent requests
async with sem:
pii_class = await _process_single_document(page_text)
result_attributes[key_name].append([start_pos, end_pos, pii_class.is_resume_or_cv])
else:
result_attributes[key_name].append([start_pos, end_pos, None])
return result_attributes
else:
raise NotImplementedError("Missing code here, expecting this to be dolma docs made by olmocr....")
async def process_file(args, worker_id: int, file_uri: str):
"""
@ -265,14 +301,10 @@ async def process_file(args, worker_id: int, file_uri: str):
# Send all records in parallel, max 500 queued at a time
sem = asyncio.Semaphore(500)
async def _sem_process_dolma_document(dolma_doc):
async with sem:
return await process_dolma_document(dolma_doc)
async with asyncio.TaskGroup() as tg:
for line in lines:
data = json.loads(line)
task = tg.create_task(_sem_process_dolma_document(data))
dolma_doc = json.loads(line)
task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
page_tasks[data["id"]] = (task, data)
logger.info(f"Started taskgroup with {len(page_tasks)} items for {file_uri}")
@ -367,7 +399,7 @@ async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, work
async def sglang_server_task(model_name_or_path, args, semaphore):
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
mem_fraction_arg = ["--mem-fraction-static", "0.60"]
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
cmd = [
"python3",
@ -380,7 +412,6 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
"--log-level-http",
"warning",
]
cmd.extend(mem_fraction_arg)
proc = await asyncio.create_subprocess_exec(
*cmd,
@ -749,13 +780,10 @@ async def main():
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)
# Wait for server to stop
process_pool.shutdown(wait=False)
sglang_server.cancel()
metrics_task.cancel()
logger.info("Work done")
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())