From ae9b1c405dafdcfca47178f68f738238fb9cfe0f Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 12 Nov 2024 13:28:39 -0800 Subject: [PATCH] Better stats --- pdelfin/beakerpipeline.py | 20 ++++++--- pdelfin/metrics.py | 90 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 98 insertions(+), 12 deletions(-) diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index b5bfab1..8f3006b 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -30,7 +30,7 @@ from pdelfin.data.renderpdf import render_pdf_to_base64png from pdelfin.prompts import build_finetuning_prompt, PageResponse from pdelfin.prompts.anchor import get_anchor_text from pdelfin.check import check_poppler_version -from pdelfin.metrics import MetricsKeeper +from pdelfin.metrics import MetricsKeeper, WorkerTracker # Initialize logger logger = logging.getLogger(__name__) @@ -62,6 +62,7 @@ pdf_s3 = boto3.client('s3') # Global variables for token statistics metrics = MetricsKeeper(window=60*5) +tracker = WorkerTracker() # Process pool for offloading cpu bound work, like calculating anchor texts process_pool = ProcessPoolExecutor() @@ -229,11 +230,12 @@ async def load_pdf_work_queue(args) -> asyncio.Queue: return queue -async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult: +async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult: COMPLETION_URL = "http://localhost:30000/v1/chat/completions" MAX_RETRIES = 3 attempt = 0 + await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started") while attempt < MAX_RETRIES: query = await build_page_query( @@ -255,6 +257,7 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"]) page_response = PageResponse(**model_response_json) + await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "finished") return PageResult( pdf_s3_path, page_num, @@ -282,8 +285,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.") raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts") + await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "errored") -async def process_pdf(args, pdf_s3_path: str): +async def process_pdf(args, worker_id: int, pdf_s3_path: str): with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: # TODO Switch to aioboto3 or something data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path)) @@ -299,7 +303,7 @@ async def process_pdf(args, pdf_s3_path: str): async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600), connector=aiohttp.TCPConnector(limit=10)) as session: for page_num in range(1, num_pages + 1): # Create a task for each page - task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num)) + task = asyncio.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num)) page_tasks.append(task) # Gather results from all page processing tasks @@ -362,7 +366,7 @@ async def worker(args, queue, semaphore, worker_id): # Wait until allowed to proceed await semaphore.acquire() - dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs]) + dolma_docs = await asyncio.gather(*[process_pdf(args, worker_id, pdf) for pdf in pdfs]) dolma_docs = [doc for doc in dolma_docs if doc is not None] # Write the Dolma documents to a local temporary file in JSONL format @@ -501,11 +505,15 @@ async def sglang_server_ready(): raise Exception("sglang server did not become ready after waiting.") + async def metrics_reporter(): while True: - logger.info(metrics) + # Leading newlines preserve table formatting in logs + logger.info("\n" + str(metrics)) + logger.info("\n" + str(await tracker.get_status_table())) await asyncio.sleep(10) + async def main(): parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') diff --git a/pdelfin/metrics.py b/pdelfin/metrics.py index 11a6f06..44e5609 100644 --- a/pdelfin/metrics.py +++ b/pdelfin/metrics.py @@ -1,5 +1,8 @@ import time +import asyncio from collections import deque, defaultdict +from dataclasses import dataclass, field +from typing import Dict class MetricsKeeper: def __init__(self, window=60*5): @@ -47,20 +50,95 @@ class MetricsKeeper: Returns a formatted string of metrics showing tokens/sec since start and within the window. Returns: - str: Formatted metrics string. + str: Formatted metrics string as a table. """ current_time = time.time() elapsed_time = current_time - self.start_time window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero - metrics_strings = [] + # Header + header = f"{'Metric Name':<20} {'Lifetime (tokens/sec)':>25} {'Window (tokens/sec)':>25}" + separator = "-" * len(header) + lines = [header, separator] + + # Sort metrics alphabetically for consistency for key in sorted(self.total_metrics.keys()): total = self.total_metrics[key] window = self.window_sum.get(key, 0) total_rate = total / elapsed_time if elapsed_time > 0 else 0 window_rate = window / window_time if window_time > 0 else 0 - metrics_strings.append( - f"{key}: {total_rate:.2f}/sec (last {int(window_time)}s: {window_rate:.2f}/sec)" - ) + line = f"{key:<20} {total_rate:>25.2f} {window_rate:>25.2f}" + lines.append(line) - return ", ".join(metrics_strings) + return "\n".join(lines) + + +class WorkerTracker: + def __init__(self): + """ + Initializes the WorkerTracker with a default dictionary. + Each worker ID maps to another dictionary that holds counts for each state. + """ + # Mapping from worker_id to a dictionary of state counts + self.worker_status: Dict[int, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) + self.lock = asyncio.Lock() + + async def track_work(self, worker_id: int, work_item_id: str, state: str): + """ + Update the state count for a specific worker. + + Args: + worker_id (int): The ID of the worker. + work_item_id (str): The unique identifier of the work item (unused in this implementation). + state (str): The state to increment for the work item. + """ + async with self.lock: + self.worker_status[worker_id][state] += 1 + logger.debug(f"Worker {worker_id} - State '{state}' incremented to {self.worker_status[worker_id][state]}.") + + async def get_status_table(self) -> str: + """ + Generate a formatted table of the current status of all workers. + + Returns: + str: A string representation of the workers' statuses. + """ + async with self.lock: + # Determine all unique states across all workers + all_states = set() + for states in self.worker_status.values(): + all_states.update(states.keys()) + all_states = sorted(all_states) + + headers = ["Worker ID"] + all_states + rows = [] + for worker_id, states in sorted(self.worker_status.items()): + row = [str(worker_id)] + for state in all_states: + count = states.get(state, 0) + row.append(str(count)) + rows.append(row) + + # Calculate column widths + col_widths = [len(header) for header in headers] + for row in rows: + for idx, cell in enumerate(row): + col_widths[idx] = max(col_widths[idx], len(cell)) + + # Create the table header + header_line = " | ".join(header.ljust(col_widths[idx]) for idx, header in enumerate(headers)) + separator = "-+-".join('-' * col_widths[idx] for idx in range(len(headers))) + + # Create the table rows + row_lines = [" | ".join(cell.ljust(col_widths[idx]) for idx, cell in enumerate(row)) for row in rows] + + # Combine all parts + table = "\n".join([header_line, separator] + row_lines) + return table + + def __str__(self): + """ + String representation is not directly supported. + Use 'await get_status_table()' to retrieve the status table. + """ + raise NotImplementedError("Use 'await get_status_table()' to get the status table.") \ No newline at end of file