diff --git a/pdelfin/assemblepipeline.py b/pdelfin/assemblepipeline.py index 81cbb49..733497c 100644 --- a/pdelfin/assemblepipeline.py +++ b/pdelfin/assemblepipeline.py @@ -8,6 +8,7 @@ import glob import tempfile import posixpath +from dataclasses import dataclass from pypdf import PdfReader from tqdm import tqdm from typing import Optional @@ -30,15 +31,17 @@ class DatabaseManager: def _initialize_tables(self): self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS index_table ( - custom_id TEXT, + CREATE TABLE IF NOT EXISTS page_results ( s3_path TEXT, + page_num INTEGER, start_index BIGINT, - end_index BIGINT + length BIGINT, + finish_reason STRING + error STRING ) """) self.cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_custom_id ON index_table(custom_id) + CREATE INDEX IF NOT EXISTS idx_path ON index_table(s3_path) """) self.cursor.execute(""" CREATE TABLE IF NOT EXISTS pdfs ( @@ -60,15 +63,16 @@ class DatabaseManager: value TEXT ) """) - self.cursor.execute("SELECT value FROM metadata WHERE key='round'") - if self.cursor.fetchone() is None: - self.cursor.execute("INSERT INTO metadata (key, value) VALUES ('round', '0')") + self.conn.commit() - def get_current_round(self): - self.cursor.execute("SELECT value FROM metadata WHERE key='round'") + def get_metadata(self, key: str) -> str: + self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,)) result = self.cursor.fetchone() - return int(result[0]) + return result[0] + + def get_current_round(self): + return int(self.get_metadata("round")) def is_file_processed(self, s3_path, etag): self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,)) @@ -76,6 +80,7 @@ class DatabaseManager: return result is not None and result[0] == etag def add_index_entries(self, index_entries): + # TODO MAke it take batchInferenceLines if index_entries: self.cursor.executemany(""" INSERT INTO index_table (custom_id, s3_path, start_index, end_index) @@ -113,45 +118,6 @@ class DatabaseManager: def close(self): self.conn.close() -def build_index(s3_path): - db_manager = DatabaseManager(s3_path) - - bucket, prefix = parse_s3_path(s3_path) - - # List all .json and .jsonl files under s3_path with their ETags - files = expand_s3_glob(s3_path) - - if not files: - print("No .json or .jsonl files found in the specified S3 path.") - db_manager.close() - return - - # Prepare a list of files that need processing - files_to_process = [ - (key, etag) for key, etag in files.items() - if not db_manager.is_file_processed(key, etag) - ] - - if not files_to_process: - print("All files are up to date. No processing needed.") - db_manager.close() - return - - # Use ProcessPoolExecutor to process files with tqdm progress bar - with ProcessPoolExecutor() as executor: - futures = [ - executor.submit(process_file, bucket, key, etag) - for key, etag in files_to_process - ] - for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"): - s3_path, key, etag, index_entries = future.result() - if index_entries: - db_manager.add_index_entries(index_entries) - # Update the processed_files table - db_manager.update_processed_file(key, etag) - - db_manager.close() - def parse_s3_path(s3_path): if not s3_path.startswith('s3://'): raise ValueError('s3_path must start with s3://') @@ -159,13 +125,13 @@ def parse_s3_path(s3_path): bucket, _, prefix = path.partition('/') return bucket, prefix + def expand_s3_glob(s3_glob: str) -> dict[str, str]: parsed = urlparse(s3_glob) bucket_name = parsed.netloc prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/" pattern = os.path.basename(parsed.path) - paginator = s3.get_paginator('list_objects_v2') page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) @@ -178,37 +144,43 @@ def expand_s3_glob(s3_glob: str) -> dict[str, str]: return matched_files -def process_file(bucket, key, etag): - s3 = boto3.client('s3') # Initialize s3 client in the worker process - s3_path = f's3://{bucket}/{key}' - try: - # Get the object - obj = s3.get_object(Bucket=bucket, Key=key) - # Read the content as bytes - content = obj['Body'].read() - # Process the file as JSONL - index_entries = process_jsonl_content(content, s3_path) - # Return the necessary data to the main process - return s3_path, key, etag, index_entries - except Exception as e: - print(f"Error processing file {s3_path}: {e}") - return s3_path, key, etag, [] +@dataclass(frozen=True) +class BatchInferenceLine: + s3_path: str + page_num: int # 1 indexed! + start_index: int + length: int + finish_reason: str + error: Optional[str] + +def parse_custom_id(custom_id: str) -> tuple[str, int]: + s3_path = custom_id[:custom_id.rindex("-")] + page_num = int(custom_id[custom_id.rindex("-") + 1:]) + + return s3_path, page_num + +def process_jsonl_content(s3_path) -> list[BatchInferenceLine]: + content = get_s3_bytes(s3_path).decode("utf-8") -def process_jsonl_content(content, s3_path): start_index = 0 index_entries = [] lines = content.splitlines(keepends=True) for line in lines: line_length = len(line) - end_index = start_index + line_length + try: data = json.loads(line) - custom_id = data.get('custom_id') - if custom_id: - index_entries.append((custom_id, s3_path, start_index, end_index)) + s3_path, page_num = parse_custom_id(data["custom_id"]) + + assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected" + + index_entries.append(BatchInferenceLine(s3_path, page_num, start_index, line_length, + finish_reason=data["outputs"][0]["finish_reason"], error=data.get("completion_error", None))) except json.JSONDecodeError: pass # Handle JSON decode errors if necessary - start_index = end_index + + start_index = start_index + line_length + return index_entries def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes: @@ -246,7 +218,7 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]: if __name__ == '__main__': parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)') - parser.add_argument('--pdfs', help='Glob path to PDFs (local or s3)', default=None) + parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None) parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB') args = parser.parse_args() @@ -258,12 +230,12 @@ if __name__ == '__main__': executor = ProcessPoolExecutor() # If you have new PDFs, add them to the list - if args.pdfs: - assert args.pdfs.startswith("s3://"), "PDFs must live on s3" + if args.add_pdfs: + assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3" - print(f"Querying all PDFs at {args.pdfs}") + print(f"Querying all PDFs at {args.add_pdfs}") - all_pdfs = expand_s3_glob(args.pdfs) + all_pdfs = expand_s3_glob(args.add_pdfs) print(f"Found {len(all_pdfs)} total pdf paths") all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)] @@ -279,8 +251,21 @@ if __name__ == '__main__': # Now build an index of all the pages that were processed within the workspace so far - build_index(f"{args.workspace}/*.jsonl") + inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl") + + inference_output_paths = [ + (key, etag) for key, etag in inference_output_paths.items() + if not db.is_file_processed(key, etag) + ] + + future_to_path = {executor.submit(process_jsonl_content, s3_path): s3_path for s3_path, etag in inference_output_paths} + + for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): + s3_path = future_to_path[future] + + inference_lines = future.result() + + db.add_index_entries(inference_lines) + + db.update_processed_file(s3_path, etag=TODO) - # Now, for each pending book, find all pages which still need to be processed - # and add them to the next round's batch inference jobs - \ No newline at end of file