From 6d766307be8320a8ef9cc8f0a5d9b4c3d49a61f0 Mon Sep 17 00:00:00 2001 From: xcvil Date: Sun, 23 Mar 2025 23:45:28 +0100 Subject: [PATCH 1/2] feat: avoid sglang server starting with empty queue --- olmocr/pipeline.py | 45 +++++++++++++++++++++++--------------------- olmocr/work_queue.py | 10 +++++++--- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index c4436af..9b022f9 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -1047,35 +1047,38 @@ async def main(): await download_model(args.model) # Initialize the work queue - await work_queue.initialize_queue() + qsize = await work_queue.initialize_queue() - # Create a semaphore to control worker access - # We only allow one worker to move forward with requests, until the server has no more requests in its queue - # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible - # As soon as one worker is no longer saturating the gpu, the next one can start sending requests - semaphore = asyncio.Semaphore(1) + if qsize > 0: + # Create a semaphore to control worker access + # We only allow one worker to move forward with requests, until the server has no more requests in its queue + # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible + # As soon as one worker is no longer saturating the gpu, the next one can start sending requests + semaphore = asyncio.Semaphore(1) - sglang_server = asyncio.create_task(sglang_server_host(args, semaphore)) + sglang_server = asyncio.create_task(sglang_server_host(args, semaphore)) - await sglang_server_ready() + await sglang_server_ready() - metrics_task = asyncio.create_task(metrics_reporter(work_queue)) + metrics_task = asyncio.create_task(metrics_reporter(work_queue)) - # Create worker tasks to process the queue concurrently. - worker_tasks = [] - for i in range(args.workers): - task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i)) - worker_tasks.append(task) + # Create worker tasks to process the queue concurrently. + worker_tasks = [] + for i in range(args.workers): + task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i)) + worker_tasks.append(task) - # Wait for all worker tasks to finish - await asyncio.gather(*worker_tasks) + # Wait for all worker tasks to finish + await asyncio.gather(*worker_tasks) - # Wait for server to stop - process_pool.shutdown(wait=False) + # Wait for server to stop + process_pool.shutdown(wait=False) - sglang_server.cancel() - metrics_task.cancel() - logger.info("Work done") + sglang_server.cancel() + metrics_task.cancel() + logger.info("Work done") + else: + logger.info("No work to do, exiting") if __name__ == "__main__": diff --git a/olmocr/work_queue.py b/olmocr/work_queue.py index 8d6be16..32e0903 100644 --- a/olmocr/work_queue.py +++ b/olmocr/work_queue.py @@ -45,7 +45,7 @@ class WorkQueue(abc.ABC): pass @abc.abstractmethod - async def initialize_queue(self) -> None: + async def initialize_queue(self) -> int: """ Load the work queue from the relevant store (local or remote) and initialize it for processing. @@ -255,7 +255,7 @@ class LocalWorkQueue(WorkQueue): # Write the combined data back to disk in zstd CSV format await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines) - async def initialize_queue(self) -> None: + async def initialize_queue(self) -> int: """ Load the work queue from the local index file and initialize it for processing. Removes already completed work items and randomizes the order. @@ -282,6 +282,8 @@ class LocalWorkQueue(WorkQueue): logger.info(f"Initialized local queue with {self._queue.qsize()} work items") + return self._queue.qsize() + async def is_completed(self, work_hash: str) -> bool: """ Check if a work item has been completed locally by seeing if @@ -459,7 +461,7 @@ class S3WorkQueue(WorkQueue): if new_groups: await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines) - async def initialize_queue(self) -> None: + async def initialize_queue(self) -> int: """ Load the work queue from S3 and initialize it for processing. Removes already completed work items and randomizes the order. @@ -492,6 +494,8 @@ class S3WorkQueue(WorkQueue): logger.info(f"Initialized queue with {self._queue.qsize()} work items") + return self._queue.qsize() + async def is_completed(self, work_hash: str) -> bool: """ Check if a work item has been completed. From ee687e25d6ed8f5f10e5d786d7a12ce6e80af64b Mon Sep 17 00:00:00 2001 From: Xiaochen Zheng Date: Thu, 27 Mar 2025 23:09:50 +0100 Subject: [PATCH 2/2] Update suggested changes for qsize check --- olmocr/pipeline.py | 58 +++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 9b022f9..c7f7073 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -1049,36 +1049,36 @@ async def main(): # Initialize the work queue qsize = await work_queue.initialize_queue() - if qsize > 0: - # Create a semaphore to control worker access - # We only allow one worker to move forward with requests, until the server has no more requests in its queue - # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible - # As soon as one worker is no longer saturating the gpu, the next one can start sending requests - semaphore = asyncio.Semaphore(1) - - sglang_server = asyncio.create_task(sglang_server_host(args, semaphore)) - - await sglang_server_ready() - - metrics_task = asyncio.create_task(metrics_reporter(work_queue)) - - # Create worker tasks to process the queue concurrently. - worker_tasks = [] - for i in range(args.workers): - task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i)) - worker_tasks.append(task) - - # Wait for all worker tasks to finish - await asyncio.gather(*worker_tasks) - - # Wait for server to stop - process_pool.shutdown(wait=False) - - sglang_server.cancel() - metrics_task.cancel() - logger.info("Work done") - else: + if qsize == 0: logger.info("No work to do, exiting") + return + # Create a semaphore to control worker access + # We only allow one worker to move forward with requests, until the server has no more requests in its queue + # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible + # As soon as one worker is no longer saturating the gpu, the next one can start sending requests + semaphore = asyncio.Semaphore(1) + + sglang_server = asyncio.create_task(sglang_server_host(args, semaphore)) + + await sglang_server_ready() + + metrics_task = asyncio.create_task(metrics_reporter(work_queue)) + + # Create worker tasks to process the queue concurrently. + worker_tasks = [] + for i in range(args.workers): + task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i)) + worker_tasks.append(task) + + # Wait for all worker tasks to finish + await asyncio.gather(*worker_tasks) + + # Wait for server to stop + process_pool.shutdown(wait=False) + + sglang_server.cancel() + metrics_task.cancel() + logger.info("Work done") if __name__ == "__main__":