Refactoring to assemble docs

This commit is contained in:
Jake Poznanski 2024-11-11 11:46:49 -08:00
parent da1b23fc47
commit 9fb464c654
2 changed files with 33 additions and 21 deletions

View File

@ -18,6 +18,8 @@ from tqdm import tqdm
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from pypdf import PdfReader from pypdf import PdfReader
from dataclasses import dataclass
from typing import Optional
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
from pdelfin.data.renderpdf import render_pdf_to_base64png from pdelfin.data.renderpdf import render_pdf_to_base64png
@ -39,6 +41,12 @@ pdf_s3 = boto3.client('s3')
MAX_TOKENS = 3000 MAX_TOKENS = 3000
@dataclass(frozen=True)
class PageResult:
s3_path: str
page_num: int
response: PageResponse
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict: async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query" assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
@ -188,7 +196,7 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
return queue return queue
async def process_page(session, pdf_path, page_num, args) -> PageResponse: async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = "http://localhost:30000/v1/chat/completions" COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
query = await build_page_query( query = await build_page_query(
@ -206,16 +214,21 @@ async def process_page(session, pdf_path, page_num, args) -> PageResponse:
try: try:
base_response_data = await response.json() base_response_data = await response.json()
model_response_json = json.loads(base_response_data["outputs"][0]["text"])
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json) page_response = PageResponse(**model_response_json)
return PageResult(pdf_s3_path, page_num, page_response)
except Exception as e: except Exception as e:
logger.warning(f"Could not parse response for {pdf_path}-{page_num}") logger.warning(f"Could not parse response for {pdf_path}-{page_num}, reason: {e}")
raise ValueError("Could not process page")
except Exception as e: except Exception as e:
logger.error(f"Exception while processing page {page_num}: {e}") logger.error(f"Exception while processing page {page_num}: {e}")
return None raise
async def process_pdf(args, pdf_s3_path): async def process_pdf(args, pdf_s3_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something # TODO Switch to aioboto3 or something
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path)) data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
@ -231,29 +244,27 @@ async def process_pdf(args, pdf_s3_path):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
for page_num in range(1, num_pages + 1): for page_num in range(1, num_pages + 1):
# Create a task for each page # Create a task for each page
task = asyncio.create_task(process_page(session, tf.name, page_num, args)) task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num))
page_tasks.append(task) page_tasks.append(task)
# Gather results from all page processing tasks # Gather results from all page processing tasks
page_results = await asyncio.gather(*page_tasks) try:
page_results: list[PageResult] = await asyncio.gather(*page_tasks)
# If we failed to build a page, then this document is toast except:
# TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf logger.warning(f"Could not load page for {pdf_s3_path}, aborting document")
if any(page is None for page in page_results):
logger.warning(f"PDF {pdf_s3_path} was not able to complete, not able to process a page")
return None return None
# Build the document text and page spans # Build the document text and page spans
document_text = '' document_text = ''
pdf_page_spans = [] pdf_page_spans = []
current_char_pos = 0 current_char_pos = 0
for page_num, result in page_data: for index, page_result in enumerate(page_results):
try: if page_result.response.natural_text is not None:
content = result['choices'][0]['message']['content'] content = page_result.response.natural_text + ("\n" if index == len(page_results) - 1 else "")
except (KeyError, IndexError) as e: else:
logger.error(f"Failed to extract content for page {page_num}: {e}") content = ""
continue
start_pos = current_char_pos start_pos = current_char_pos
document_text += content document_text += content
@ -298,6 +309,8 @@ async def worker(args, queue):
# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location # Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
# under the proper work_hash location # under the proper work_hash location
for dolma_doc in completed_pdfs:
logger.info("Done!", dolma_doc)
queue.task_done() queue.task_done()
@ -330,8 +343,6 @@ async def sglang_server_task(args):
# Make really sure we kill this subprocess on exit # Make really sure we kill this subprocess on exit
def _kill_proc(): def _kill_proc():
proc.terminate() proc.terminate()
time.sleep(3)
proc.kill()
atexit.register(_kill_proc) atexit.register(_kill_proc)
@ -390,6 +401,7 @@ async def main():
pdf_s3 = pdf_session.client("s3") pdf_s3 = pdf_session.client("s3")
check_poppler_version() check_poppler_version()
logger.info(f"Starting pipeline with PID {os.getpid()}")
if args.pdfs: if args.pdfs:
await populate_pdf_work_queue(args) await populate_pdf_work_queue(args)

View File

@ -15,7 +15,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
) )
@dataclass @dataclass(frozen=True)
class PageResponse: class PageResponse:
primary_language: Optional[str] primary_language: Optional[str]
is_rotation_valid: bool is_rotation_valid: bool