mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-25 06:06:23 +00:00
Refactoring tagging bigly
This commit is contained in:
parent
811d267bd5
commit
c326fae03c
@ -18,6 +18,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -26,6 +27,8 @@ import httpx
|
||||
import torch
|
||||
import zstandard as zstd
|
||||
from huggingface_hub import snapshot_download
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from typing import Optional
|
||||
|
||||
from olmocr.check import (
|
||||
check_sglang_version,
|
||||
@ -71,8 +74,89 @@ metrics = MetricsKeeper(window=60 * 5)
|
||||
tracker = WorkerTracker()
|
||||
|
||||
|
||||
# Process pool for offloading cpu bound work, like calculating anchor texts, max 32 workers, otherwise it can spawn way too many workers on a big machine
|
||||
process_pool = ProcessPoolExecutor(max_workers=min(multiprocessing.cpu_count() // 2 + 1, 32), mp_context=multiprocessing.get_context("spawn"))
|
||||
class PIIClassification(BaseModel):
|
||||
is_resume_or_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv.")
|
||||
|
||||
|
||||
async def _process_single_page(page_text: str) -> PIIClassification:
|
||||
"""Helper function to process a single document or page."""
|
||||
text = page_text
|
||||
text_len = len(text)
|
||||
|
||||
# Count the attempt up-front
|
||||
metrics.add_metrics(sglang_documents=1)
|
||||
|
||||
query = {
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{text}\n\n-----------\n"
|
||||
"Given the text above, determine if the text above is a resume (résumé) or CV. Answer in a simple JSON block."
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.0,
|
||||
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
|
||||
}
|
||||
|
||||
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||
|
||||
# ---------- HTTP call ---------------------------------------------------
|
||||
try:
|
||||
status, body = await apost(url, json_data=query)
|
||||
except Exception as e:
|
||||
logger.warning(f"SGLang network error: {e!s}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return PIIClassification(is_resume_or_cv=None)
|
||||
|
||||
if status != 200:
|
||||
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return PIIClassification(is_resume_or_cv=None)
|
||||
|
||||
# ---------- Parse base JSON --------------------------------------------
|
||||
try:
|
||||
base = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return PIIClassification(is_resume_or_cv=None)
|
||||
|
||||
# Token accounting if available
|
||||
if "usage" in base:
|
||||
metrics.add_metrics(
|
||||
sglang_input_tokens=base["usage"].get("prompt_tokens", 0),
|
||||
sglang_output_tokens=base["usage"].get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
# ---------- Extract the model message ----------------------------------
|
||||
try:
|
||||
content = base["choices"][0]["message"].get("content")
|
||||
except (KeyError, IndexError, AttributeError) as e:
|
||||
logger.warning(f"Missing fields in SGLang response: {e!s}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return PIIClassification(is_resume_or_cv=None)
|
||||
|
||||
if not isinstance(content, str):
|
||||
logger.warning("SGLang `content` is not a string; treating as error.")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return PIIClassification(is_resume_or_cv=None)
|
||||
|
||||
try:
|
||||
pii_classification: PIIClassification = PIIClassification.model_validate_json(content)
|
||||
return pii_classification
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Unable to parse pii classification object: {e!s}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return PIIClassification(is_resume_or_cv=None)
|
||||
|
||||
|
||||
# Manual simple implementation of HTTP Post
|
||||
@ -139,8 +223,7 @@ async def apost(url, json_data):
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
async def process_dolma_document(dolma_doc):
|
||||
async def process_dolma_document(args, dolma_doc, sem):
|
||||
"""
|
||||
Query SGLang to detect PII, enforcing a JSON schema.
|
||||
|
||||
@ -148,7 +231,7 @@ async def process_dolma_document(dolma_doc):
|
||||
• Transport / HTTP errors
|
||||
• Missing or malformed fields in the response
|
||||
• Non-string or None `content`
|
||||
• Bad JSON in the model’s answer
|
||||
• Bad JSON in the model's answer
|
||||
|
||||
Always returns: (doc_id, contains_pii: bool, text_length: int)
|
||||
"""
|
||||
@ -156,88 +239,41 @@ async def process_dolma_document(dolma_doc):
|
||||
text = dolma_doc.get("text", "") or ""
|
||||
text_len = len(text)
|
||||
|
||||
# Count the attempt up-front
|
||||
metrics.add_metrics(sglang_documents=1)
|
||||
key_name = f"{args.model.replace('/', '_')}_pii_classification"
|
||||
|
||||
# 1) Define the JSON Schema for the response
|
||||
pii_schema = {"type": "object", "properties": {"is_resume_or_cv": {"type": "boolean"}}, "required": ["is_resume_or_cv"], "additionalProperties": False}
|
||||
|
||||
# 2) Build the request payload, including `response_format`
|
||||
query = {
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{text}\n\n-----------\n"
|
||||
"Given the text above, determine if the text above is a resume (résumé) or CV. Answer in a simple JSON block."
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.0,
|
||||
"response_format": {"type": "json_schema", "json_schema": {"name": "PiiDetection", "schema": pii_schema}},
|
||||
result_attributes = {
|
||||
key_name: []
|
||||
}
|
||||
|
||||
# If pdf_page_numbers is present, split the text and process each page separately
|
||||
if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
|
||||
page_numbers = dolma_doc["attributes"]["pdf_page_numbers"]
|
||||
|
||||
logger.info(f"Document {doc_id} has {len(page_numbers)} pages, processing each individually")
|
||||
|
||||
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||
# Filter pages down to actual real content
|
||||
selected_page_numbers = [p for p in page_numbers if p[0] < p[1]]
|
||||
|
||||
# ---------- HTTP call ---------------------------------------------------
|
||||
try:
|
||||
status, body = await apost(url, json_data=query)
|
||||
except Exception as e:
|
||||
logger.warning(f"SGLang network error: {e!s}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return doc_id, False, text_len
|
||||
# Sample 3 pages max per document
|
||||
random.shuffle(selected_page_numbers)
|
||||
selected_page_numbers = selected_page_numbers[:3]
|
||||
|
||||
for start_pos, end_pos, page_num in page_numbers:
|
||||
if (start_pos, end_pos, page_num) in selected_page_numbers:
|
||||
page_text = text[start_pos:end_pos]
|
||||
|
||||
if status != 200:
|
||||
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return doc_id, False, text_len
|
||||
|
||||
# ---------- Parse base JSON --------------------------------------------
|
||||
try:
|
||||
base = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return doc_id, False, text_len
|
||||
|
||||
# Token accounting if available
|
||||
if "usage" in base:
|
||||
metrics.add_metrics(
|
||||
sglang_input_tokens=base["usage"].get("prompt_tokens", 0),
|
||||
sglang_output_tokens=base["usage"].get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
# ---------- Extract the model message ----------------------------------
|
||||
try:
|
||||
content = base["choices"][0]["message"].get("content")
|
||||
except (KeyError, IndexError, AttributeError) as e:
|
||||
logger.warning(f"Missing fields in SGLang response: {e!s}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return doc_id, False, text_len
|
||||
|
||||
if not isinstance(content, str):
|
||||
logger.warning("SGLang `content` is not a string; treating as error.")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return doc_id, False, text_len
|
||||
|
||||
# ---------- Parse the model’s JSON payload -----------------------------
|
||||
try:
|
||||
model_json = json.loads(content)
|
||||
contains_pii = bool(model_json.get("is_resume_or_cv", False))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Model JSON malformed: {content[:250]!r}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
return doc_id, False, text_len
|
||||
|
||||
return doc_id, contains_pii, text_len
|
||||
# Process each page with the semaphore to limit concurrent requests
|
||||
async with sem:
|
||||
pii_class = await _process_single_document(page_text)
|
||||
|
||||
result_attributes[key_name].append([start_pos, end_pos, pii_class.is_resume_or_cv])
|
||||
else:
|
||||
result_attributes[key_name].append([start_pos, end_pos, None])
|
||||
|
||||
return result_attributes
|
||||
else:
|
||||
raise NotImplementedError("Missing code here, expecting this to be dolma docs made by olmocr....")
|
||||
|
||||
|
||||
async def process_file(args, worker_id: int, file_uri: str):
|
||||
"""
|
||||
@ -265,14 +301,10 @@ async def process_file(args, worker_id: int, file_uri: str):
|
||||
# Send all records in parallel, max 500 queued at a time
|
||||
sem = asyncio.Semaphore(500)
|
||||
|
||||
async def _sem_process_dolma_document(dolma_doc):
|
||||
async with sem:
|
||||
return await process_dolma_document(dolma_doc)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
task = tg.create_task(_sem_process_dolma_document(data))
|
||||
dolma_doc = json.loads(line)
|
||||
task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
|
||||
page_tasks[data["id"]] = (task, data)
|
||||
|
||||
logger.info(f"Started taskgroup with {len(page_tasks)} items for {file_uri}")
|
||||
@ -367,7 +399,7 @@ async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, work
|
||||
|
||||
async def sglang_server_task(model_name_or_path, args, semaphore):
|
||||
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
|
||||
mem_fraction_arg = ["--mem-fraction-static", "0.60"]
|
||||
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
|
||||
|
||||
cmd = [
|
||||
"python3",
|
||||
@ -380,7 +412,6 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
|
||||
"--log-level-http",
|
||||
"warning",
|
||||
]
|
||||
cmd.extend(mem_fraction_arg)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
@ -749,13 +780,10 @@ async def main():
|
||||
# Wait for all worker tasks to finish
|
||||
await asyncio.gather(*worker_tasks)
|
||||
|
||||
# Wait for server to stop
|
||||
process_pool.shutdown(wait=False)
|
||||
|
||||
sglang_server.cancel()
|
||||
metrics_task.cancel()
|
||||
logger.info("Work done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
Loading…
x
Reference in New Issue
Block a user