mirror of
https://github.com/allenai/olmocr.git
synced 2025-07-04 23:52:58 +00:00
148 lines
5.7 KiB
Python
148 lines
5.7 KiB
Python
import asyncio
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from typing import Any, Deque, Dict, List, Set
|
|
|
|
|
|
class MetricsKeeper:
|
|
def __init__(self, window=60 * 5):
|
|
"""
|
|
Initializes the MetricsKeeper.
|
|
|
|
Args:
|
|
window (int): Time window in seconds for recent metrics. Defaults to 5 minutes.
|
|
"""
|
|
self.window = window # Time window in seconds
|
|
self.start_time = time.time() # Timestamp when MetricsKeeper was created
|
|
self.total_metrics = defaultdict(int) # Cumulative metrics since start
|
|
self.window_metrics: Deque[Any] = deque() # Deque to store (timestamp, metrics_dict)
|
|
self.window_sum = defaultdict(int) # Sum of metrics within the window
|
|
|
|
def add_metrics(self, **kwargs):
|
|
"""
|
|
Adds metrics to the keeper.
|
|
|
|
Args:
|
|
**kwargs: Arbitrary keyword arguments representing metric names and their values.
|
|
"""
|
|
current_time = time.time()
|
|
# Update cumulative metrics
|
|
for key, value in kwargs.items():
|
|
self.total_metrics[key] += value
|
|
|
|
# Append current metrics with timestamp to the deque
|
|
self.window_metrics.append((current_time, kwargs))
|
|
|
|
# Update window sums
|
|
for key, value in kwargs.items():
|
|
self.window_sum[key] += value
|
|
|
|
# Remove metrics that are outside the time window
|
|
while self.window_metrics and self.window_metrics[0][0] < current_time - self.window:
|
|
old_time, old_metrics = self.window_metrics.popleft()
|
|
for key, value in old_metrics.items():
|
|
self.window_sum[key] -= value
|
|
if self.window_sum[key] <= 0:
|
|
del self.window_sum[key] # Clean up to prevent negative counts
|
|
|
|
def __str__(self):
|
|
"""
|
|
Returns a formatted string of metrics showing tokens/sec since start and within the window.
|
|
|
|
Returns:
|
|
str: Formatted metrics string as a table.
|
|
"""
|
|
current_time = time.time()
|
|
elapsed_time = current_time - self.start_time
|
|
window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero
|
|
|
|
# Header
|
|
header = f"{'Metric Name':<30} {'Lifetime (tokens/sec)':>25} {'Recently (tokens/sec)':>25}"
|
|
separator = "-" * len(header)
|
|
lines = [header, separator]
|
|
|
|
# Sort metrics alphabetically for consistency
|
|
for key in sorted(self.total_metrics.keys()):
|
|
total = self.total_metrics[key]
|
|
window = self.window_sum.get(key, 0)
|
|
total_rate = total / elapsed_time if elapsed_time > 0 else 0
|
|
window_rate = window / window_time if window_time > 0 else 0
|
|
line = f"{key:<20} {total_rate:>25.2f} {window_rate:>25.2f}"
|
|
lines.append(line)
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
class WorkerTracker:
|
|
def __init__(self):
|
|
"""
|
|
Initializes the WorkerTracker with a default dictionary.
|
|
Each worker ID maps to another dictionary that holds counts for each state.
|
|
"""
|
|
# Mapping from worker_id to a dictionary of state counts
|
|
self.worker_status: Dict[int, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
|
self.lock = asyncio.Lock()
|
|
|
|
async def clear_work(self, worker_id: int):
|
|
async with self.lock:
|
|
self.worker_status[worker_id].clear()
|
|
|
|
async def track_work(self, worker_id: int, work_item_id: str, state: str):
|
|
"""
|
|
Update the state count for a specific worker.
|
|
|
|
Args:
|
|
worker_id (int): The ID of the worker.
|
|
work_item_id (str): The unique identifier of the work item (unused in this implementation).
|
|
state (str): The state to increment for the work item.
|
|
"""
|
|
async with self.lock:
|
|
self.worker_status[worker_id][state] += 1
|
|
|
|
async def get_status_table(self) -> str:
|
|
"""
|
|
Generate a formatted table of the current status of all workers.
|
|
|
|
Returns:
|
|
str: A string representation of the workers' statuses.
|
|
"""
|
|
async with self.lock:
|
|
# Determine all unique states across all workers
|
|
all_states: Set[str] = set()
|
|
for states in self.worker_status.values():
|
|
all_states.update(states.keys())
|
|
sorted_states: List[str] = sorted(all_states)
|
|
|
|
headers = ["Worker ID"] + sorted_states # type: ignore
|
|
rows = []
|
|
for worker_id, states in sorted(self.worker_status.items()):
|
|
row = [str(worker_id)]
|
|
for state in sorted_states:
|
|
count = states.get(state, 0)
|
|
row.append(str(count))
|
|
rows.append(row)
|
|
|
|
# Calculate column widths
|
|
col_widths = [len(header) for header in headers]
|
|
for row in rows:
|
|
for idx, cell in enumerate(row):
|
|
col_widths[idx] = max(col_widths[idx], len(cell))
|
|
|
|
# Create the table header
|
|
header_line = " | ".join(header.ljust(col_widths[idx]) for idx, header in enumerate(headers))
|
|
separator = "-+-".join("-" * col_widths[idx] for idx in range(len(headers)))
|
|
|
|
# Create the table rows
|
|
row_lines = [" | ".join(cell.ljust(col_widths[idx]) for idx, cell in enumerate(row)) for row in rows]
|
|
|
|
# Combine all parts
|
|
table = "\n".join([header_line, separator] + row_lines)
|
|
return table
|
|
|
|
def __str__(self):
|
|
"""
|
|
String representation is not directly supported.
|
|
Use 'await get_status_table()' to retrieve the status table.
|
|
"""
|
|
raise NotImplementedError("Use 'await get_status_table()' to get the status table.")
|