diff --git a/pdelfin/assemblepipeline.py b/pdelfin/assemblepipeline.py new file mode 100644 index 0000000..7af3a47 --- /dev/null +++ b/pdelfin/assemblepipeline.py @@ -0,0 +1,198 @@ +import argparse +import os +import json +import hashlib +import logging +from collections import defaultdict +from typing import Optional +from tqdm import tqdm +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed + +import boto3 +from pypdf import PdfReader +from cached_path import cached_path +from smart_open import smart_open + +from dataclasses import dataclass + +# Import your existing modules if necessary +# from dolma_refine.evaluate.metrics import DocumentEditSimilarity +# from dolma_refine.evaluate.segmenters import SpacySegmenter +# from dolma_refine.evaluate.aligners import HirschbergAligner + +@dataclass(frozen=True) +class NormalizedEntry: + s3_path: str + pagenum: int + text: Optional[str] + finish_reason: Optional[str] + error: Optional[str] = None + + @staticmethod + def from_goldkey(goldkey: str, **kwargs): + s3_path = goldkey[:goldkey.rindex("-")] + page_num = int(goldkey[goldkey.rindex("-") + 1:]) + return NormalizedEntry(s3_path, page_num, **kwargs) + + @property + def goldkey(self): + return f"{self.s3_path}-{self.pagenum}" + +def normalize_json_entry(data: dict) -> NormalizedEntry: + if "outputs" in data: + # Birr case + if data["outputs"] is None: + text = None + finish_reason = None + else: + text = data["outputs"][0]["text"] + finish_reason = data["outputs"][0]["finish_reason"] + + # Try to parse the structured output if possible + try: + if text is not None: + parsed_content = json.loads(text) + text = parsed_content["natural_text"] + except json.JSONDecodeError: + pass + + return NormalizedEntry.from_goldkey( + goldkey=data["custom_id"], + text=text, + finish_reason=finish_reason, + error=data.get("completion_error", None) + ) + else: + # OpenAI case + try: + # Attempt to parse the JSON content from OpenAI's response + parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"]) + return NormalizedEntry.from_goldkey( + goldkey=data["custom_id"], + text=parsed_content["natural_text"], + finish_reason=data["response"]["body"]["choices"][0]["finish_reason"] + ) + except json.JSONDecodeError: + # Fallback if content is not valid JSON + return NormalizedEntry.from_goldkey( + goldkey=data["custom_id"], + text=data["response"]["body"]["choices"][0]["message"]["content"], + finish_reason=data["response"]["body"]["choices"][0]["finish_reason"] + ) + +def parse_s3_path(s3_path): + if not s3_path.startswith('s3://'): + raise ValueError('Invalid S3 path') + s3_path = s3_path[5:] + bucket_name, _, key = s3_path.partition('/') + return bucket_name, key + +def main(): + parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs') + parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files') + parser.add_argument('--output_dir', default='output', help='Directory to save the output files') + args = parser.parse_args() + + # Set up logging + logging.basicConfig(filename='processing.log', level=logging.INFO, format='%(asctime)s %(message)s') + + os.makedirs(args.output_dir, exist_ok=True) + + # Initialize S3 client + s3 = boto3.client('s3') + bucket_name, prefix = parse_s3_path(args.s3_path) + + # List all .json and .jsonl files in the specified S3 path + paginator = s3.get_paginator('list_objects_v2') + page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) + + files = [] + for page in page_iterator: + if 'Contents' in page: + for obj in page['Contents']: + key = obj['Key'] + if key.endswith('.json') or key.endswith('.jsonl'): + files.append(key) + + # Build documents mapping + documents = defaultdict(list) + + print("Processing JSON files and building documents mapping...") + for key in tqdm(files): + file_s3_path = f's3://{bucket_name}/{key}' + try: + with smart_open(file_s3_path, 'r') as f: + for line in f: + data = json.loads(line) + entry = normalize_json_entry(data) + documents[entry.s3_path].append(entry) + except Exception as e: + logging.error(f"Error processing file {file_s3_path}: {e}") + + total_documents = len(documents) + successful_documents = 0 + total_pages = 0 + successful_pages = 0 + + print("Processing documents...") + for s3_path, entries in tqdm(documents.items()): + try: + # Download the PDF locally + pdf_local_path = cached_path(s3_path, quiet=True) + + pdf = PdfReader(pdf_local_path) + total_pages_in_pdf = len(pdf.pages) + except Exception as e: + logging.error(f"Error downloading or reading PDF {s3_path}: {e}") + continue + + total_pages += total_pages_in_pdf + + # Build mapping from pagenum to entry + entry_by_pagenum = {entry.pagenum: entry for entry in entries} + + valid_entries = [] + missing_pages = [] + errors = [] + + for page_num in range(total_pages_in_pdf): + entry = entry_by_pagenum.get(page_num) + if entry is None: + missing_pages.append(page_num) + elif entry.error is not None or entry.finish_reason != 'stop': + errors.append(entry) + else: + valid_entries.append(entry) + + successful_pages += len(valid_entries) + + if not missing_pages and not errors: + # Assemble text + valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum) + text = '\n'.join(entry.text for entry in valid_entries_sorted) + + # Generate a filename based on the s3_path + doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest() + output_filename = os.path.join(args.output_dir, f'{doc_hash}.json') + + output_data = { + 'source': s3_path, + 'total_pages': total_pages_in_pdf, + 'text': text + } + + with open(output_filename, 'w') as f_out: + json.dump(output_data, f_out) + + successful_documents += 1 + else: + logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}') + + print(f'Total documents: {total_documents}') + print(f'Successful documents: {successful_documents}') + print(f'Total pages: {total_pages}') + print(f'Successful pages: {successful_pages}') + +if __name__ == '__main__': + main() diff --git a/pdelfin/eval/runeval.py b/pdelfin/eval/runeval.py index 8bcbb3d..d2510a8 100644 --- a/pdelfin/eval/runeval.py +++ b/pdelfin/eval/runeval.py @@ -285,6 +285,8 @@ def do_eval(gold_data_path: str, eval_data_path: str, review_page_name: str, rev # if len(pd["gold_text"]) < 200 and len(pd["eval_text"]) < 200: # continue + if "[Error processing this page: overrun" not in pd["eval_text"]: + continue page_eval_data.append(pd) diff --git a/pdelfin/extract_text.py b/pdelfin/extract_text.py deleted file mode 100644 index 5e7ed94..0000000 --- a/pdelfin/extract_text.py +++ /dev/null @@ -1,81 +0,0 @@ -import subprocess -from typing import Literal - -import pymupdf -import pypdfium2 as pdfium - - -def get_page_text( - local_pdf_path: str, page_num: int, pdf_engine: Literal["pdftotext", "pymupdf", "pdfium"] = "pdftotext" -) -> str: - if pdf_engine == "pdftotext": - pdftotext_result = subprocess.run( - [ - "pdftotext", - "-f", - str(page_num), - "-l", - str(page_num), - local_pdf_path, - "-", - ], - timeout=60, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - assert pdftotext_result.returncode == 0 - - return pdftotext_result.stdout.decode("utf-8") - elif pdf_engine == "pymupdf": - pm_doc = pymupdf.open(local_pdf_path) - return pm_doc[page_num - 1].get_text() - elif pdf_engine == "pdfium": - pdf = pdfium.PdfDocument(local_pdf_path, autoclose=True) - page = pdf[page_num - 1] - textpage = page.get_textpage() - - # Extract text from the whole page - result = textpage.get_text_range() - pdf.close() - return result - else: - raise NotImplementedError() - - -def get_document_text(local_pdf_path: str, pdf_engine: Literal["pdftotext", "pymupdf", "pdfium"] = "pdftotext") -> str: - if pdf_engine == "pdftotext": - pdftotext_result = subprocess.run( - [ - "pdftotext", - local_pdf_path, - "-", - ], - timeout=60, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - assert pdftotext_result.returncode == 0 - - return pdftotext_result.stdout.decode("utf-8") - elif pdf_engine == "pymupdf": - pm_doc = pymupdf.open(local_pdf_path) - result = "" - - for page in pm_doc: - result += page.get_text() - result += "\n" - - return result - elif pdf_engine == "pdfium": - pdf = pdfium.PdfDocument(local_pdf_path, autoclose=True) - result = "" - - for page in pdf: - textpage = page.get_textpage() - result += textpage.get_text_range() - result += "\n" - - pdf.close() - return result - else: - raise NotImplementedError() diff --git a/pdelfin/data/runpipeline.py b/pdelfin/runpipeline.py similarity index 89% rename from pdelfin/data/runpipeline.py rename to pdelfin/runpipeline.py index 57ba200..5452eb9 100644 --- a/pdelfin/data/runpipeline.py +++ b/pdelfin/runpipeline.py @@ -1,8 +1,19 @@ +# The way this script works is it gets a list of pdfs to process +# and an output/scratch folder location either locally or in s3 to work with +# On the first run, with an empty output folder, it will queue up each page in each pdf to go into a VLM +# Then, the user queues up that task in birr, and it outputs to a new subfolder in the same location +# Then, you run your script again, and it will see that you have some valid output files +# If so, then it will check those output files, and if it has a complete document, it will build a dolma doc for it, and that's considered done +# For any remaining pages that got errored out, or failed due to stop_reason not being "stop" (ex. over length) +# Then, it will queue up another set of tasks, hopefully much smaller, to send into batch inference again +# This process will keep going, until you run it with the --fallback option, at which point it will +# just use a basic text extraction on any remaining pages, and assemble the rest of the dolma docs +# +# +# import os import glob import random -import subprocess -import base64 import argparse import boto3 import json @@ -258,6 +269,7 @@ def main(): # Print the number of PDFs that resulted in at least one output print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}") + print(f"Now you should run these prompts through mise/birr") if __name__ == "__main__": main()