diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index f7c07c9..2817b07 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -2,6 +2,8 @@ import logging import argparse import boto3 import os +import subprocess +import atexit from tqdm import tqdm @@ -76,6 +78,7 @@ if __name__ == '__main__': logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace") # Group the new PDFs into chunks of group_size + # TODO: Figure out the group size automatically by sampling a few pdfs, and taking the mean/median number of pages, etc. new_groups = [] current_group = [] for pdf in sorted(new_pdfs): # Sort for consistency @@ -109,6 +112,25 @@ if __name__ == '__main__': download_directory(args.model, model_cache_dir) # Start up the sglang server + sglang_process = subprocess.Popen([ + "python3", "-m", "sglang.launch_server", + "--model-path", model_cache_dir, + "--chat-template", args.mode_chat_template, + "--context-length", args.model_max_context + ]) + + # Register atexit function to guarantee process termination + def terminate_processes(): + print("Terminating child processes...") + sglang_process.terminate() + try: + sglang_process.wait(timeout=30) + except subprocess.TimeoutExpired: + print("Forcing termination of child processes.") + sglang_process.kill() + print("Child processes terminated.") + + atexit.register(terminate_processes) # Read in the work queue from s3 # Read in the done items from the s3 workspace