New pipeline stuff

This commit is contained in:
Jake Poznanski 2024-10-14 17:09:11 +00:00
parent 4d6eaf654d
commit 39333f2c96

View File

@ -7,6 +7,7 @@ import argparse
import glob
import tempfile
import posixpath
import smart_open
from dataclasses import dataclass
from pypdf import PdfReader
@ -193,7 +194,66 @@ class DatabaseManager:
self.conn.close()
# Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter:
def __init__(self, output_prefix: str, max_size_mb: int = 250):
self.output_prefix = output_prefix
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self.batch = []
self.batch_size = 0
parsed = urlparse(output_prefix)
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
if not self.is_s3:
os.makedirs(output_prefix, exist_ok=True)
def _compute_hash(self, content: str) -> str:
"""Compute a 20-character SHA1 hash of the given content."""
sha1 = hashlib.sha1()
sha1.update(content.encode('utf-8'))
return sha1.hexdigest()[:20]
def _get_output_path(self, hash_str: str) -> str:
"""Generate the full output path with hash in the filename."""
parsed = urlparse(self.output_prefix)
if self.is_s3:
bucket = parsed.netloc
key = parsed.path.lstrip('/')
if key and not key.endswith('/'):
key += '/'
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}"
else:
filename = f"output_{hash_str}.jsonl"
return os.path.join(self.output_prefix, filename)
def write_line(self, line: str):
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
if self.batch_size + line_size > self.max_size:
self._write_batch()
self.batch.append(line)
self.batch_size += line_size
def _write_batch(self):
if not self.batch:
return
batch_content = "\n".join(self.batch) + "\n"
hash_str = self._compute_hash(batch_content)
output_path = self._get_output_path(hash_str)
with smart_open.open(output_path, 'w') as f_out:
f_out.write(batch_content)
print(f"Wrote batch to {output_path}")
self.batch = []
self.batch_size = 0
def close(self):
self._write_batch()
def parse_s3_path(s3_path):
@ -315,6 +375,10 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> list
existing_pages = db.get_index_entries(pdf.s3_path)
new_queries = []
# Shortcut out of downloading the actual PDF
if set(page.page_num for page in existing_pages if page.is_usable()) == set(range(1, pdf.num_pages + 1)):
return []
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
@ -337,7 +401,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None)
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
args = parser.parse_args()
db = DatabaseManager(args.workspace)
@ -383,21 +447,30 @@ if __name__ == '__main__':
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path, etag = future_to_path[future]
inference_lines = future.result()
inference_records = future.result()
db.add_index_entries(inference_lines)
db.add_index_entries(inference_records)
db.update_processed_file(s3_path, etag=etag)
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
potentially_done_pdfs = []
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb)
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future]
inference_lines = future.result()
if len(inference_lines) == 0:
potentially_done_pdfs.append(pdf)
for line in inference_lines:
new_inference_writer.write_line(json.dumps(line))
new_inference_writer.close()
# Now, finally, assemble any potentially done docs into dolma documents
# TODO