mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-16 18:52:50 +00:00
Update suggested changes for qsize check
This commit is contained in:
parent
6d766307be
commit
ee687e25d6
@ -1049,36 +1049,36 @@ async def main():
|
||||
# Initialize the work queue
|
||||
qsize = await work_queue.initialize_queue()
|
||||
|
||||
if qsize > 0:
|
||||
# Create a semaphore to control worker access
|
||||
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
||||
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
|
||||
|
||||
await sglang_server_ready()
|
||||
|
||||
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
||||
|
||||
# Create worker tasks to process the queue concurrently.
|
||||
worker_tasks = []
|
||||
for i in range(args.workers):
|
||||
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
|
||||
worker_tasks.append(task)
|
||||
|
||||
# Wait for all worker tasks to finish
|
||||
await asyncio.gather(*worker_tasks)
|
||||
|
||||
# Wait for server to stop
|
||||
process_pool.shutdown(wait=False)
|
||||
|
||||
sglang_server.cancel()
|
||||
metrics_task.cancel()
|
||||
logger.info("Work done")
|
||||
else:
|
||||
if qsize == 0:
|
||||
logger.info("No work to do, exiting")
|
||||
return
|
||||
# Create a semaphore to control worker access
|
||||
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
||||
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
|
||||
|
||||
await sglang_server_ready()
|
||||
|
||||
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
||||
|
||||
# Create worker tasks to process the queue concurrently.
|
||||
worker_tasks = []
|
||||
for i in range(args.workers):
|
||||
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
|
||||
worker_tasks.append(task)
|
||||
|
||||
# Wait for all worker tasks to finish
|
||||
await asyncio.gather(*worker_tasks)
|
||||
|
||||
# Wait for server to stop
|
||||
process_pool.shutdown(wait=False)
|
||||
|
||||
sglang_server.cancel()
|
||||
metrics_task.cancel()
|
||||
logger.info("Work done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user