diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index 7567d6c..895437f 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -18,6 +18,8 @@ from tqdm import tqdm from io import BytesIO from PIL import Image from pypdf import PdfReader +from dataclasses import dataclass +from typing import Optional from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory from pdelfin.data.renderpdf import render_pdf_to_base64png @@ -39,6 +41,12 @@ pdf_s3 = boto3.client('s3') MAX_TOKENS = 3000 +@dataclass(frozen=True) +class PageResult: + s3_path: str + page_num: int + response: PageResponse + async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict: assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query" @@ -188,7 +196,7 @@ async def load_pdf_work_queue(args) -> asyncio.Queue: return queue -async def process_page(session, pdf_path, page_num, args) -> PageResponse: +async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult: COMPLETION_URL = "http://localhost:30000/v1/chat/completions" query = await build_page_query( @@ -206,16 +214,21 @@ async def process_page(session, pdf_path, page_num, args) -> PageResponse: try: base_response_data = await response.json() - model_response_json = json.loads(base_response_data["outputs"][0]["text"]) + + model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"]) page_response = PageResponse(**model_response_json) + + return PageResult(pdf_s3_path, page_num, page_response) except Exception as e: - logger.warning(f"Could not parse response for {pdf_path}-{page_num}") + logger.warning(f"Could not parse response for {pdf_path}-{page_num}, reason: {e}") + + raise ValueError("Could not process page") except Exception as e: logger.error(f"Exception while processing page {page_num}: {e}") - return None + raise -async def process_pdf(args, pdf_s3_path): +async def process_pdf(args, pdf_s3_path: str): with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: # TODO Switch to aioboto3 or something data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path)) @@ -231,29 +244,27 @@ async def process_pdf(args, pdf_s3_path): async with aiohttp.ClientSession() as session: for page_num in range(1, num_pages + 1): # Create a task for each page - task = asyncio.create_task(process_page(session, tf.name, page_num, args)) + task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num)) page_tasks.append(task) # Gather results from all page processing tasks - page_results = await asyncio.gather(*page_tasks) + try: + page_results: list[PageResult] = await asyncio.gather(*page_tasks) + except: + logger.warning(f"Could not load page for {pdf_s3_path}, aborting document") + return None - # If we failed to build a page, then this document is toast - # TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf - if any(page is None for page in page_results): - logger.warning(f"PDF {pdf_s3_path} was not able to complete, not able to process a page") - return None # Build the document text and page spans document_text = '' pdf_page_spans = [] current_char_pos = 0 - for page_num, result in page_data: - try: - content = result['choices'][0]['message']['content'] - except (KeyError, IndexError) as e: - logger.error(f"Failed to extract content for page {page_num}: {e}") - continue + for index, page_result in enumerate(page_results): + if page_result.response.natural_text is not None: + content = page_result.response.natural_text + ("\n" if index == len(page_results) - 1 else "") + else: + content = "" start_pos = current_char_pos document_text += content @@ -298,6 +309,8 @@ async def worker(args, queue): # Take all the not None completed_pdfs and write them as a jsonl to the workspace output location # under the proper work_hash location + for dolma_doc in completed_pdfs: + logger.info("Done!", dolma_doc) queue.task_done() @@ -330,8 +343,6 @@ async def sglang_server_task(args): # Make really sure we kill this subprocess on exit def _kill_proc(): proc.terminate() - time.sleep(3) - proc.kill() atexit.register(_kill_proc) @@ -390,6 +401,7 @@ async def main(): pdf_s3 = pdf_session.client("s3") check_poppler_version() + logger.info(f"Starting pipeline with PID {os.getpid()}") if args.pdfs: await populate_pdf_work_queue(args) diff --git a/pdelfin/prompts/prompts.py b/pdelfin/prompts/prompts.py index ed3c2b2..84439e8 100644 --- a/pdelfin/prompts/prompts.py +++ b/pdelfin/prompts/prompts.py @@ -15,7 +15,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str: f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" ) -@dataclass +@dataclass(frozen=True) class PageResponse: primary_language: Optional[str] is_rotation_valid: bool