Quicker results by limited workers via semaphore while still utilizing gpu

This commit is contained in:
Jake Poznanski 2024-11-12 08:18:22 -08:00
parent 615409568d
commit 4f2f4fda7d

View File

@ -14,6 +14,7 @@ import asyncio
import aiohttp
import datetime
import tempfile
import re
from tqdm import tqdm
from io import BytesIO
@ -332,11 +333,14 @@ async def process_pdf(args, pdf_s3_path: str):
return dolma_doc
async def worker(args, queue):
async def worker(args, queue, semaphore):
while True:
[work_hash, pdfs] = await queue.get()
try:
# Wait until allowed to proceed
await semaphore.acquire()
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None]
@ -372,7 +376,7 @@ async def worker(args, queue):
queue.task_done()
async def sglang_server_task(args):
async def sglang_server_task(args, semaphore):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
# TODO cache locally
#download_directory(args.model, model_cache_dir)
@ -390,20 +394,53 @@ async def sglang_server_task(args):
proc = await asyncio.create_subprocess_exec(
"python3",
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# Make really sure we kill this subprocess on exit
# Make sure we kill this subprocess on exit
def _kill_proc():
proc.terminate()
atexit.register(_kill_proc)
last_queue_req = None # To track transitions
async def process_line(line):
# Parse the line and update semaphore if necessary
match = re.search(r'#running-req: (\d+), #queue-req: (\d+)', line)
if match:
logger.info(line)
running_req = int(match.group(1))
queue_req = int(match.group(2))
nonlocal last_queue_req
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0:
# Release the semaphore when queue_req transitions from non-zero to zero
if semaphore.locked():
semaphore.release()
logger.info("Semaphore released, allowing a worker to proceed.")
last_queue_req = queue_req
async def read_stream(stream):
while True:
line = await stream.readline()
if not line:
break
line = line.decode('utf-8').rstrip()
await process_line(line)
# Start tasks to read stdout and stderr
stdout_task = asyncio.create_task(read_stream(proc.stdout))
stderr_task = asyncio.create_task(read_stream(proc.stderr))
await proc.wait()
await stdout_task
await stderr_task
async def sglang_server_ready():
@ -463,7 +500,13 @@ async def main():
if args.pdfs:
await populate_pdf_work_queue(args)
sglang_server = asyncio.create_task(sglang_server_task(args))
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
sglang_server = asyncio.create_task(sglang_server_task(args, semaphore))
work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
@ -473,7 +516,7 @@ async def main():
# Create worker tasks to process the queue concurrently.
worker_tasks = []
for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue))
task = asyncio.create_task(worker(args, work_queue, semaphore))
worker_tasks.append(task)
# Wait for the queue to be fully processed
@ -501,4 +544,3 @@ if __name__ == "__main__":
# TODO
# Possible future addon, in beaker, discover other nodes on this same job
# Send them a message when you take a work item off the queue