Quicker results by limited workers via semaphore while still utilizing gpu

This commit is contained in:
Jake Poznanski 2024-11-12 08:18:22 -08:00
parent 615409568d
commit 4f2f4fda7d

View File

@ -14,6 +14,7 @@ import asyncio
import aiohttp import aiohttp
import datetime import datetime
import tempfile import tempfile
import re
from tqdm import tqdm from tqdm import tqdm
from io import BytesIO from io import BytesIO
@ -73,7 +74,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread) # Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim) image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
# GET ANCHOR TEXT IS NOT THREAD SAFE!! Ahhhh..... don't try to do it # GET ANCHOR TEXT IS NOT THREAD SAFE!! Ahhhh..... don't try to do it
# and it's also CPU bound, so it needs to run in a process pool # and it's also CPU bound, so it needs to run in a process pool
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -287,7 +288,7 @@ async def process_pdf(args, pdf_s3_path: str):
logger.exception(f"Could not load page for {pdf_s3_path}, aborting document") logger.exception(f"Could not load page for {pdf_s3_path}, aborting document")
return None return None
# Build the document text and page spans # Build the document text and page spans
document_text = "" document_text = ""
pdf_page_spans = [] pdf_page_spans = []
@ -332,11 +333,14 @@ async def process_pdf(args, pdf_s3_path: str):
return dolma_doc return dolma_doc
async def worker(args, queue): async def worker(args, queue, semaphore):
while True: while True:
[work_hash, pdfs] = await queue.get() [work_hash, pdfs] = await queue.get()
try: try:
# Wait until allowed to proceed
await semaphore.acquire()
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs]) dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None] dolma_docs = [doc for doc in dolma_docs if doc is not None]
@ -372,7 +376,7 @@ async def worker(args, queue):
queue.task_done() queue.task_done()
async def sglang_server_task(args): async def sglang_server_task(args, semaphore):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model') model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
# TODO cache locally # TODO cache locally
#download_directory(args.model, model_cache_dir) #download_directory(args.model, model_cache_dir)
@ -390,20 +394,53 @@ async def sglang_server_task(args):
proc = await asyncio.create_subprocess_exec( proc = await asyncio.create_subprocess_exec(
"python3", "python3",
"-m", "sglang.launch_server", "-m", "sglang.launch_server",
"--model-path", model_cache_dir, "--model-path", model_cache_dir,
"--chat-template", args.model_chat_template, "--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context), "--context-length", str(args.model_max_context),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
) )
# Make really sure we kill this subprocess on exit # Make sure we kill this subprocess on exit
def _kill_proc(): def _kill_proc():
proc.terminate() proc.terminate()
atexit.register(_kill_proc) atexit.register(_kill_proc)
last_queue_req = None # To track transitions
async def process_line(line):
# Parse the line and update semaphore if necessary
match = re.search(r'#running-req: (\d+), #queue-req: (\d+)', line)
if match:
logger.info(line)
running_req = int(match.group(1))
queue_req = int(match.group(2))
nonlocal last_queue_req
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0:
# Release the semaphore when queue_req transitions from non-zero to zero
if semaphore.locked():
semaphore.release()
logger.info("Semaphore released, allowing a worker to proceed.")
last_queue_req = queue_req
async def read_stream(stream):
while True:
line = await stream.readline()
if not line:
break
line = line.decode('utf-8').rstrip()
await process_line(line)
# Start tasks to read stdout and stderr
stdout_task = asyncio.create_task(read_stream(proc.stdout))
stderr_task = asyncio.create_task(read_stream(proc.stderr))
await proc.wait() await proc.wait()
await stdout_task
await stderr_task
async def sglang_server_ready(): async def sglang_server_ready():
@ -463,7 +500,13 @@ async def main():
if args.pdfs: if args.pdfs:
await populate_pdf_work_queue(args) await populate_pdf_work_queue(args)
sglang_server = asyncio.create_task(sglang_server_task(args)) # Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
sglang_server = asyncio.create_task(sglang_server_task(args, semaphore))
work_queue = await load_pdf_work_queue(args) work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items") logger.info(f"Work queue prepared with {work_queue.qsize()} items")
@ -473,7 +516,7 @@ async def main():
# Create worker tasks to process the queue concurrently. # Create worker tasks to process the queue concurrently.
worker_tasks = [] worker_tasks = []
for i in range(args.workers): for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue)) task = asyncio.create_task(worker(args, work_queue, semaphore))
worker_tasks.append(task) worker_tasks.append(task)
# Wait for the queue to be fully processed # Wait for the queue to be fully processed
@ -501,4 +544,3 @@ if __name__ == "__main__":
# TODO # TODO
# Possible future addon, in beaker, discover other nodes on this same job # Possible future addon, in beaker, discover other nodes on this same job
# Send them a message when you take a work item off the queue # Send them a message when you take a work item off the queue