Reworking to be async

This commit is contained in:
Jake Poznanski 2024-11-08 08:14:20 -08:00
parent a103ce730f
commit a39350e074

View File

@ -9,6 +9,7 @@ import subprocess
import atexit
import hashlib
import base64
import asyncio
from tqdm import tqdm
from io import BytesIO
@ -73,39 +74,21 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1.update(pdf.encode('utf-8'))
return sha1.hexdigest()
async def start_sglang_server(args):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
download_directory(args.model, model_cache_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time')
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
args = parser.parse_args()
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
# Start up the sglang server
sglang_process = subprocess.Popen([
"python3", "-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
])
async def populate_pdf_work_queue(args):
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
check_poppler_version()
# Check list of pdfs and that it matches what's in the workspace
if args.pdfs:
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
@ -167,41 +150,8 @@ if __name__ == '__main__':
logger.info("Completed adding new PDFs.")
# TODO
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
# If not, then your job is to do the actual work
# Download the model from the best place available
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
download_directory(args.model, model_cache_dir)
# Start up the sglang server
sglang_process = subprocess.Popen([
"python3", "-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
])
# Register atexit function and signal handlers to guarantee process termination
def terminate_processes():
print("Terminating child processes...")
sglang_process.terminate()
try:
sglang_process.wait(timeout=30)
except subprocess.TimeoutExpired:
print("Forcing termination of child processes.")
sglang_process.kill()
print("Child processes terminated.")
atexit.register(terminate_processes)
def signal_handler(sig, frame):
terminate_processes()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
async def load_pdf_work_queue(args) -> asyncio.Queue:
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
# Read in the work queue from s3
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
@ -225,7 +175,109 @@ if __name__ == '__main__':
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
logger.info(f"Remaining work items: {len(remaining_work_queue)}")
queue = asyncio.Queue()
for work in remaining_work_queue:
await queue.put((work, remaining_work_queue[work]))
return queue
async def worker(args, queue):
while True:
work = await queue.get()
logger.info(f"Got work to do for {work}")
queue.task_done()
async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time')
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
args = parser.parse_args()
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
check_poppler_version()
if args.pdfs:
await populate_pdf_work_queue(args)
work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
# Create worker tasks to process the queue concurrently.
tasks = []
for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue))
tasks.append(task)
# Wait for the queue to be fully processed
await work_queue.join()
# Cancel our worker tasks.
for task in tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True)
if __name__ == "__main__":
asyncio.run(main())
# TODO
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
# If not, then your job is to do the actual work
# Download the model from the best place available
# Register atexit function and signal handlers to guarantee process termination
# def terminate_processes():
# print("Terminating child processes...")
# sglang_process.terminate()
# try:
# sglang_process.wait(timeout=30)
# except subprocess.TimeoutExpired:
# print("Forcing termination of child processes.")
# sglang_process.kill()
# print("Child processes terminated.")
# atexit.register(terminate_processes)
# def signal_handler(sig, frame):
# terminate_processes()
# sys.exit(0)
# signal.signal(signal.SIGINT, signal_handler)
# signal.signal(signal.SIGTERM, signal_handler)
# logger.info(f"Remaining work items: {len(remaining_work_queue)}")
# TODO
# Spawn up to N workers to do:
@ -238,12 +290,12 @@ if __name__ == '__main__':
# Possible future addon, in beaker, discover other nodes on this same job
# Send them a message when you take a work item off the queue
try:
while True:
time.sleep(1)
# try:
# while True:
# time.sleep(1)
if sglang_process.returncode is not None:
logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
except KeyboardInterrupt:
logger.info("Got keyboard interrupt, exiting everything")
sys.exit(1)
# if sglang_process.returncode is not None:
# logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
# except KeyboardInterrupt:
# logger.info("Got keyboard interrupt, exiting everything")
# sys.exit(1)