From ee687e25d6ed8f5f10e5d786d7a12ce6e80af64b Mon Sep 17 00:00:00 2001 From: Xiaochen Zheng Date: Thu, 27 Mar 2025 23:09:50 +0100 Subject: [PATCH] Update suggested changes for qsize check --- olmocr/pipeline.py | 58 +++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 9b022f9..c7f7073 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -1049,36 +1049,36 @@ async def main(): # Initialize the work queue qsize = await work_queue.initialize_queue() - if qsize > 0: - # Create a semaphore to control worker access - # We only allow one worker to move forward with requests, until the server has no more requests in its queue - # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible - # As soon as one worker is no longer saturating the gpu, the next one can start sending requests - semaphore = asyncio.Semaphore(1) - - sglang_server = asyncio.create_task(sglang_server_host(args, semaphore)) - - await sglang_server_ready() - - metrics_task = asyncio.create_task(metrics_reporter(work_queue)) - - # Create worker tasks to process the queue concurrently. - worker_tasks = [] - for i in range(args.workers): - task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i)) - worker_tasks.append(task) - - # Wait for all worker tasks to finish - await asyncio.gather(*worker_tasks) - - # Wait for server to stop - process_pool.shutdown(wait=False) - - sglang_server.cancel() - metrics_task.cancel() - logger.info("Work done") - else: + if qsize == 0: logger.info("No work to do, exiting") + return + # Create a semaphore to control worker access + # We only allow one worker to move forward with requests, until the server has no more requests in its queue + # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible + # As soon as one worker is no longer saturating the gpu, the next one can start sending requests + semaphore = asyncio.Semaphore(1) + + sglang_server = asyncio.create_task(sglang_server_host(args, semaphore)) + + await sglang_server_ready() + + metrics_task = asyncio.create_task(metrics_reporter(work_queue)) + + # Create worker tasks to process the queue concurrently. + worker_tasks = [] + for i in range(args.workers): + task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i)) + worker_tasks.append(task) + + # Wait for all worker tasks to finish + await asyncio.gather(*worker_tasks) + + # Wait for server to stop + process_pool.shutdown(wait=False) + + sglang_server.cancel() + metrics_task.cancel() + logger.info("Work done") if __name__ == "__main__":