Merge pull request #138 from xcvil/sglang_server

feat: avoid sglang server starting with empty queue
This commit is contained in:
Jake Poznanski 2025-03-28 11:45:46 -07:00 committed by GitHub
commit 0892b1829b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 4 deletions

View File

@ -1051,8 +1051,11 @@ async def main():
await download_model(args.model) await download_model(args.model)
# Initialize the work queue # Initialize the work queue
await work_queue.initialize_queue() qsize = await work_queue.initialize_queue()
if qsize == 0:
logger.info("No work to do, exiting")
return
# Create a semaphore to control worker access # Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue # We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible

View File

@ -45,7 +45,7 @@ class WorkQueue(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def initialize_queue(self) -> None: async def initialize_queue(self) -> int:
""" """
Load the work queue from the relevant store (local or remote) Load the work queue from the relevant store (local or remote)
and initialize it for processing. and initialize it for processing.
@ -255,7 +255,7 @@ class LocalWorkQueue(WorkQueue):
# Write the combined data back to disk in zstd CSV format # Write the combined data back to disk in zstd CSV format
await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines) await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines)
async def initialize_queue(self) -> None: async def initialize_queue(self) -> int:
""" """
Load the work queue from the local index file and initialize it for processing. Load the work queue from the local index file and initialize it for processing.
Removes already completed work items and randomizes the order. Removes already completed work items and randomizes the order.
@ -282,6 +282,8 @@ class LocalWorkQueue(WorkQueue):
logger.info(f"Initialized local queue with {self._queue.qsize()} work items") logger.info(f"Initialized local queue with {self._queue.qsize()} work items")
return self._queue.qsize()
async def is_completed(self, work_hash: str) -> bool: async def is_completed(self, work_hash: str) -> bool:
""" """
Check if a work item has been completed locally by seeing if Check if a work item has been completed locally by seeing if
@ -459,7 +461,7 @@ class S3WorkQueue(WorkQueue):
if new_groups: if new_groups:
await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines) await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines)
async def initialize_queue(self) -> None: async def initialize_queue(self) -> int:
""" """
Load the work queue from S3 and initialize it for processing. Load the work queue from S3 and initialize it for processing.
Removes already completed work items and randomizes the order. Removes already completed work items and randomizes the order.
@ -492,6 +494,8 @@ class S3WorkQueue(WorkQueue):
logger.info(f"Initialized queue with {self._queue.qsize()} work items") logger.info(f"Initialized queue with {self._queue.qsize()} work items")
return self._queue.qsize()
async def is_completed(self, work_hash: str) -> bool: async def is_completed(self, work_hash: str) -> bool:
""" """
Check if a work item has been completed. Check if a work item has been completed.