diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index 53fb080..02d3ab3 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -33,6 +33,7 @@ def compute_workgroup_sha1(work_group: list[str]) -> str: sha1.update(pdf.encode('utf-8')) return sha1.hexdigest() + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') @@ -76,12 +77,18 @@ if __name__ == '__main__': all_pdfs = set(all_pdfs) logger.info(f"Found {len(all_pdfs):,} total pdf paths") - + existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path) # Parse existing work items into groups - existing_groups = [line.strip().split(",") for line in existing_lines if line.strip()] - existing_pdf_set = set(pdf for group in existing_groups for pdf in group) + existing_groups = {} + for line in existing_lines: + if line.strip(): + parts = line.strip().split(",") + group_hash = parts[0] + group_pdfs = parts[1:] + existing_groups[group_hash] = group_pdfs + existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs) logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace") @@ -96,18 +103,22 @@ if __name__ == '__main__': for pdf in sorted(new_pdfs): # Sort for consistency current_group.append(pdf) if len(current_group) == args.group_size: - new_groups.append(current_group) + group_hash = compute_workgroup_sha1(current_group) + new_groups.append((group_hash, current_group)) current_group = [] if current_group: - new_groups.append(current_group) + group_hash = compute_workgroup_sha1(current_group) + new_groups.append((group_hash, current_group)) logger.info(f"Created {len(new_groups):,} new work groups") # Combine existing groups with new groups - combined_groups = existing_groups + new_groups + combined_groups = existing_groups.copy() + for group_hash, group_pdfs in new_groups: + combined_groups[group_hash] = group_pdfs # Prepare lines to write back - combined_lines = [",".join(group) for group in combined_groups] + combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()] # Upload the combined work items back to S3 if new_groups: @@ -119,9 +130,9 @@ if __name__ == '__main__': # If there is a beaker flag, then your job is to trigger this script with N replicas on beaker # If not, then your job is to do the actual work - # Donwload the model from the best place available + # Download the model from the best place available model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model') - #download_directory(args.model, model_cache_dir) + download_directory(args.model, model_cache_dir) # Start up the sglang server sglang_process = subprocess.Popen([ @@ -131,7 +142,6 @@ if __name__ == '__main__': "--context-length", str(args.model_max_context), ]) - # Register atexit function and signal handlers to guarantee process termination def terminate_processes(): print("Terminating child processes...") @@ -153,11 +163,28 @@ if __name__ == '__main__': signal.signal(signal.SIGTERM, signal_handler) # Read in the work queue from s3 - work_queue = download_zstd_csv(workspace_s3, index_file_s3_path) - work_queue = {compute_workgroup_sha1(pdfs): pdfs for pdfs in work_queue} + work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path) + work_queue = {} + for line in work_queue_lines: + if line.strip(): + parts = line.strip().split(",") + group_hash = parts[0] + group_pdfs = parts[1:] + work_queue[group_hash] = group_pdfs # Read in the done items from the s3 workspace - done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/*.jsonl") + done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl") + done_work_hashes = set() + for item in done_work_items: + filename = os.path.basename(item) + if filename.startswith('output_') and filename.endswith('.jsonl'): + group_hash = filename[len('output_'):-len('.jsonl')] + done_work_hashes.add(group_hash) + + remaining_work_hashes = set(work_queue.keys()) - done_work_hashes + remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes} + + logger.info(f"Remaining work items: {len(remaining_work_queue)}") # TODO # Spawn up to N workers to do: @@ -178,4 +205,4 @@ if __name__ == '__main__': logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.") except KeyboardInterrupt: logger.info("Got keyboard interrupt, exiting everything") - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/pyproject.toml b/pyproject.toml index 8af92d4..7a973e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ dependencies = [ "markdown2", "filelock", "orjson", + "requests", + "zstandard", ] license = {file = "LICENSE"}