mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-18 22:01:56 +00:00
Quicker results by limited workers via semaphore while still utilizing gpu
This commit is contained in:
parent
615409568d
commit
4f2f4fda7d
@ -14,6 +14,7 @@ import asyncio
|
||||
import aiohttp
|
||||
import datetime
|
||||
import tempfile
|
||||
import re
|
||||
|
||||
from tqdm import tqdm
|
||||
from io import BytesIO
|
||||
@ -73,7 +74,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
||||
|
||||
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
|
||||
image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
|
||||
# GET ANCHOR TEXT IS NOT THREAD SAFE!! Ahhhh..... don't try to do it
|
||||
# and it's also CPU bound, so it needs to run in a process pool
|
||||
loop = asyncio.get_running_loop()
|
||||
@ -287,7 +288,7 @@ async def process_pdf(args, pdf_s3_path: str):
|
||||
logger.exception(f"Could not load page for {pdf_s3_path}, aborting document")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
# Build the document text and page spans
|
||||
document_text = ""
|
||||
pdf_page_spans = []
|
||||
@ -332,11 +333,14 @@ async def process_pdf(args, pdf_s3_path: str):
|
||||
return dolma_doc
|
||||
|
||||
|
||||
async def worker(args, queue):
|
||||
async def worker(args, queue, semaphore):
|
||||
while True:
|
||||
[work_hash, pdfs] = await queue.get()
|
||||
|
||||
try:
|
||||
# Wait until allowed to proceed
|
||||
await semaphore.acquire()
|
||||
|
||||
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
||||
dolma_docs = [doc for doc in dolma_docs if doc is not None]
|
||||
|
||||
@ -372,7 +376,7 @@ async def worker(args, queue):
|
||||
queue.task_done()
|
||||
|
||||
|
||||
async def sglang_server_task(args):
|
||||
async def sglang_server_task(args, semaphore):
|
||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||
# TODO cache locally
|
||||
#download_directory(args.model, model_cache_dir)
|
||||
@ -390,20 +394,53 @@ async def sglang_server_task(args):
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"python3",
|
||||
|
||||
"-m", "sglang.launch_server",
|
||||
"--model-path", model_cache_dir,
|
||||
"--chat-template", args.model_chat_template,
|
||||
"--context-length", str(args.model_max_context),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
# Make really sure we kill this subprocess on exit
|
||||
# Make sure we kill this subprocess on exit
|
||||
def _kill_proc():
|
||||
proc.terminate()
|
||||
|
||||
atexit.register(_kill_proc)
|
||||
|
||||
last_queue_req = None # To track transitions
|
||||
async def process_line(line):
|
||||
# Parse the line and update semaphore if necessary
|
||||
match = re.search(r'#running-req: (\d+), #queue-req: (\d+)', line)
|
||||
if match:
|
||||
logger.info(line)
|
||||
running_req = int(match.group(1))
|
||||
queue_req = int(match.group(2))
|
||||
|
||||
nonlocal last_queue_req
|
||||
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0:
|
||||
# Release the semaphore when queue_req transitions from non-zero to zero
|
||||
if semaphore.locked():
|
||||
semaphore.release()
|
||||
logger.info("Semaphore released, allowing a worker to proceed.")
|
||||
|
||||
last_queue_req = queue_req
|
||||
|
||||
async def read_stream(stream):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.decode('utf-8').rstrip()
|
||||
await process_line(line)
|
||||
|
||||
# Start tasks to read stdout and stderr
|
||||
stdout_task = asyncio.create_task(read_stream(proc.stdout))
|
||||
stderr_task = asyncio.create_task(read_stream(proc.stderr))
|
||||
|
||||
await proc.wait()
|
||||
await stdout_task
|
||||
await stderr_task
|
||||
|
||||
|
||||
async def sglang_server_ready():
|
||||
@ -463,7 +500,13 @@ async def main():
|
||||
if args.pdfs:
|
||||
await populate_pdf_work_queue(args)
|
||||
|
||||
sglang_server = asyncio.create_task(sglang_server_task(args))
|
||||
# Create a semaphore to control worker access
|
||||
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
||||
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
sglang_server = asyncio.create_task(sglang_server_task(args, semaphore))
|
||||
|
||||
work_queue = await load_pdf_work_queue(args)
|
||||
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
||||
@ -473,7 +516,7 @@ async def main():
|
||||
# Create worker tasks to process the queue concurrently.
|
||||
worker_tasks = []
|
||||
for i in range(args.workers):
|
||||
task = asyncio.create_task(worker(args, work_queue))
|
||||
task = asyncio.create_task(worker(args, work_queue, semaphore))
|
||||
worker_tasks.append(task)
|
||||
|
||||
# Wait for the queue to be fully processed
|
||||
@ -501,4 +544,3 @@ if __name__ == "__main__":
|
||||
# TODO
|
||||
# Possible future addon, in beaker, discover other nodes on this same job
|
||||
# Send them a message when you take a work item off the queue
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user