2024-10-09 22:14:28 +00:00
|
|
|
import os
|
|
|
|
import hashlib
|
2024-10-10 22:10:26 +00:00
|
|
|
import boto3
|
2024-10-10 22:30:09 +00:00
|
|
|
import sqlite3
|
2024-10-10 22:10:26 +00:00
|
|
|
import json
|
|
|
|
import argparse
|
2024-10-09 22:14:28 +00:00
|
|
|
from tqdm import tqdm
|
2024-10-10 22:30:09 +00:00
|
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
2024-10-09 22:14:28 +00:00
|
|
|
|
2024-10-11 16:24:29 +00:00
|
|
|
class DatabaseManager:
|
2024-10-11 20:22:58 +00:00
|
|
|
def __init__(self, s3_workspace: str):
|
|
|
|
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
|
|
|
|
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key)
|
|
|
|
os.makedirs(home_cache_dir, exist_ok=True)
|
|
|
|
self.db_path = os.path.join(home_cache_dir, 'index.db')
|
|
|
|
|
2024-10-11 16:24:29 +00:00
|
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
|
|
self.cursor = self.conn.cursor()
|
|
|
|
self._initialize_tables()
|
|
|
|
|
|
|
|
def _initialize_tables(self):
|
|
|
|
self.cursor.execute("""
|
|
|
|
CREATE TABLE IF NOT EXISTS index_table (
|
|
|
|
custom_id TEXT,
|
|
|
|
s3_path TEXT,
|
|
|
|
start_index BIGINT,
|
|
|
|
end_index BIGINT
|
|
|
|
)
|
|
|
|
""")
|
2024-10-11 20:22:58 +00:00
|
|
|
self.cursor.execute("""
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_custom_id ON index_table(custom_id)
|
|
|
|
""")
|
|
|
|
self.cursor.execute("""
|
|
|
|
CREATE TABLE IF NOT EXISTS pdfs (
|
|
|
|
s3_path TEXT PRIMARY KEY,
|
|
|
|
num_pages INTEGER,
|
|
|
|
status TEXT DEFAULT 'pending'
|
|
|
|
)
|
|
|
|
""")
|
2024-10-11 16:24:29 +00:00
|
|
|
self.cursor.execute("""
|
|
|
|
CREATE TABLE IF NOT EXISTS processed_files (
|
|
|
|
s3_path TEXT PRIMARY KEY,
|
|
|
|
etag TEXT
|
|
|
|
)
|
|
|
|
""")
|
2024-10-11 20:22:58 +00:00
|
|
|
# Generic metadata such as current round
|
|
|
|
self.cursor.execute("""
|
|
|
|
CREATE TABLE IF NOT EXISTS metadata (
|
|
|
|
key TEXT PRIMARY KEY,
|
|
|
|
value TEXT
|
|
|
|
)
|
|
|
|
""")
|
|
|
|
self.cursor.execute("SELECT value FROM metadata WHERE key='round'")
|
|
|
|
if self.cursor.fetchone() is None:
|
|
|
|
self.cursor.execute("INSERT INTO metadata (key, value) VALUES ('round', '0')")
|
2024-10-11 16:24:29 +00:00
|
|
|
self.conn.commit()
|
|
|
|
|
2024-10-11 20:22:58 +00:00
|
|
|
def get_current_round(self):
|
|
|
|
self.cursor.execute("SELECT value FROM metadata WHERE key='round'")
|
|
|
|
result = self.cursor.fetchone()
|
|
|
|
return int(result[0])
|
|
|
|
|
2024-10-11 16:24:29 +00:00
|
|
|
def is_file_processed(self, s3_path, etag):
|
|
|
|
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
|
|
|
result = self.cursor.fetchone()
|
|
|
|
return result is not None and result[0] == etag
|
|
|
|
|
|
|
|
def add_index_entries(self, index_entries):
|
|
|
|
if index_entries:
|
|
|
|
self.cursor.executemany("""
|
|
|
|
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
|
|
|
|
VALUES (?, ?, ?, ?)
|
|
|
|
""", index_entries)
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
|
|
def update_processed_file(self, s3_path, etag):
|
|
|
|
self.cursor.execute("""
|
|
|
|
INSERT INTO processed_files (s3_path, etag)
|
|
|
|
VALUES (?, ?)
|
|
|
|
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
|
|
|
""", (s3_path, etag))
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
self.conn.close()
|
|
|
|
|
2024-10-10 22:10:26 +00:00
|
|
|
def build_index(s3_path):
|
2024-10-11 20:22:58 +00:00
|
|
|
db_manager = DatabaseManager(s3_path)
|
2024-10-09 22:14:28 +00:00
|
|
|
|
2024-10-10 22:10:26 +00:00
|
|
|
s3 = boto3.client('s3')
|
|
|
|
bucket, prefix = parse_s3_path(s3_path)
|
2024-10-09 22:14:28 +00:00
|
|
|
|
2024-10-10 22:10:26 +00:00
|
|
|
# List all .json and .jsonl files under s3_path with their ETags
|
|
|
|
files = list_s3_files(s3, bucket, prefix)
|
2024-10-09 22:14:28 +00:00
|
|
|
|
2024-10-10 22:13:43 +00:00
|
|
|
if not files:
|
|
|
|
print("No .json or .jsonl files found in the specified S3 path.")
|
2024-10-11 16:24:29 +00:00
|
|
|
db_manager.close()
|
2024-10-10 22:10:26 +00:00
|
|
|
return
|
2024-10-09 22:14:28 +00:00
|
|
|
|
2024-10-10 22:30:09 +00:00
|
|
|
# Prepare a list of files that need processing
|
2024-10-11 16:24:29 +00:00
|
|
|
files_to_process = [
|
|
|
|
(key, etag) for key, etag in files.items()
|
|
|
|
if not db_manager.is_file_processed(key, etag)
|
|
|
|
]
|
2024-10-10 22:30:09 +00:00
|
|
|
|
|
|
|
if not files_to_process:
|
|
|
|
print("All files are up to date. No processing needed.")
|
2024-10-11 16:24:29 +00:00
|
|
|
db_manager.close()
|
2024-10-10 22:30:09 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
# Use ProcessPoolExecutor to process files with tqdm progress bar
|
|
|
|
with ProcessPoolExecutor() as executor:
|
2024-10-11 16:24:29 +00:00
|
|
|
futures = [
|
|
|
|
executor.submit(process_file, bucket, key, etag)
|
|
|
|
for key, etag in files_to_process
|
|
|
|
]
|
2024-10-10 22:30:09 +00:00
|
|
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
|
|
|
|
s3_path, key, etag, index_entries = future.result()
|
|
|
|
if index_entries:
|
2024-10-11 16:24:29 +00:00
|
|
|
db_manager.add_index_entries(index_entries)
|
2024-10-10 22:30:09 +00:00
|
|
|
# Update the processed_files table
|
2024-10-11 16:24:29 +00:00
|
|
|
db_manager.update_processed_file(key, etag)
|
2024-10-10 22:30:09 +00:00
|
|
|
|
2024-10-11 16:24:29 +00:00
|
|
|
db_manager.close()
|
2024-10-09 23:39:13 +00:00
|
|
|
|
2024-10-09 22:14:28 +00:00
|
|
|
def parse_s3_path(s3_path):
|
2024-10-10 22:10:26 +00:00
|
|
|
if not s3_path.startswith('s3://'):
|
|
|
|
raise ValueError('s3_path must start with s3://')
|
|
|
|
path = s3_path[5:]
|
|
|
|
bucket, _, prefix = path.partition('/')
|
|
|
|
return bucket, prefix
|
|
|
|
|
|
|
|
def list_s3_files(s3, bucket, prefix):
|
|
|
|
paginator = s3.get_paginator('list_objects_v2')
|
|
|
|
page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
|
|
|
files = {}
|
|
|
|
for page in page_iterator:
|
|
|
|
contents = page.get('Contents', [])
|
|
|
|
for obj in contents:
|
|
|
|
key = obj['Key']
|
|
|
|
if key.endswith('.json') or key.endswith('.jsonl'):
|
|
|
|
# Retrieve ETag for each file
|
|
|
|
files[key] = obj['ETag'].strip('"')
|
|
|
|
return files
|
|
|
|
|
2024-10-10 22:30:09 +00:00
|
|
|
def process_file(bucket, key, etag):
|
|
|
|
s3 = boto3.client('s3') # Initialize s3 client in the worker process
|
2024-10-10 22:13:43 +00:00
|
|
|
s3_path = f's3://{bucket}/{key}'
|
2024-10-09 22:19:16 +00:00
|
|
|
try:
|
2024-10-10 22:30:09 +00:00
|
|
|
# Get the object
|
|
|
|
obj = s3.get_object(Bucket=bucket, Key=key)
|
|
|
|
# Read the content as bytes
|
|
|
|
content = obj['Body'].read()
|
|
|
|
# Process the file as JSONL
|
|
|
|
index_entries = process_jsonl_content(content, s3_path)
|
|
|
|
# Return the necessary data to the main process
|
|
|
|
return s3_path, key, etag, index_entries
|
2024-10-10 22:10:26 +00:00
|
|
|
except Exception as e:
|
2024-10-10 22:13:43 +00:00
|
|
|
print(f"Error processing file {s3_path}: {e}")
|
2024-10-10 22:30:09 +00:00
|
|
|
return s3_path, key, etag, []
|
2024-10-10 22:10:26 +00:00
|
|
|
|
2024-10-10 22:30:09 +00:00
|
|
|
def process_jsonl_content(content, s3_path):
|
2024-10-10 22:10:26 +00:00
|
|
|
start_index = 0
|
2024-10-10 22:30:09 +00:00
|
|
|
index_entries = []
|
2024-10-10 22:10:26 +00:00
|
|
|
lines = content.splitlines(keepends=True)
|
|
|
|
for line in lines:
|
|
|
|
line_length = len(line)
|
|
|
|
end_index = start_index + line_length
|
2024-10-09 22:19:16 +00:00
|
|
|
try:
|
2024-10-10 22:10:26 +00:00
|
|
|
data = json.loads(line)
|
|
|
|
custom_id = data.get('custom_id')
|
|
|
|
if custom_id:
|
2024-10-10 22:30:09 +00:00
|
|
|
index_entries.append((custom_id, s3_path, start_index, end_index))
|
2024-10-10 22:10:26 +00:00
|
|
|
except json.JSONDecodeError:
|
|
|
|
pass # Handle JSON decode errors if necessary
|
|
|
|
start_index = end_index
|
2024-10-10 22:30:09 +00:00
|
|
|
return index_entries
|
2024-10-09 22:19:16 +00:00
|
|
|
|
2024-10-10 22:10:26 +00:00
|
|
|
if __name__ == '__main__':
|
2024-10-11 20:22:58 +00:00
|
|
|
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
|
|
|
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
|
|
|
parser.add_argument('--pdf_glob_path', help='Glob path to PDFs (local or s3)', default=None)
|
|
|
|
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
|
2024-10-09 22:14:28 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2024-10-11 20:22:58 +00:00
|
|
|
db = DatabaseManager(args.workspace)
|
|
|
|
print(f"Loaded db at {db.db_path}")
|
|
|
|
print(f"Current round is {db.get_current_round()}")
|
|
|
|
|
|
|
|
if args.pdf_glob_path:
|
|
|
|
# Add new pdfs to be processed if they don't exist in the database
|
|
|
|
# TODO
|
|
|
|
pass
|
|
|
|
|
2024-10-10 22:36:09 +00:00
|
|
|
# Step one, build an index of all the pages that were processed
|
2024-10-11 20:22:58 +00:00
|
|
|
build_index(args.workspace)
|