Refactoring to assemble docs

This commit is contained in:
Jake Poznanski 2024-11-11 11:46:49 -08:00
parent da1b23fc47
commit 9fb464c654
2 changed files with 33 additions and 21 deletions

View File

@ -18,6 +18,8 @@ from tqdm import tqdm
from io import BytesIO
from PIL import Image
from pypdf import PdfReader
from dataclasses import dataclass
from typing import Optional
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
from pdelfin.data.renderpdf import render_pdf_to_base64png
@ -39,6 +41,12 @@ pdf_s3 = boto3.client('s3')
MAX_TOKENS = 3000
@dataclass(frozen=True)
class PageResult:
s3_path: str
page_num: int
response: PageResponse
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
@ -188,7 +196,7 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
return queue
async def process_page(session, pdf_path, page_num, args) -> PageResponse:
async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
query = await build_page_query(
@ -206,16 +214,21 @@ async def process_page(session, pdf_path, page_num, args) -> PageResponse:
try:
base_response_data = await response.json()
model_response_json = json.loads(base_response_data["outputs"][0]["text"])
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)
return PageResult(pdf_s3_path, page_num, page_response)
except Exception as e:
logger.warning(f"Could not parse response for {pdf_path}-{page_num}")
logger.warning(f"Could not parse response for {pdf_path}-{page_num}, reason: {e}")
raise ValueError("Could not process page")
except Exception as e:
logger.error(f"Exception while processing page {page_num}: {e}")
return None
raise
async def process_pdf(args, pdf_s3_path):
async def process_pdf(args, pdf_s3_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
@ -231,29 +244,27 @@ async def process_pdf(args, pdf_s3_path):
async with aiohttp.ClientSession() as session:
for page_num in range(1, num_pages + 1):
# Create a task for each page
task = asyncio.create_task(process_page(session, tf.name, page_num, args))
task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)
# Gather results from all page processing tasks
page_results = await asyncio.gather(*page_tasks)
# If we failed to build a page, then this document is toast
# TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf
if any(page is None for page in page_results):
logger.warning(f"PDF {pdf_s3_path} was not able to complete, not able to process a page")
try:
page_results: list[PageResult] = await asyncio.gather(*page_tasks)
except:
logger.warning(f"Could not load page for {pdf_s3_path}, aborting document")
return None
# Build the document text and page spans
document_text = ''
pdf_page_spans = []
current_char_pos = 0
for page_num, result in page_data:
try:
content = result['choices'][0]['message']['content']
except (KeyError, IndexError) as e:
logger.error(f"Failed to extract content for page {page_num}: {e}")
continue
for index, page_result in enumerate(page_results):
if page_result.response.natural_text is not None:
content = page_result.response.natural_text + ("\n" if index == len(page_results) - 1 else "")
else:
content = ""
start_pos = current_char_pos
document_text += content
@ -298,6 +309,8 @@ async def worker(args, queue):
# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
# under the proper work_hash location
for dolma_doc in completed_pdfs:
logger.info("Done!", dolma_doc)
queue.task_done()
@ -330,8 +343,6 @@ async def sglang_server_task(args):
# Make really sure we kill this subprocess on exit
def _kill_proc():
proc.terminate()
time.sleep(3)
proc.kill()
atexit.register(_kill_proc)
@ -390,6 +401,7 @@ async def main():
pdf_s3 = pdf_session.client("s3")
check_poppler_version()
logger.info(f"Starting pipeline with PID {os.getpid()}")
if args.pdfs:
await populate_pdf_work_queue(args)

View File

@ -15,7 +15,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
@dataclass
@dataclass(frozen=True)
class PageResponse:
primary_language: Optional[str]
is_rotation_valid: bool