From cd8e28e459e776b67a4e9038ee47612a4f3a7f73 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 14 Oct 2024 18:19:17 +0000 Subject: [PATCH] Pipeline working hopefully soon --- pdelfin/assemblepipeline.py | 57 +++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/pdelfin/assemblepipeline.py b/pdelfin/assemblepipeline.py index 37494bc..930ae9a 100644 --- a/pdelfin/assemblepipeline.py +++ b/pdelfin/assemblepipeline.py @@ -27,7 +27,8 @@ s3 = boto3.client('s3') class DatabaseManager: @dataclass(frozen=True) class BatchInferenceRecord: - s3_path: str + inference_s3_path: str + pdf_s3_path: str page_num: int # 1 indexed! start_index: int length: int @@ -56,7 +57,8 @@ class DatabaseManager: def _initialize_tables(self): self.cursor.execute(""" CREATE TABLE IF NOT EXISTS page_results ( - s3_path TEXT, + inference_s3_path TEXT, + pdf_s3_path TEXT, page_num INTEGER, start_index BIGINT, length BIGINT, @@ -115,29 +117,30 @@ class DatabaseManager: def add_index_entries(self, index_entries: List[BatchInferenceRecord]): if index_entries: self.cursor.executemany(""" - INSERT INTO page_results (s3_path, page_num, start_index, length, finish_reason, error) + INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error) VALUES (?, ?, ?, ?, ?, ?) - """, [(entry.s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries]) + """, [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries]) self.conn.commit() def get_index_entries(self, s3_path: str) -> List[BatchInferenceRecord]: self.cursor.execute(""" - SELECT s3_path, page_num, start_index, length, finish_reason, error + SELECT inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error FROM page_results WHERE s3_path = ? - ORDER BY page_num ASC + ORDER BY inference_s3_path DESC start_index ASC page_num ASC """, (s3_path,)) rows = self.cursor.fetchall() return [ self.BatchInferenceRecord( - s3_path=row[0], - page_num=row[1], - start_index=row[2], - length=row[3], - finish_reason=row[4], - error=row[5] + inference_s3_path=row[0], + pdf_s3_path=row[1], + page_num=row[2], + start_index=row[3], + length=row[4], + finish_reason=row[5], + error=row[6] ) for row in rows ] @@ -186,6 +189,7 @@ class DatabaseManager: SELECT s3_path, num_pages, status FROM pdfs WHERE status == ? + ORDER BY s3_path DESC """, (status, )) rows = self.cursor.fetchall() @@ -334,8 +338,8 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]: page_num = int(custom_id[custom_id.rindex("-") + 1:]) return s3_path, page_num -def process_jsonl_content(s3_path) -> List[DatabaseManager.BatchInferenceRecord]: - content = get_s3_bytes(s3_path).decode("utf-8") +def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]: + content = get_s3_bytes(inference_s3_path).decode("utf-8") start_index = 0 index_entries = [] @@ -345,12 +349,13 @@ def process_jsonl_content(s3_path) -> List[DatabaseManager.BatchInferenceRecord] try: data = json.loads(line) - s3_path, page_num = parse_custom_id(data["custom_id"]) + pdf_s3_path, page_num = parse_custom_id(data["custom_id"]) assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected" index_entries.append(DatabaseManager.BatchInferenceRecord( - s3_path=s3_path, + inference_s3_path=inference_s3_path, + pdf_s3_path=pdf_s3_path, page_num=page_num, start_index=start_index, length=line_length, @@ -410,17 +415,22 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict: db = DatabaseManager(s3_workspace) existing_pages = db.get_index_entries(pdf.s3_path) document_text = "" + last_page_start_index = 0 + pdf_page_spans = [] for target_page_num in range(1, pdf.num_pages + 1): target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num) - target_row = get_s3_bytes(target_page.s3_path, - start_index=target_page.start_index, - end_index=target_page.start_index+target_page.length) + target_row = get_s3_bytes(target_page.pdf_s3_path, + start_index=target_page.start_index, + end_index=target_page.start_index+target_page.length) target_data = json.loads(target_row.decode("utf-8")) - document_text += target_data + "\n" + document_text += target_data["natural_text"] + "\n" + + pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num]) + last_page_start_index = len(document_text) metadata = { "Source-File": pdf.s3_path, @@ -431,10 +441,13 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict: dolma_doc = { "id": id_, "text": document_text, - "source": "s2pdf", + "source": "pdelfin", "added": datetime.datetime.now().strftime("%Y-%m-%d"), "created": datetime.datetime.now().strftime("%Y-%m-%d"), "metadata": metadata, + "attributes": { + "pdf_page_numbers": pdf_page_spans + } } return dolma_doc @@ -500,7 +513,7 @@ if __name__ == '__main__': future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")} potentially_done_pdfs = [] lines_written = 0 - new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb) + new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{db.get_current_round()}", args.max_size_mb) for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): pdf = future_to_path[future]