diff --git a/pdelfin/assemblepipeline.py b/pdelfin/assemblepipeline.py index 7af3a47..40a6e4c 100644 --- a/pdelfin/assemblepipeline.py +++ b/pdelfin/assemblepipeline.py @@ -88,10 +88,93 @@ def parse_s3_path(s3_path): bucket_name, _, key = s3_path.partition('/') return bucket_name, key +def process_document(s3_path, entries, output_dir): + """ + Processes a single document: + - Downloads the PDF + - Validates and assembles text + - Writes the output JSON if successful + - Returns processing results for aggregation + """ + try: + # Download the PDF locally + pdf_local_path = cached_path(s3_path, quiet=True) + pdf = PdfReader(pdf_local_path) + total_pages_in_pdf = len(pdf.pages) + except Exception as e: + logging.error(f"Error downloading or reading PDF {s3_path}: {e}") + return { + 'processed': 1, + 'successful_documents': 0, + 'successful_pages': 0, + 'total_pages': 0 + } + + # Build mapping from pagenum to entry + entry_by_pagenum = {entry.pagenum: entry for entry in entries} + + valid_entries = [] + missing_pages = [] + errors = [] + + # Iterate from 1 to total_pages_in_pdf inclusive + for page_num in range(1, total_pages_in_pdf + 1): + entry = entry_by_pagenum.get(page_num) + if entry is None: + missing_pages.append(page_num) + elif entry.error is not None or entry.finish_reason != 'stop': + errors.append(entry) + else: + valid_entries.append(entry) + + if not missing_pages and not errors: + # Assemble text + valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum) + text = '\n'.join(entry.text for entry in valid_entries_sorted if entry.text) + + # Generate a filename based on the s3_path + doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest() + output_filename = os.path.join(output_dir, f'{doc_hash}.json') + + output_data = { + 'source': s3_path, + 'total_pages': total_pages_in_pdf, + 'text': text + } + + try: + with open(output_filename, 'w') as f_out: + json.dump(output_data, f_out) + return { + 'processed': 1, + 'successful_documents': 1, + 'successful_pages': len(valid_entries), + 'total_pages': total_pages_in_pdf + } + except Exception as e: + logging.error(f"Error writing output file {output_filename}: {e}") + return { + 'processed': 1, + 'successful_documents': 0, + 'successful_pages': 0, + 'total_pages': total_pages_in_pdf + } + else: + missing = [page for page in missing_pages] + error_pages = [e.pagenum for e in errors] + logging.info(f'Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}') + return { + 'processed': 1, + 'successful_documents': 0, + 'successful_pages': len(valid_entries), + 'total_pages': total_pages_in_pdf + } + def main(): parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs') parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files') parser.add_argument('--output_dir', default='output', help='Directory to save the output files') + parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of worker threads') args = parser.parse_args() # Set up logging @@ -135,59 +218,28 @@ def main(): total_pages = 0 successful_pages = 0 - print("Processing documents...") - for s3_path, entries in tqdm(documents.items()): - try: - # Download the PDF locally - pdf_local_path = cached_path(s3_path, quiet=True) + print("Processing documents with ThreadPoolExecutor...") + with ThreadPoolExecutor(max_workers=args.max_workers) as executor: + # Prepare futures + future_to_s3 = { + executor.submit( + process_document, + s3_path, + entries, + args.output_dir + ): s3_path for s3_path, entries in documents.items() + } - pdf = PdfReader(pdf_local_path) - total_pages_in_pdf = len(pdf.pages) - except Exception as e: - logging.error(f"Error downloading or reading PDF {s3_path}: {e}") - continue - - total_pages += total_pages_in_pdf - - # Build mapping from pagenum to entry - entry_by_pagenum = {entry.pagenum: entry for entry in entries} - - valid_entries = [] - missing_pages = [] - errors = [] - - for page_num in range(total_pages_in_pdf): - entry = entry_by_pagenum.get(page_num) - if entry is None: - missing_pages.append(page_num) - elif entry.error is not None or entry.finish_reason != 'stop': - errors.append(entry) - else: - valid_entries.append(entry) - - successful_pages += len(valid_entries) - - if not missing_pages and not errors: - # Assemble text - valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum) - text = '\n'.join(entry.text for entry in valid_entries_sorted) - - # Generate a filename based on the s3_path - doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest() - output_filename = os.path.join(args.output_dir, f'{doc_hash}.json') - - output_data = { - 'source': s3_path, - 'total_pages': total_pages_in_pdf, - 'text': text - } - - with open(output_filename, 'w') as f_out: - json.dump(output_data, f_out) - - successful_documents += 1 - else: - logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}') + # Use tqdm to display progress + for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)): + try: + result = future.result() + successful_documents += result.get('successful_documents', 0) + successful_pages += result.get('successful_pages', 0) + total_pages += result.get('total_pages', 0) + except Exception as e: + s3_path = future_to_s3[future] + logging.error(f"Error processing document {s3_path}: {e}") print(f'Total documents: {total_documents}') print(f'Successful documents: {successful_documents}')