diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 933a1ff..e8b6056 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -669,6 +669,10 @@ async def vllm_server_task(model_name_or_path, args, semaphore): except asyncio.CancelledError: logger.info("Got cancellation request for VLLM server") proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=10.0) + except asyncio.TimeoutError: + logger.warning("VLLM server did not terminate within 10 seconds") raise timeout_task.cancel() @@ -1209,6 +1213,9 @@ async def main(): vllm_server.cancel() metrics_task.cancel() + # Wait for cancelled tasks to complete + await asyncio.gather(vllm_server, metrics_task, return_exceptions=True) + # Output final metrics summary metrics_summary = metrics.get_metrics_summary() logger.info("=" * 80) @@ -1233,7 +1240,8 @@ async def main(): ) # Output finished_on_attempt statistics - logger.info("\nPages finished by attempt number:") + logger.info("") + logger.info("Pages finished by attempt number:") total_finished = sum(total_metrics.get(f"finished_on_attempt_{i}", 0) for i in range(args.max_page_retries)) cumulative = 0