diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index 5b04a92..61084bf 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -36,11 +36,16 @@ logging.basicConfig(level=logging.INFO) # Quiet logs from pypdf logging.getLogger("pypdf").setLevel(logging.ERROR) -# Global s3 client for the whole script, feel free to adjust params if you need it +# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts workspace_s3 = boto3.client('s3') pdf_s3 = boto3.client('s3') -MAX_TOKENS = 3000 +# Global variables for token statistics +total_input_tokens = 0 +total_output_tokens = 0 +process_start_time = time.perf_counter() +last_batch_time = process_start_time + @dataclass(frozen=True) class PageResult: @@ -48,8 +53,12 @@ class PageResult: page_num: int response: PageResponse + total_input_tokens: int + total_output_tokens: int + async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict: + MAX_TOKENS = 3000 assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query" # Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread) @@ -216,7 +225,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"]) page_response = PageResponse(**model_response_json) - return PageResult(pdf_s3_path, page_num, page_response) + return PageResult(pdf_s3_path, page_num, page_response, + total_input_tokens=base_response_data["usage"].get("prompt_tokens", 0), + total_output_tokens=base_response_data["usage"].get("completion_tokens", 0)) except Exception as e: logger.exception(f"Exception while processing page {page_num}: {e}") raise @@ -250,7 +261,7 @@ async def process_pdf(args, pdf_s3_path: str): # Build the document text and page spans - document_text = '' + document_text = "" pdf_page_spans = [] current_char_pos = 0 @@ -264,7 +275,7 @@ async def process_pdf(args, pdf_s3_path: str): document_text += content current_char_pos = len(document_text) pdf_page_spans.append({ - 'pdf_page_number': page_num, + 'pdf_page_number': page_result.page_num, 'start_char': start_pos, 'end_char': current_char_pos }) @@ -276,6 +287,8 @@ async def process_pdf(args, pdf_s3_path: str): metadata = { "Source-File": pdf_s3_path, "pdf-total-pages": num_pages, + "total-input-tokens": sum(page.total_input_tokens for page in page_results), + "total-output-tokens": sum(page.total_output_tokens for page in page_results) } id_ = hashlib.sha1(document_text.encode()).hexdigest() @@ -297,16 +310,45 @@ async def process_pdf(args, pdf_s3_path: str): async def worker(args, queue): while True: + [work_hash, pdfs] = await queue.get() - - completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs]) - - # Take all the not None completed_pdfs and write them as a jsonl to the workspace output location - # under the proper work_hash location - for dolma_doc in completed_pdfs: - pass - - queue.task_done() + + try: + dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs]) + dolma_docs = [doc for doc in dolma_docs if doc is not None] + + # Write the Dolma documents to a local temporary file in JSONL format + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf: + for doc in dolma_docs: + tf.write(json.dumps(doc)) + tf.write('\n') + tf.flush() + + # Define the output S3 path using the work_hash + output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl') + + bucket, key = parse_s3_path(output_s3_path) + workspace_s3.upload_file(tf.name, bucket, key) + + # Sum up stats and report them since the last batch finished + global total_input_tokens, total_output_tokens, last_batch_time + batch_input_tokens = sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs) + batch_output_tokens = sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs) + batch_time = time.perf_counter() - last_batch_time + logger.info(f"Tokens per second (since last batch): input {batch_input_tokens / batch_time:.1f}, output {batch_output_tokens / batch_time:.1f}, total {(batch_input_tokens + batch_output_tokens) / batch_time:.1f}") + + # Print statistics since process start + total_input_tokens += batch_input_tokens + total_output_tokens += batch_output_tokens + total_time = time.perf_counter() - process_start_time + logger.info(f"Tokens per second (since process start): input {total_input_tokens / total_time:.1f}, output {total_output_tokens / total_time:.1f}, total {(total_input_tokens + total_output_tokens) / total_time:.1f}") + + # Update last batch time + last_batch_time = current_time + except Exception as e: + logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}") + finally: + queue.task_done() async def sglang_server_task(args):