Reworking to be async

This commit is contained in:
Jake Poznanski 2024-11-08 08:14:20 -08:00
parent a103ce730f
commit a39350e074

View File

@ -9,6 +9,7 @@ import subprocess
import atexit import atexit
import hashlib import hashlib
import base64 import base64
import asyncio
from tqdm import tqdm from tqdm import tqdm
from io import BytesIO from io import BytesIO
@ -73,8 +74,123 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1.update(pdf.encode('utf-8')) sha1.update(pdf.encode('utf-8'))
return sha1.hexdigest() return sha1.hexdigest()
async def start_sglang_server(args):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
download_directory(args.model, model_cache_dir)
if __name__ == '__main__': # Start up the sglang server
sglang_process = subprocess.Popen([
"python3", "-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
])
async def populate_pdf_work_queue(args):
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
all_pdfs = set(all_pdfs)
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
# Parse existing work items into groups
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
existing_groups[group_hash] = group_pdfs
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
# Remove existing PDFs from all_pdfs
new_pdfs = all_pdfs - existing_pdf_set
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
# Group the new PDFs into chunks of group_size
# TODO: Figure out the group size automatically by sampling a few pdfs, and taking the mean/median number of pages, etc.
new_groups = []
current_group = []
for pdf in sorted(new_pdfs): # Sort for consistency
current_group.append(pdf)
if len(current_group) == args.group_size:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine existing groups with new groups
combined_groups = existing_groups.copy()
for group_hash, group_pdfs in new_groups:
combined_groups[group_hash] = group_pdfs
# Prepare lines to write back
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]
# Upload the combined work items back to S3
if new_groups:
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
logger.info("Completed adding new PDFs.")
async def load_pdf_work_queue(args) -> asyncio.Queue:
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
# Read in the work queue from s3
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
work_queue[group_hash] = group_pdfs
# Read in the done items from the s3 workspace
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl")
done_work_hashes = set()
for item in done_work_items:
filename = os.path.basename(item)
if filename.startswith('output_') and filename.endswith('.jsonl'):
group_hash = filename[len('output_'):-len('.jsonl')]
done_work_hashes.add(group_hash)
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
queue = asyncio.Queue()
for work in remaining_work_queue:
await queue.put((work, remaining_work_queue[work]))
return queue
async def worker(args, queue):
while True:
work = await queue.get()
logger.info(f"Got work to do for {work}")
queue.task_done()
async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None) parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
@ -101,131 +217,67 @@ if __name__ == '__main__':
pdf_session = boto3.Session(profile_name=args.pdf_profile) pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3") pdf_s3 = pdf_session.client("s3")
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
check_poppler_version() check_poppler_version()
# Check list of pdfs and that it matches what's in the workspace
if args.pdfs: if args.pdfs:
if args.pdfs.startswith("s3://"): await populate_pdf_work_queue(args)
logger.info(f"Expanding s3 glob at {args.pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
all_pdfs = set(all_pdfs) work_queue = await load_pdf_work_queue(args)
logger.info(f"Found {len(all_pdfs):,} total pdf paths") logger.info(f"Work queue prepared with {work_queue.qsize()} items")
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
# Parse existing work items into groups # Create worker tasks to process the queue concurrently.
existing_groups = {} tasks = []
for line in existing_lines: for i in range(args.workers):
if line.strip(): task = asyncio.create_task(worker(args, work_queue))
parts = line.strip().split(",") tasks.append(task)
group_hash = parts[0]
group_pdfs = parts[1:]
existing_groups[group_hash] = group_pdfs
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace") # Wait for the queue to be fully processed
await work_queue.join()
# Remove existing PDFs from all_pdfs # Cancel our worker tasks.
new_pdfs = all_pdfs - existing_pdf_set for task in tasks:
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace") task.cancel()
# Group the new PDFs into chunks of group_size # Wait until all worker tasks are cancelled.
# TODO: Figure out the group size automatically by sampling a few pdfs, and taking the mean/median number of pages, etc. await asyncio.gather(*tasks, return_exceptions=True)
new_groups = []
current_group = []
for pdf in sorted(new_pdfs): # Sort for consistency
current_group.append(pdf)
if len(current_group) == args.group_size:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups") if __name__ == "__main__":
asyncio.run(main())
# Combine existing groups with new groups
combined_groups = existing_groups.copy()
for group_hash, group_pdfs in new_groups:
combined_groups[group_hash] = group_pdfs
# Prepare lines to write back
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]
# Upload the combined work items back to S3
if new_groups:
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
logger.info("Completed adding new PDFs.")
# TODO # TODO
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker # If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
# If not, then your job is to do the actual work # If not, then your job is to do the actual work
# Download the model from the best place available # Download the model from the best place available
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
download_directory(args.model, model_cache_dir)
# Start up the sglang server
sglang_process = subprocess.Popen([
"python3", "-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
])
# Register atexit function and signal handlers to guarantee process termination # Register atexit function and signal handlers to guarantee process termination
def terminate_processes(): # def terminate_processes():
print("Terminating child processes...") # print("Terminating child processes...")
sglang_process.terminate() # sglang_process.terminate()
try: # try:
sglang_process.wait(timeout=30) # sglang_process.wait(timeout=30)
except subprocess.TimeoutExpired: # except subprocess.TimeoutExpired:
print("Forcing termination of child processes.") # print("Forcing termination of child processes.")
sglang_process.kill() # sglang_process.kill()
print("Child processes terminated.") # print("Child processes terminated.")
atexit.register(terminate_processes) # atexit.register(terminate_processes)
def signal_handler(sig, frame): # def signal_handler(sig, frame):
terminate_processes() # terminate_processes()
sys.exit(0) # sys.exit(0)
signal.signal(signal.SIGINT, signal_handler) # signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler) # signal.signal(signal.SIGTERM, signal_handler)
# Read in the work queue from s3
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path) # logger.info(f"Remaining work items: {len(remaining_work_queue)}")
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
work_queue[group_hash] = group_pdfs
# Read in the done items from the s3 workspace
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl")
done_work_hashes = set()
for item in done_work_items:
filename = os.path.basename(item)
if filename.startswith('output_') and filename.endswith('.jsonl'):
group_hash = filename[len('output_'):-len('.jsonl')]
done_work_hashes.add(group_hash)
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
logger.info(f"Remaining work items: {len(remaining_work_queue)}")
# TODO # TODO
# Spawn up to N workers to do: # Spawn up to N workers to do:
@ -238,12 +290,12 @@ if __name__ == '__main__':
# Possible future addon, in beaker, discover other nodes on this same job # Possible future addon, in beaker, discover other nodes on this same job
# Send them a message when you take a work item off the queue # Send them a message when you take a work item off the queue
try: # try:
while True: # while True:
time.sleep(1) # time.sleep(1)
if sglang_process.returncode is not None: # if sglang_process.returncode is not None:
logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.") # logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
except KeyboardInterrupt: # except KeyboardInterrupt:
logger.info("Got keyboard interrupt, exiting everything") # logger.info("Got keyboard interrupt, exiting everything")
sys.exit(1) # sys.exit(1)