More cleanup

This commit is contained in:
Jake Poznanski 2024-10-11 22:37:32 +00:00
parent 53fdb6108c
commit a45f86e4a4

View File

@ -8,6 +8,7 @@ import glob
import tempfile import tempfile
import posixpath import posixpath
from dataclasses import dataclass
from pypdf import PdfReader from pypdf import PdfReader
from tqdm import tqdm from tqdm import tqdm
from typing import Optional from typing import Optional
@ -30,15 +31,17 @@ class DatabaseManager:
def _initialize_tables(self): def _initialize_tables(self):
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS index_table ( CREATE TABLE IF NOT EXISTS page_results (
custom_id TEXT,
s3_path TEXT, s3_path TEXT,
page_num INTEGER,
start_index BIGINT, start_index BIGINT,
end_index BIGINT length BIGINT,
finish_reason STRING
error STRING
) )
""") """)
self.cursor.execute(""" self.cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_custom_id ON index_table(custom_id) CREATE INDEX IF NOT EXISTS idx_path ON index_table(s3_path)
""") """)
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS pdfs ( CREATE TABLE IF NOT EXISTS pdfs (
@ -60,15 +63,16 @@ class DatabaseManager:
value TEXT value TEXT
) )
""") """)
self.cursor.execute("SELECT value FROM metadata WHERE key='round'")
if self.cursor.fetchone() is None:
self.cursor.execute("INSERT INTO metadata (key, value) VALUES ('round', '0')")
self.conn.commit() self.conn.commit()
def get_current_round(self): def get_metadata(self, key: str) -> str:
self.cursor.execute("SELECT value FROM metadata WHERE key='round'") self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
result = self.cursor.fetchone() result = self.cursor.fetchone()
return int(result[0]) return result[0]
def get_current_round(self):
return int(self.get_metadata("round"))
def is_file_processed(self, s3_path, etag): def is_file_processed(self, s3_path, etag):
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,)) self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
@ -76,6 +80,7 @@ class DatabaseManager:
return result is not None and result[0] == etag return result is not None and result[0] == etag
def add_index_entries(self, index_entries): def add_index_entries(self, index_entries):
# TODO MAke it take batchInferenceLines
if index_entries: if index_entries:
self.cursor.executemany(""" self.cursor.executemany("""
INSERT INTO index_table (custom_id, s3_path, start_index, end_index) INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
@ -113,45 +118,6 @@ class DatabaseManager:
def close(self): def close(self):
self.conn.close() self.conn.close()
def build_index(s3_path):
db_manager = DatabaseManager(s3_path)
bucket, prefix = parse_s3_path(s3_path)
# List all .json and .jsonl files under s3_path with their ETags
files = expand_s3_glob(s3_path)
if not files:
print("No .json or .jsonl files found in the specified S3 path.")
db_manager.close()
return
# Prepare a list of files that need processing
files_to_process = [
(key, etag) for key, etag in files.items()
if not db_manager.is_file_processed(key, etag)
]
if not files_to_process:
print("All files are up to date. No processing needed.")
db_manager.close()
return
# Use ProcessPoolExecutor to process files with tqdm progress bar
with ProcessPoolExecutor() as executor:
futures = [
executor.submit(process_file, bucket, key, etag)
for key, etag in files_to_process
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
s3_path, key, etag, index_entries = future.result()
if index_entries:
db_manager.add_index_entries(index_entries)
# Update the processed_files table
db_manager.update_processed_file(key, etag)
db_manager.close()
def parse_s3_path(s3_path): def parse_s3_path(s3_path):
if not s3_path.startswith('s3://'): if not s3_path.startswith('s3://'):
raise ValueError('s3_path must start with s3://') raise ValueError('s3_path must start with s3://')
@ -159,13 +125,13 @@ def parse_s3_path(s3_path):
bucket, _, prefix = path.partition('/') bucket, _, prefix = path.partition('/')
return bucket, prefix return bucket, prefix
def expand_s3_glob(s3_glob: str) -> dict[str, str]: def expand_s3_glob(s3_glob: str) -> dict[str, str]:
parsed = urlparse(s3_glob) parsed = urlparse(s3_glob)
bucket_name = parsed.netloc bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/" prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path) pattern = os.path.basename(parsed.path)
paginator = s3.get_paginator('list_objects_v2') paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
@ -178,37 +144,43 @@ def expand_s3_glob(s3_glob: str) -> dict[str, str]:
return matched_files return matched_files
def process_file(bucket, key, etag): @dataclass(frozen=True)
s3 = boto3.client('s3') # Initialize s3 client in the worker process class BatchInferenceLine:
s3_path = f's3://{bucket}/{key}' s3_path: str
try: page_num: int # 1 indexed!
# Get the object start_index: int
obj = s3.get_object(Bucket=bucket, Key=key) length: int
# Read the content as bytes finish_reason: str
content = obj['Body'].read() error: Optional[str]
# Process the file as JSONL
index_entries = process_jsonl_content(content, s3_path) def parse_custom_id(custom_id: str) -> tuple[str, int]:
# Return the necessary data to the main process s3_path = custom_id[:custom_id.rindex("-")]
return s3_path, key, etag, index_entries page_num = int(custom_id[custom_id.rindex("-") + 1:])
except Exception as e:
print(f"Error processing file {s3_path}: {e}") return s3_path, page_num
return s3_path, key, etag, []
def process_jsonl_content(s3_path) -> list[BatchInferenceLine]:
content = get_s3_bytes(s3_path).decode("utf-8")
def process_jsonl_content(content, s3_path):
start_index = 0 start_index = 0
index_entries = [] index_entries = []
lines = content.splitlines(keepends=True) lines = content.splitlines(keepends=True)
for line in lines: for line in lines:
line_length = len(line) line_length = len(line)
end_index = start_index + line_length
try: try:
data = json.loads(line) data = json.loads(line)
custom_id = data.get('custom_id') s3_path, page_num = parse_custom_id(data["custom_id"])
if custom_id:
index_entries.append((custom_id, s3_path, start_index, end_index)) assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
index_entries.append(BatchInferenceLine(s3_path, page_num, start_index, line_length,
finish_reason=data["outputs"][0]["finish_reason"], error=data.get("completion_error", None)))
except json.JSONDecodeError: except json.JSONDecodeError:
pass # Handle JSON decode errors if necessary pass # Handle JSON decode errors if necessary
start_index = end_index
start_index = start_index + line_length
return index_entries return index_entries
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes: def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
@ -246,7 +218,7 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
parser.add_argument('--pdfs', help='Glob path to PDFs (local or s3)', default=None) parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None)
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB') parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
args = parser.parse_args() args = parser.parse_args()
@ -258,12 +230,12 @@ if __name__ == '__main__':
executor = ProcessPoolExecutor() executor = ProcessPoolExecutor()
# If you have new PDFs, add them to the list # If you have new PDFs, add them to the list
if args.pdfs: if args.add_pdfs:
assert args.pdfs.startswith("s3://"), "PDFs must live on s3" assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3"
print(f"Querying all PDFs at {args.pdfs}") print(f"Querying all PDFs at {args.add_pdfs}")
all_pdfs = expand_s3_glob(args.pdfs) all_pdfs = expand_s3_glob(args.add_pdfs)
print(f"Found {len(all_pdfs)} total pdf paths") print(f"Found {len(all_pdfs)} total pdf paths")
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)] all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
@ -279,8 +251,21 @@ if __name__ == '__main__':
# Now build an index of all the pages that were processed within the workspace so far # Now build an index of all the pages that were processed within the workspace so far
build_index(f"{args.workspace}/*.jsonl") inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl")
# Now, for each pending book, find all pages which still need to be processed inference_output_paths = [
# and add them to the next round's batch inference jobs (key, etag) for key, etag in inference_output_paths.items()
if not db.is_file_processed(key, etag)
]
future_to_path = {executor.submit(process_jsonl_content, s3_path): s3_path for s3_path, etag in inference_output_paths}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path = future_to_path[future]
inference_lines = future.result()
db.add_index_entries(inference_lines)
db.update_processed_file(s3_path, etag=TODO)