Better stats

This commit is contained in:
Jake Poznanski 2024-11-12 13:28:39 -08:00
parent 9ce28c0504
commit ae9b1c405d
2 changed files with 98 additions and 12 deletions

View File

@ -30,7 +30,7 @@ from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt, PageResponse from pdelfin.prompts import build_finetuning_prompt, PageResponse
from pdelfin.prompts.anchor import get_anchor_text from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.check import check_poppler_version from pdelfin.check import check_poppler_version
from pdelfin.metrics import MetricsKeeper from pdelfin.metrics import MetricsKeeper, WorkerTracker
# Initialize logger # Initialize logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,6 +62,7 @@ pdf_s3 = boto3.client('s3')
# Global variables for token statistics # Global variables for token statistics
metrics = MetricsKeeper(window=60*5) metrics = MetricsKeeper(window=60*5)
tracker = WorkerTracker()
# Process pool for offloading cpu bound work, like calculating anchor texts # Process pool for offloading cpu bound work, like calculating anchor texts
process_pool = ProcessPoolExecutor() process_pool = ProcessPoolExecutor()
@ -229,11 +230,12 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
return queue return queue
async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult: async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = "http://localhost:30000/v1/chat/completions" COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
MAX_RETRIES = 3 MAX_RETRIES = 3
attempt = 0 attempt = 0
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started")
while attempt < MAX_RETRIES: while attempt < MAX_RETRIES:
query = await build_page_query( query = await build_page_query(
@ -255,6 +257,7 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"]) model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json) page_response = PageResponse(**model_response_json)
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "finished")
return PageResult( return PageResult(
pdf_s3_path, pdf_s3_path,
page_num, page_num,
@ -282,8 +285,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.") logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.")
raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts") raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts")
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "errored")
async def process_pdf(args, pdf_s3_path: str): async def process_pdf(args, worker_id: int, pdf_s3_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something # TODO Switch to aioboto3 or something
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path)) data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
@ -299,7 +303,7 @@ async def process_pdf(args, pdf_s3_path: str):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600), connector=aiohttp.TCPConnector(limit=10)) as session: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600), connector=aiohttp.TCPConnector(limit=10)) as session:
for page_num in range(1, num_pages + 1): for page_num in range(1, num_pages + 1):
# Create a task for each page # Create a task for each page
task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num)) task = asyncio.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task) page_tasks.append(task)
# Gather results from all page processing tasks # Gather results from all page processing tasks
@ -362,7 +366,7 @@ async def worker(args, queue, semaphore, worker_id):
# Wait until allowed to proceed # Wait until allowed to proceed
await semaphore.acquire() await semaphore.acquire()
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs]) dolma_docs = await asyncio.gather(*[process_pdf(args, worker_id, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None] dolma_docs = [doc for doc in dolma_docs if doc is not None]
# Write the Dolma documents to a local temporary file in JSONL format # Write the Dolma documents to a local temporary file in JSONL format
@ -501,11 +505,15 @@ async def sglang_server_ready():
raise Exception("sglang server did not become ready after waiting.") raise Exception("sglang server did not become ready after waiting.")
async def metrics_reporter(): async def metrics_reporter():
while True: while True:
logger.info(metrics) # Leading newlines preserve table formatting in logs
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
await asyncio.sleep(10) await asyncio.sleep(10)
async def main(): async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')

View File

@ -1,5 +1,8 @@
import time import time
import asyncio
from collections import deque, defaultdict from collections import deque, defaultdict
from dataclasses import dataclass, field
from typing import Dict
class MetricsKeeper: class MetricsKeeper:
def __init__(self, window=60*5): def __init__(self, window=60*5):
@ -47,20 +50,95 @@ class MetricsKeeper:
Returns a formatted string of metrics showing tokens/sec since start and within the window. Returns a formatted string of metrics showing tokens/sec since start and within the window.
Returns: Returns:
str: Formatted metrics string. str: Formatted metrics string as a table.
""" """
current_time = time.time() current_time = time.time()
elapsed_time = current_time - self.start_time elapsed_time = current_time - self.start_time
window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero
metrics_strings = [] # Header
header = f"{'Metric Name':<20} {'Lifetime (tokens/sec)':>25} {'Window (tokens/sec)':>25}"
separator = "-" * len(header)
lines = [header, separator]
# Sort metrics alphabetically for consistency
for key in sorted(self.total_metrics.keys()): for key in sorted(self.total_metrics.keys()):
total = self.total_metrics[key] total = self.total_metrics[key]
window = self.window_sum.get(key, 0) window = self.window_sum.get(key, 0)
total_rate = total / elapsed_time if elapsed_time > 0 else 0 total_rate = total / elapsed_time if elapsed_time > 0 else 0
window_rate = window / window_time if window_time > 0 else 0 window_rate = window / window_time if window_time > 0 else 0
metrics_strings.append( line = f"{key:<20} {total_rate:>25.2f} {window_rate:>25.2f}"
f"{key}: {total_rate:.2f}/sec (last {int(window_time)}s: {window_rate:.2f}/sec)" lines.append(line)
)
return ", ".join(metrics_strings) return "\n".join(lines)
class WorkerTracker:
def __init__(self):
"""
Initializes the WorkerTracker with a default dictionary.
Each worker ID maps to another dictionary that holds counts for each state.
"""
# Mapping from worker_id to a dictionary of state counts
self.worker_status: Dict[int, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
self.lock = asyncio.Lock()
async def track_work(self, worker_id: int, work_item_id: str, state: str):
"""
Update the state count for a specific worker.
Args:
worker_id (int): The ID of the worker.
work_item_id (str): The unique identifier of the work item (unused in this implementation).
state (str): The state to increment for the work item.
"""
async with self.lock:
self.worker_status[worker_id][state] += 1
logger.debug(f"Worker {worker_id} - State '{state}' incremented to {self.worker_status[worker_id][state]}.")
async def get_status_table(self) -> str:
"""
Generate a formatted table of the current status of all workers.
Returns:
str: A string representation of the workers' statuses.
"""
async with self.lock:
# Determine all unique states across all workers
all_states = set()
for states in self.worker_status.values():
all_states.update(states.keys())
all_states = sorted(all_states)
headers = ["Worker ID"] + all_states
rows = []
for worker_id, states in sorted(self.worker_status.items()):
row = [str(worker_id)]
for state in all_states:
count = states.get(state, 0)
row.append(str(count))
rows.append(row)
# Calculate column widths
col_widths = [len(header) for header in headers]
for row in rows:
for idx, cell in enumerate(row):
col_widths[idx] = max(col_widths[idx], len(cell))
# Create the table header
header_line = " | ".join(header.ljust(col_widths[idx]) for idx, header in enumerate(headers))
separator = "-+-".join('-' * col_widths[idx] for idx in range(len(headers)))
# Create the table rows
row_lines = [" | ".join(cell.ljust(col_widths[idx]) for idx, cell in enumerate(row)) for row in rows]
# Combine all parts
table = "\n".join([header_line, separator] + row_lines)
return table
def __str__(self):
"""
String representation is not directly supported.
Use 'await get_status_table()' to retrieve the status table.
"""
raise NotImplementedError("Use 'await get_status_table()' to get the status table.")