diff --git a/Dockerfile b/Dockerfile index 1404bef..a9e69bd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,6 +63,9 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" RUN uv pip install --system --no-cache -e . RUN uv pip install --system --no-cache ".[gpu]" --extra-index-url https://download.pytorch.org/whl/cu128 + +# TODO Try this and measure performance on it +#RUN uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl RUN uv pip install --system --no-cache ".[bench]" RUN playwright install-deps RUN playwright install chromium diff --git a/olmocr/metrics.py b/olmocr/metrics.py index 6795fd2..d95cb37 100644 --- a/olmocr/metrics.py +++ b/olmocr/metrics.py @@ -72,6 +72,38 @@ class MetricsKeeper: return "\n".join(lines) + def get_total_metrics(self): + """ + Returns the total cumulative metrics since the MetricsKeeper was created. + + Returns: + dict: Dictionary of metric names to their total values. + """ + return dict(self.total_metrics) + + def get_metrics_summary(self): + """ + Returns a summary of metrics including totals and rates. + + Returns: + dict: Dictionary containing total metrics and overall rates. + """ + current_time = time.time() + elapsed_time = current_time - self.start_time + + summary = { + "elapsed_time_seconds": elapsed_time, + "total_metrics": dict(self.total_metrics), + "rates": {} + } + + # Calculate rates for each metric + if elapsed_time > 0: + for key, value in self.total_metrics.items(): + summary["rates"][f"{key}_per_sec"] = value / elapsed_time + + return summary + class WorkerTracker: def __init__(self): diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index f9b9e6b..9ff75d0 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -607,7 +607,7 @@ async def vllm_server_task(model_name_or_path, args, semaphore): logger.error("Cannot continue, sampling errors detected, model is probably corrupt") sys.exit(1) - if not server_printed_ready_message and "The server is fired up and ready to roll!" in line: + if not server_printed_ready_message and ("The server is fired up and ready to roll!" in line or "vllm server is ready" in line): server_printed_ready_message = True last_semaphore_release = time.time() @@ -740,7 +740,8 @@ def submit_beaker_job(args): b = Beaker.from_env(default_workspace=args.beaker_workspace) account = b.account.whoami() owner = account.name - beaker_image = f"jakep/olmocr-inference-{VERSION}" + #beaker_image = f"jakep/olmocr-inference-{VERSION}" + beaker_image = "jakep/olmocr-benchmark-0.1.71-d71703317d" task_name = f"olmocr-{os.path.basename(args.workspace.rstrip('/'))}" @@ -1163,6 +1164,37 @@ async def main(): vllm_server.cancel() metrics_task.cancel() + + # Output final metrics summary + metrics_summary = metrics.get_metrics_summary() + logger.info("=" * 80) + logger.info("FINAL METRICS SUMMARY") + logger.info("=" * 80) + logger.info(f"Total elapsed time: {metrics_summary['elapsed_time_seconds']:.2f} seconds") + + # Output token counts and rates + total_metrics = metrics_summary['total_metrics'] + rates = metrics_summary['rates'] + + # Calculate total tokens (input + output) + total_tokens = total_metrics.get('server_input_tokens', 0) + total_metrics.get('server_output_tokens', 0) + total_finished_tokens = total_metrics.get('finished_input_tokens', 0) + total_metrics.get('finished_output_tokens', 0) + + logger.info(f"Total tokens processed: {total_tokens:,}") + logger.info(f" - Input tokens: {total_metrics.get('server_input_tokens', 0):,}") + logger.info(f" - Output tokens: {total_metrics.get('server_output_tokens', 0):,}") + + logger.info(f"Total tokens in finished documents: {total_finished_tokens:,}") + logger.info(f" - Finished input tokens: {total_metrics.get('finished_input_tokens', 0):,}") + logger.info(f" - Finished output tokens: {total_metrics.get('finished_output_tokens', 0):,}") + + # Output rates + if 'server_output_tokens_per_sec' in rates: + logger.info(f"Output tokens/sec rate: {rates['server_output_tokens_per_sec']:.2f}") + if 'server_input_tokens_per_sec' in rates: + logger.info(f"Input tokens/sec rate: {rates['server_input_tokens_per_sec']:.2f}") + + logger.info("=" * 80) logger.info("Work done")