mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-01 04:46:16 +00:00
Reworking to be async
This commit is contained in:
parent
a103ce730f
commit
a39350e074
@ -9,6 +9,7 @@ import subprocess
|
||||
import atexit
|
||||
import hashlib
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
from tqdm import tqdm
|
||||
from io import BytesIO
|
||||
@ -73,39 +74,21 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
|
||||
sha1.update(pdf.encode('utf-8'))
|
||||
return sha1.hexdigest()
|
||||
|
||||
async def start_sglang_server(args):
|
||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||
download_directory(args.model, model_cache_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
|
||||
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
|
||||
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
|
||||
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
|
||||
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time')
|
||||
|
||||
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
|
||||
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
|
||||
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
|
||||
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
|
||||
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
|
||||
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.workspace_profile:
|
||||
workspace_session = boto3.Session(profile_name=args.workspace_profile)
|
||||
workspace_s3 = workspace_session.client("s3")
|
||||
|
||||
if args.pdf_profile:
|
||||
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||
pdf_s3 = pdf_session.client("s3")
|
||||
# Start up the sglang server
|
||||
sglang_process = subprocess.Popen([
|
||||
"python3", "-m", "sglang.launch_server",
|
||||
"--model-path", model_cache_dir,
|
||||
"--chat-template", args.model_chat_template,
|
||||
"--context-length", str(args.model_max_context),
|
||||
])
|
||||
|
||||
async def populate_pdf_work_queue(args):
|
||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||
check_poppler_version()
|
||||
|
||||
# Check list of pdfs and that it matches what's in the workspace
|
||||
if args.pdfs:
|
||||
if args.pdfs.startswith("s3://"):
|
||||
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
||||
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
|
||||
@ -167,41 +150,8 @@ if __name__ == '__main__':
|
||||
|
||||
logger.info("Completed adding new PDFs.")
|
||||
|
||||
# TODO
|
||||
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
|
||||
# If not, then your job is to do the actual work
|
||||
|
||||
# Download the model from the best place available
|
||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||
download_directory(args.model, model_cache_dir)
|
||||
|
||||
# Start up the sglang server
|
||||
sglang_process = subprocess.Popen([
|
||||
"python3", "-m", "sglang.launch_server",
|
||||
"--model-path", model_cache_dir,
|
||||
"--chat-template", args.model_chat_template,
|
||||
"--context-length", str(args.model_max_context),
|
||||
])
|
||||
|
||||
# Register atexit function and signal handlers to guarantee process termination
|
||||
def terminate_processes():
|
||||
print("Terminating child processes...")
|
||||
sglang_process.terminate()
|
||||
try:
|
||||
sglang_process.wait(timeout=30)
|
||||
except subprocess.TimeoutExpired:
|
||||
print("Forcing termination of child processes.")
|
||||
sglang_process.kill()
|
||||
print("Child processes terminated.")
|
||||
|
||||
atexit.register(terminate_processes)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
terminate_processes()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
async def load_pdf_work_queue(args) -> asyncio.Queue:
|
||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||
|
||||
# Read in the work queue from s3
|
||||
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
||||
@ -225,7 +175,109 @@ if __name__ == '__main__':
|
||||
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
|
||||
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
|
||||
|
||||
logger.info(f"Remaining work items: {len(remaining_work_queue)}")
|
||||
queue = asyncio.Queue()
|
||||
|
||||
for work in remaining_work_queue:
|
||||
await queue.put((work, remaining_work_queue[work]))
|
||||
|
||||
return queue
|
||||
|
||||
async def worker(args, queue):
|
||||
while True:
|
||||
work = await queue.get()
|
||||
|
||||
logger.info(f"Got work to do for {work}")
|
||||
queue.task_done()
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
|
||||
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
|
||||
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
|
||||
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
|
||||
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time')
|
||||
|
||||
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
|
||||
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
|
||||
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
|
||||
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
|
||||
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
|
||||
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.workspace_profile:
|
||||
workspace_session = boto3.Session(profile_name=args.workspace_profile)
|
||||
workspace_s3 = workspace_session.client("s3")
|
||||
|
||||
if args.pdf_profile:
|
||||
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||
pdf_s3 = pdf_session.client("s3")
|
||||
|
||||
check_poppler_version()
|
||||
|
||||
if args.pdfs:
|
||||
await populate_pdf_work_queue(args)
|
||||
|
||||
work_queue = await load_pdf_work_queue(args)
|
||||
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
||||
|
||||
|
||||
# Create worker tasks to process the queue concurrently.
|
||||
tasks = []
|
||||
for i in range(args.workers):
|
||||
task = asyncio.create_task(worker(args, work_queue))
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for the queue to be fully processed
|
||||
await work_queue.join()
|
||||
|
||||
# Cancel our worker tasks.
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait until all worker tasks are cancelled.
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
|
||||
|
||||
# TODO
|
||||
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
|
||||
# If not, then your job is to do the actual work
|
||||
|
||||
# Download the model from the best place available
|
||||
|
||||
|
||||
# Register atexit function and signal handlers to guarantee process termination
|
||||
# def terminate_processes():
|
||||
# print("Terminating child processes...")
|
||||
# sglang_process.terminate()
|
||||
# try:
|
||||
# sglang_process.wait(timeout=30)
|
||||
# except subprocess.TimeoutExpired:
|
||||
# print("Forcing termination of child processes.")
|
||||
# sglang_process.kill()
|
||||
# print("Child processes terminated.")
|
||||
|
||||
# atexit.register(terminate_processes)
|
||||
|
||||
# def signal_handler(sig, frame):
|
||||
# terminate_processes()
|
||||
# sys.exit(0)
|
||||
|
||||
# signal.signal(signal.SIGINT, signal_handler)
|
||||
# signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
|
||||
# logger.info(f"Remaining work items: {len(remaining_work_queue)}")
|
||||
|
||||
# TODO
|
||||
# Spawn up to N workers to do:
|
||||
@ -238,12 +290,12 @@ if __name__ == '__main__':
|
||||
# Possible future addon, in beaker, discover other nodes on this same job
|
||||
# Send them a message when you take a work item off the queue
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
# try:
|
||||
# while True:
|
||||
# time.sleep(1)
|
||||
|
||||
if sglang_process.returncode is not None:
|
||||
logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Got keyboard interrupt, exiting everything")
|
||||
sys.exit(1)
|
||||
# if sglang_process.returncode is not None:
|
||||
# logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
|
||||
# except KeyboardInterrupt:
|
||||
# logger.info("Got keyboard interrupt, exiting everything")
|
||||
# sys.exit(1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user