mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-19 14:22:26 +00:00
Quicker results by limited workers via semaphore while still utilizing gpu
This commit is contained in:
parent
615409568d
commit
4f2f4fda7d
@ -14,6 +14,7 @@ import asyncio
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import datetime
|
import datetime
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import re
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -332,11 +333,14 @@ async def process_pdf(args, pdf_s3_path: str):
|
|||||||
return dolma_doc
|
return dolma_doc
|
||||||
|
|
||||||
|
|
||||||
async def worker(args, queue):
|
async def worker(args, queue, semaphore):
|
||||||
while True:
|
while True:
|
||||||
[work_hash, pdfs] = await queue.get()
|
[work_hash, pdfs] = await queue.get()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Wait until allowed to proceed
|
||||||
|
await semaphore.acquire()
|
||||||
|
|
||||||
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
||||||
dolma_docs = [doc for doc in dolma_docs if doc is not None]
|
dolma_docs = [doc for doc in dolma_docs if doc is not None]
|
||||||
|
|
||||||
@ -372,7 +376,7 @@ async def worker(args, queue):
|
|||||||
queue.task_done()
|
queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
async def sglang_server_task(args):
|
async def sglang_server_task(args, semaphore):
|
||||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||||
# TODO cache locally
|
# TODO cache locally
|
||||||
#download_directory(args.model, model_cache_dir)
|
#download_directory(args.model, model_cache_dir)
|
||||||
@ -390,20 +394,53 @@ async def sglang_server_task(args):
|
|||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
"python3",
|
"python3",
|
||||||
|
|
||||||
"-m", "sglang.launch_server",
|
"-m", "sglang.launch_server",
|
||||||
"--model-path", model_cache_dir,
|
"--model-path", model_cache_dir,
|
||||||
"--chat-template", args.model_chat_template,
|
"--chat-template", args.model_chat_template,
|
||||||
"--context-length", str(args.model_max_context),
|
"--context-length", str(args.model_max_context),
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make really sure we kill this subprocess on exit
|
# Make sure we kill this subprocess on exit
|
||||||
def _kill_proc():
|
def _kill_proc():
|
||||||
proc.terminate()
|
proc.terminate()
|
||||||
|
|
||||||
atexit.register(_kill_proc)
|
atexit.register(_kill_proc)
|
||||||
|
|
||||||
|
last_queue_req = None # To track transitions
|
||||||
|
async def process_line(line):
|
||||||
|
# Parse the line and update semaphore if necessary
|
||||||
|
match = re.search(r'#running-req: (\d+), #queue-req: (\d+)', line)
|
||||||
|
if match:
|
||||||
|
logger.info(line)
|
||||||
|
running_req = int(match.group(1))
|
||||||
|
queue_req = int(match.group(2))
|
||||||
|
|
||||||
|
nonlocal last_queue_req
|
||||||
|
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0:
|
||||||
|
# Release the semaphore when queue_req transitions from non-zero to zero
|
||||||
|
if semaphore.locked():
|
||||||
|
semaphore.release()
|
||||||
|
logger.info("Semaphore released, allowing a worker to proceed.")
|
||||||
|
|
||||||
|
last_queue_req = queue_req
|
||||||
|
|
||||||
|
async def read_stream(stream):
|
||||||
|
while True:
|
||||||
|
line = await stream.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
line = line.decode('utf-8').rstrip()
|
||||||
|
await process_line(line)
|
||||||
|
|
||||||
|
# Start tasks to read stdout and stderr
|
||||||
|
stdout_task = asyncio.create_task(read_stream(proc.stdout))
|
||||||
|
stderr_task = asyncio.create_task(read_stream(proc.stderr))
|
||||||
|
|
||||||
await proc.wait()
|
await proc.wait()
|
||||||
|
await stdout_task
|
||||||
|
await stderr_task
|
||||||
|
|
||||||
|
|
||||||
async def sglang_server_ready():
|
async def sglang_server_ready():
|
||||||
@ -463,7 +500,13 @@ async def main():
|
|||||||
if args.pdfs:
|
if args.pdfs:
|
||||||
await populate_pdf_work_queue(args)
|
await populate_pdf_work_queue(args)
|
||||||
|
|
||||||
sglang_server = asyncio.create_task(sglang_server_task(args))
|
# Create a semaphore to control worker access
|
||||||
|
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
||||||
|
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
||||||
|
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||||
|
semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
sglang_server = asyncio.create_task(sglang_server_task(args, semaphore))
|
||||||
|
|
||||||
work_queue = await load_pdf_work_queue(args)
|
work_queue = await load_pdf_work_queue(args)
|
||||||
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
||||||
@ -473,7 +516,7 @@ async def main():
|
|||||||
# Create worker tasks to process the queue concurrently.
|
# Create worker tasks to process the queue concurrently.
|
||||||
worker_tasks = []
|
worker_tasks = []
|
||||||
for i in range(args.workers):
|
for i in range(args.workers):
|
||||||
task = asyncio.create_task(worker(args, work_queue))
|
task = asyncio.create_task(worker(args, work_queue, semaphore))
|
||||||
worker_tasks.append(task)
|
worker_tasks.append(task)
|
||||||
|
|
||||||
# Wait for the queue to be fully processed
|
# Wait for the queue to be fully processed
|
||||||
@ -501,4 +544,3 @@ if __name__ == "__main__":
|
|||||||
# TODO
|
# TODO
|
||||||
# Possible future addon, in beaker, discover other nodes on this same job
|
# Possible future addon, in beaker, discover other nodes on this same job
|
||||||
# Send them a message when you take a work item off the queue
|
# Send them a message when you take a work item off the queue
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user