From 38dc5a2a0f1c8a0a9668f2d22c222f448c7ec3f3 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 23 Oct 2024 16:28:46 +0000 Subject: [PATCH] Refactored to have a more efficient batchwriter, and also not allow too many running futures --- pdelfin/birrpipeline.py | 224 +++++++++++++++++++++++++++------------- 1 file changed, 151 insertions(+), 73 deletions(-) diff --git a/pdelfin/birrpipeline.py b/pdelfin/birrpipeline.py index 2c9a2ef..96d6659 100644 --- a/pdelfin/birrpipeline.py +++ b/pdelfin/birrpipeline.py @@ -4,7 +4,7 @@ import boto3 import sqlite3 import json import argparse -import glob +import uuid import tempfile import datetime import posixpath @@ -19,6 +19,7 @@ from tqdm import tqdm from functools import partial from typing import Optional, List, Tuple, Dict, Callable, Any from urllib.parse import urlparse +import concurrent.futures from concurrent.futures import ProcessPoolExecutor, as_completed from pdelfin.data.renderpdf import render_pdf_to_base64png @@ -241,26 +242,104 @@ class DatabaseManager: self.conn.close() -# Writes batches of lines out to a set of files, keeping each file below some maximum size class BatchWriter: - def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None): + def __init__( + self, + output_prefix: str, + max_size_mb: int = 250, + after_flush: Optional[Callable[[List[Any]], Any]] = None, + ): self.output_prefix = output_prefix self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes - self.batch = [] + self.batch_objects = [] self.batch_size = 0 self.after_flush = after_flush self.threads = [] + self.temp_file = None # The temporary file object + self.temp_file_path = None # Path to the temporary file parsed = urlparse(output_prefix) - self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n') + self.is_s3 = parsed.scheme in ("s3", "s3a", "s3n") if not self.is_s3: os.makedirs(output_prefix, exist_ok=True) - def _compute_hash(self, content: str) -> str: - """Compute a 20-character SHA1 hash of the given content.""" + def write_line(self, obj: Optional[Any]): + if obj is None: + return + + line_bytes = json.dumps(obj, ensure_ascii=False).encode("utf-8") + line_size = len(line_bytes) + 1 # +1 for newline + + if self.batch_size + line_size > self.max_size: + self._write_batch() + + if self.batch_size == 0: + # Open a new temporary file + self.temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False) + self.temp_file_path = self.temp_file.name + + self.temp_file.write(line_bytes + b"\n") + self.batch_objects.append(obj) + self.batch_size += line_size + + def _write_batch(self): + if self.batch_size == 0: + return + + # Close the temp file + self.temp_file.flush() + self.temp_file.close() + + # Start a new thread to upload the temp file + thread = threading.Thread( + target=self._write_batch_to_file, args=(self.temp_file_path, self.batch_objects) + ) + thread.start() + self.threads.append(thread) + + # Reset batch_objects and batch_size + self.batch_objects = [] + self.batch_size = 0 + self.temp_file = None + self.temp_file_path = None + + def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]): + # Compute hash based on file content + hash_str = self._compute_hash(temp_file_path) + output_path = self._get_output_path(hash_str) + + if self.is_s3: + # Use s3 upload_file + parsed = urlparse(output_path) + bucket = parsed.netloc + key = parsed.path.lstrip("/") + + # Use the s3 client directly + try: + workspace_s3.upload_file(temp_file_path, bucket, key) + except Exception as e: + print(f"Failed to upload {temp_file_path} to {output_path}: {e}") + else: + # Move the temp file to the output path + os.rename(temp_file_path, output_path) + + # After writing, call the after_flush callback if it is set + if self.after_flush: + self.after_flush(batch_objects) + + # Delete the temporary file + os.remove(temp_file_path) + + def _compute_hash(self, temp_file_path: str) -> str: + """Compute a 20-character SHA1 hash of the file content.""" sha1 = hashlib.sha1() - sha1.update(content.encode('utf-8')) + with open(temp_file_path, "rb") as f: + while True: + data = f.read(1024*1024) + if not data: + break + sha1.update(data) return sha1.hexdigest()[:20] def _get_output_path(self, hash_str: str) -> str: @@ -268,59 +347,15 @@ class BatchWriter: parsed = urlparse(self.output_prefix) if self.is_s3: bucket = parsed.netloc - key = parsed.path.lstrip('/') - if key and not key.endswith('/'): - key += '/' + key = parsed.path.lstrip("/") + if key and not key.endswith("/"): + key += "/" full_key = posixpath.join(key, f"output_{hash_str}.jsonl") return f"s3://{bucket}/{full_key}" else: filename = f"output_{hash_str}.jsonl" return os.path.join(self.output_prefix, filename) - def write_line(self, line: Optional[str]): - if line is None or not line.strip(): - return - - line_size = len(line.encode('utf-8')) + 1 # +1 for newline - - if self.batch_size + line_size > self.max_size: - self._write_batch() - - self.batch.append(line) - self.batch_size += line_size - - def _write_batch(self): - if not self.batch: - return - - batch_lines = self.batch.copy() - batch_content = "\n".join(batch_lines) + "\n" - hash_str = self._compute_hash(batch_content) - output_path = self._get_output_path(hash_str) - - # Start a new thread to write the batch - thread = threading.Thread( - target=self._write_batch_to_file, - args=(batch_content, output_path, batch_lines) - ) - thread.start() - self.threads.append(thread) - - # Clear the batch and batch_size - self.batch = [] - self.batch_size = 0 - - def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]): - if self.is_s3: - put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8")) - else: - with open(output_path, 'w', encoding='utf-8') as f_out: - f_out.write(batch_content) - - # After writing, call the after_flush callback if it is set - if self.after_flush: - self.after_flush(batch_lines) - def close(self): self._write_batch() # Wait for all threads to finish @@ -520,11 +555,11 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option return dolma_doc -def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]): +def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]): db = DatabaseManager(s3_workspace) - for line in dolma_doc_lines: - db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed") + for doc in dolma_docs: + db.update_pdf_status(doc["metadata"]["Source-File"], "completed") def get_current_round(s3_workspace: str) -> int: path = s3_workspace[5:] @@ -610,13 +645,13 @@ if __name__ == '__main__': print("Indexing all batch inference sent to this workspace") inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl") - inference_output_paths = [ - (s3_path, etag) for s3_path, etag in inference_output_paths.items() + inference_output_paths = { + s3_path: etag for s3_path, etag in inference_output_paths.items() if not db.is_file_processed(s3_path, etag) - ] + } print(f"Found {len(inference_output_paths):,} new batch inference results to index") - future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths} + future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()} for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): s3_path, etag = future_to_path[future] @@ -638,29 +673,72 @@ if __name__ == '__main__': potentially_done_pdfs = db.get_pdfs_by_status("pending") else: print(f"\nCreating batch inference files for new PDFs") - future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim, args.target_anchor_text_len): pdf for pdf in db.get_pdfs_by_status("pending")} + pdf_list = list(db.get_pdfs_by_status("pending")) + pdf_iter = iter(pdf_list) + pending_futures = {} potentially_done_pdfs = [] lines_written = 0 new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb) + total_pdfs = len(pdf_list) + max_pending = 5000 - for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): - pdf = future_to_path[future] - inference_lines = future.result() + with tqdm(total=total_pdfs) as pbar: + # Submit initial batch of futures + for _ in range(min(max_pending, total_pdfs)): + pdf = next(pdf_iter) + future = executor.submit( + build_pdf_queries, + args.workspace, + pdf, + current_round, + args.target_longest_image_dim, + args.target_anchor_text_len, + ) + pending_futures[future] = pdf - if len(inference_lines) == 0: - potentially_done_pdfs.append(pdf) + while pending_futures: + # Wait for the next future to complete + done, _ = concurrent.futures.wait( + pending_futures.keys(), + return_when=concurrent.futures.FIRST_COMPLETED, + ) - for line in inference_lines: - lines_written += 1 + for future in done: + pdf = pending_futures.pop(future) + inference_lines = future.result() - if line is not None: - new_inference_writer.write_line(json.dumps(line)) + if len(inference_lines) == 0: + potentially_done_pdfs.append(pdf) + + for line in inference_lines: + lines_written += 1 + + if line is not None: + new_inference_writer.write_line(line) + + pbar.update(1) + + # Submit a new future if there are more PDFs + try: + pdf = next(pdf_iter) + new_future = executor.submit( + build_pdf_queries, + args.workspace, + pdf, + current_round, + args.target_longest_image_dim, + args.target_anchor_text_len, + ) + pending_futures[new_future] = pdf + except StopIteration: + pass # No more PDFs to process new_inference_writer.close() if lines_written > 0: print(f"Added {lines_written:,} new batch inference requests") + # Now, finally, assemble any potentially done docs into dolma documents print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output") future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs} @@ -671,7 +749,7 @@ if __name__ == '__main__': dolma_doc = future.result() if dolma_doc is not None: - new_output_writer.write_line(json.dumps(dolma_doc)) + new_output_writer.write_line(dolma_doc) new_output_writer.close()