mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-18 22:01:56 +00:00
Better stats
This commit is contained in:
parent
9ce28c0504
commit
ae9b1c405d
@ -30,7 +30,7 @@ from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
from pdelfin.prompts import build_finetuning_prompt, PageResponse
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
from pdelfin.check import check_poppler_version
|
||||
from pdelfin.metrics import MetricsKeeper
|
||||
from pdelfin.metrics import MetricsKeeper, WorkerTracker
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -62,6 +62,7 @@ pdf_s3 = boto3.client('s3')
|
||||
|
||||
# Global variables for token statistics
|
||||
metrics = MetricsKeeper(window=60*5)
|
||||
tracker = WorkerTracker()
|
||||
|
||||
# Process pool for offloading cpu bound work, like calculating anchor texts
|
||||
process_pool = ProcessPoolExecutor()
|
||||
@ -229,11 +230,12 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
|
||||
return queue
|
||||
|
||||
|
||||
async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||
async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
|
||||
MAX_RETRIES = 3
|
||||
|
||||
attempt = 0
|
||||
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started")
|
||||
|
||||
while attempt < MAX_RETRIES:
|
||||
query = await build_page_query(
|
||||
@ -255,6 +257,7 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
|
||||
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
|
||||
page_response = PageResponse(**model_response_json)
|
||||
|
||||
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "finished")
|
||||
return PageResult(
|
||||
pdf_s3_path,
|
||||
page_num,
|
||||
@ -282,8 +285,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
|
||||
logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.")
|
||||
raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts")
|
||||
|
||||
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "errored")
|
||||
|
||||
async def process_pdf(args, pdf_s3_path: str):
|
||||
async def process_pdf(args, worker_id: int, pdf_s3_path: str):
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
# TODO Switch to aioboto3 or something
|
||||
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
|
||||
@ -299,7 +303,7 @@ async def process_pdf(args, pdf_s3_path: str):
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600), connector=aiohttp.TCPConnector(limit=10)) as session:
|
||||
for page_num in range(1, num_pages + 1):
|
||||
# Create a task for each page
|
||||
task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num))
|
||||
task = asyncio.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
|
||||
page_tasks.append(task)
|
||||
|
||||
# Gather results from all page processing tasks
|
||||
@ -362,7 +366,7 @@ async def worker(args, queue, semaphore, worker_id):
|
||||
# Wait until allowed to proceed
|
||||
await semaphore.acquire()
|
||||
|
||||
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
||||
dolma_docs = await asyncio.gather(*[process_pdf(args, worker_id, pdf) for pdf in pdfs])
|
||||
dolma_docs = [doc for doc in dolma_docs if doc is not None]
|
||||
|
||||
# Write the Dolma documents to a local temporary file in JSONL format
|
||||
@ -501,11 +505,15 @@ async def sglang_server_ready():
|
||||
|
||||
raise Exception("sglang server did not become ready after waiting.")
|
||||
|
||||
|
||||
async def metrics_reporter():
|
||||
while True:
|
||||
logger.info(metrics)
|
||||
# Leading newlines preserve table formatting in logs
|
||||
logger.info("\n" + str(metrics))
|
||||
logger.info("\n" + str(await tracker.get_status_table()))
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
|
||||
|
@ -1,5 +1,8 @@
|
||||
import time
|
||||
import asyncio
|
||||
from collections import deque, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
class MetricsKeeper:
|
||||
def __init__(self, window=60*5):
|
||||
@ -47,20 +50,95 @@ class MetricsKeeper:
|
||||
Returns a formatted string of metrics showing tokens/sec since start and within the window.
|
||||
|
||||
Returns:
|
||||
str: Formatted metrics string.
|
||||
str: Formatted metrics string as a table.
|
||||
"""
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - self.start_time
|
||||
window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero
|
||||
|
||||
metrics_strings = []
|
||||
# Header
|
||||
header = f"{'Metric Name':<20} {'Lifetime (tokens/sec)':>25} {'Window (tokens/sec)':>25}"
|
||||
separator = "-" * len(header)
|
||||
lines = [header, separator]
|
||||
|
||||
# Sort metrics alphabetically for consistency
|
||||
for key in sorted(self.total_metrics.keys()):
|
||||
total = self.total_metrics[key]
|
||||
window = self.window_sum.get(key, 0)
|
||||
total_rate = total / elapsed_time if elapsed_time > 0 else 0
|
||||
window_rate = window / window_time if window_time > 0 else 0
|
||||
metrics_strings.append(
|
||||
f"{key}: {total_rate:.2f}/sec (last {int(window_time)}s: {window_rate:.2f}/sec)"
|
||||
)
|
||||
line = f"{key:<20} {total_rate:>25.2f} {window_rate:>25.2f}"
|
||||
lines.append(line)
|
||||
|
||||
return ", ".join(metrics_strings)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class WorkerTracker:
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the WorkerTracker with a default dictionary.
|
||||
Each worker ID maps to another dictionary that holds counts for each state.
|
||||
"""
|
||||
# Mapping from worker_id to a dictionary of state counts
|
||||
self.worker_status: Dict[int, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def track_work(self, worker_id: int, work_item_id: str, state: str):
|
||||
"""
|
||||
Update the state count for a specific worker.
|
||||
|
||||
Args:
|
||||
worker_id (int): The ID of the worker.
|
||||
work_item_id (str): The unique identifier of the work item (unused in this implementation).
|
||||
state (str): The state to increment for the work item.
|
||||
"""
|
||||
async with self.lock:
|
||||
self.worker_status[worker_id][state] += 1
|
||||
logger.debug(f"Worker {worker_id} - State '{state}' incremented to {self.worker_status[worker_id][state]}.")
|
||||
|
||||
async def get_status_table(self) -> str:
|
||||
"""
|
||||
Generate a formatted table of the current status of all workers.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the workers' statuses.
|
||||
"""
|
||||
async with self.lock:
|
||||
# Determine all unique states across all workers
|
||||
all_states = set()
|
||||
for states in self.worker_status.values():
|
||||
all_states.update(states.keys())
|
||||
all_states = sorted(all_states)
|
||||
|
||||
headers = ["Worker ID"] + all_states
|
||||
rows = []
|
||||
for worker_id, states in sorted(self.worker_status.items()):
|
||||
row = [str(worker_id)]
|
||||
for state in all_states:
|
||||
count = states.get(state, 0)
|
||||
row.append(str(count))
|
||||
rows.append(row)
|
||||
|
||||
# Calculate column widths
|
||||
col_widths = [len(header) for header in headers]
|
||||
for row in rows:
|
||||
for idx, cell in enumerate(row):
|
||||
col_widths[idx] = max(col_widths[idx], len(cell))
|
||||
|
||||
# Create the table header
|
||||
header_line = " | ".join(header.ljust(col_widths[idx]) for idx, header in enumerate(headers))
|
||||
separator = "-+-".join('-' * col_widths[idx] for idx in range(len(headers)))
|
||||
|
||||
# Create the table rows
|
||||
row_lines = [" | ".join(cell.ljust(col_widths[idx]) for idx, cell in enumerate(row)) for row in rows]
|
||||
|
||||
# Combine all parts
|
||||
table = "\n".join([header_line, separator] + row_lines)
|
||||
return table
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
String representation is not directly supported.
|
||||
Use 'await get_status_table()' to retrieve the status table.
|
||||
"""
|
||||
raise NotImplementedError("Use 'await get_status_table()' to get the status table.")
|
Loading…
x
Reference in New Issue
Block a user