diff --git a/pdelfin/assemblepipeline.py b/pdelfin/assemblepipeline.py index 84c80ee..3d9c94d 100644 --- a/pdelfin/assemblepipeline.py +++ b/pdelfin/assemblepipeline.py @@ -7,6 +7,54 @@ import argparse from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor, as_completed +class DatabaseManager: + def __init__(self, db_path): + self.db_path = db_path + self.conn = sqlite3.connect(self.db_path) + self.cursor = self.conn.cursor() + self._initialize_tables() + + def _initialize_tables(self): + self.cursor.execute(""" + CREATE TABLE IF NOT EXISTS index_table ( + custom_id TEXT, + s3_path TEXT, + start_index BIGINT, + end_index BIGINT + ) + """) + self.cursor.execute(""" + CREATE TABLE IF NOT EXISTS processed_files ( + s3_path TEXT PRIMARY KEY, + etag TEXT + ) + """) + self.conn.commit() + + def is_file_processed(self, s3_path, etag): + self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,)) + result = self.cursor.fetchone() + return result is not None and result[0] == etag + + def add_index_entries(self, index_entries): + if index_entries: + self.cursor.executemany(""" + INSERT INTO index_table (custom_id, s3_path, start_index, end_index) + VALUES (?, ?, ?, ?) + """, index_entries) + self.conn.commit() + + def update_processed_file(self, s3_path, etag): + self.cursor.execute(""" + INSERT INTO processed_files (s3_path, etag) + VALUES (?, ?) + ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag + """, (s3_path, etag)) + self.conn.commit() + + def close(self): + self.conn.close() + def build_index(s3_path): # Hash the s3_path to get a cache key cache_key = hashlib.sha256(s3_path.encode('utf-8')).hexdigest() @@ -14,25 +62,9 @@ def build_index(s3_path): os.makedirs(home_cache_dir, exist_ok=True) db_path = os.path.join(home_cache_dir, 'index.db') - # Connect to sqlite and create tables if not exist + # Initialize the database manager print("Building page index at", db_path) - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS index_table ( - custom_id TEXT, - s3_path TEXT, - start_index BIGINT, - end_index BIGINT - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS processed_files ( - s3_path TEXT PRIMARY KEY, - etag TEXT - ) - """) - conn.commit() + db_manager = DatabaseManager(db_path) s3 = boto3.client('s3') bucket, prefix = parse_s3_path(s3_path) @@ -42,42 +74,34 @@ def build_index(s3_path): if not files: print("No .json or .jsonl files found in the specified S3 path.") + db_manager.close() return # Prepare a list of files that need processing - files_to_process = [] - for key, etag in files.items(): - cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (key,)) - db_result = cursor.fetchone() - if db_result and db_result[0] == etag: - # File has already been processed with the same ETag - pass # Skip - else: - files_to_process.append((key, etag)) + files_to_process = [ + (key, etag) for key, etag in files.items() + if not db_manager.is_file_processed(key, etag) + ] if not files_to_process: print("All files are up to date. No processing needed.") + db_manager.close() return # Use ProcessPoolExecutor to process files with tqdm progress bar with ProcessPoolExecutor() as executor: - futures = [executor.submit(process_file, bucket, key, etag) for key, etag in files_to_process] + futures = [ + executor.submit(process_file, bucket, key, etag) + for key, etag in files_to_process + ] for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"): s3_path, key, etag, index_entries = future.result() if index_entries: - cursor.executemany(""" - INSERT INTO index_table (custom_id, s3_path, start_index, end_index) - VALUES (?, ?, ?, ?) - """, index_entries) + db_manager.add_index_entries(index_entries) # Update the processed_files table - cursor.execute(""" - INSERT INTO processed_files (s3_path, etag) - VALUES (?, ?) - ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag - """, (key, etag)) - conn.commit() + db_manager.update_processed_file(key, etag) - conn.close() + db_manager.close() def parse_s3_path(s3_path): if not s3_path.startswith('s3://'): @@ -139,5 +163,3 @@ if __name__ == '__main__': # Step one, build an index of all the pages that were processed build_index(args.s3_path) - -