From 39333f2c963d3e789fbab498ab41c38bb050e8e8 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 14 Oct 2024 17:09:11 +0000 Subject: [PATCH] New pipeline stuff --- pdelfin/assemblepipeline.py | 79 +++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/pdelfin/assemblepipeline.py b/pdelfin/assemblepipeline.py index fd62216..a04de09 100644 --- a/pdelfin/assemblepipeline.py +++ b/pdelfin/assemblepipeline.py @@ -7,6 +7,7 @@ import argparse import glob import tempfile import posixpath +import smart_open from dataclasses import dataclass from pypdf import PdfReader @@ -193,7 +194,66 @@ class DatabaseManager: self.conn.close() +# Writes batches of lines out to a set of files, keeping each file below some maximum size +class BatchWriter: + def __init__(self, output_prefix: str, max_size_mb: int = 250): + self.output_prefix = output_prefix + self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes + self.batch = [] + self.batch_size = 0 + parsed = urlparse(output_prefix) + self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n') + + if not self.is_s3: + os.makedirs(output_prefix, exist_ok=True) + + def _compute_hash(self, content: str) -> str: + """Compute a 20-character SHA1 hash of the given content.""" + sha1 = hashlib.sha1() + sha1.update(content.encode('utf-8')) + return sha1.hexdigest()[:20] + + def _get_output_path(self, hash_str: str) -> str: + """Generate the full output path with hash in the filename.""" + parsed = urlparse(self.output_prefix) + if self.is_s3: + bucket = parsed.netloc + key = parsed.path.lstrip('/') + if key and not key.endswith('/'): + key += '/' + full_key = posixpath.join(key, f"output_{hash_str}.jsonl") + return f"s3://{bucket}/{full_key}" + else: + filename = f"output_{hash_str}.jsonl" + return os.path.join(self.output_prefix, filename) + + def write_line(self, line: str): + line_size = len(line.encode('utf-8')) + 1 # +1 for newline + + if self.batch_size + line_size > self.max_size: + self._write_batch() + + self.batch.append(line) + self.batch_size += line_size + + def _write_batch(self): + if not self.batch: + return + + batch_content = "\n".join(self.batch) + "\n" + hash_str = self._compute_hash(batch_content) + output_path = self._get_output_path(hash_str) + + with smart_open.open(output_path, 'w') as f_out: + f_out.write(batch_content) + print(f"Wrote batch to {output_path}") + + self.batch = [] + self.batch_size = 0 + + def close(self): + self._write_batch() def parse_s3_path(s3_path): @@ -315,6 +375,10 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> list existing_pages = db.get_index_entries(pdf.s3_path) new_queries = [] + + # Shortcut out of downloading the actual PDF + if set(page.page_num for page in existing_pages if page.is_usable()) == set(range(1, pdf.num_pages + 1)): + return [] try: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: @@ -337,7 +401,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)') parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None) - parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB') + parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB') args = parser.parse_args() db = DatabaseManager(args.workspace) @@ -383,21 +447,30 @@ if __name__ == '__main__': for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): s3_path, etag = future_to_path[future] - inference_lines = future.result() + inference_records = future.result() - db.add_index_entries(inference_lines) + db.add_index_entries(inference_records) db.update_processed_file(s3_path, etag=etag) # Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it # If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")} + potentially_done_pdfs = [] + new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb) for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): pdf = future_to_path[future] inference_lines = future.result() + if len(inference_lines) == 0: + potentially_done_pdfs.append(pdf) + for line in inference_lines: + new_inference_writer.write_line(json.dumps(line)) + new_inference_writer.close() + + # Now, finally, assemble any potentially done docs into dolma documents # TODO