mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-18 22:01:56 +00:00
Refactoring to assemble docs
This commit is contained in:
parent
da1b23fc47
commit
9fb464c654
@ -18,6 +18,8 @@ from tqdm import tqdm
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
@ -39,6 +41,12 @@ pdf_s3 = boto3.client('s3')
|
||||
|
||||
MAX_TOKENS = 3000
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PageResult:
|
||||
s3_path: str
|
||||
page_num: int
|
||||
response: PageResponse
|
||||
|
||||
|
||||
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
||||
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
||||
@ -188,7 +196,7 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
|
||||
return queue
|
||||
|
||||
|
||||
async def process_page(session, pdf_path, page_num, args) -> PageResponse:
|
||||
async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
|
||||
|
||||
query = await build_page_query(
|
||||
@ -206,16 +214,21 @@ async def process_page(session, pdf_path, page_num, args) -> PageResponse:
|
||||
|
||||
try:
|
||||
base_response_data = await response.json()
|
||||
model_response_json = json.loads(base_response_data["outputs"][0]["text"])
|
||||
|
||||
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
|
||||
page_response = PageResponse(**model_response_json)
|
||||
|
||||
return PageResult(pdf_s3_path, page_num, page_response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse response for {pdf_path}-{page_num}")
|
||||
logger.warning(f"Could not parse response for {pdf_path}-{page_num}, reason: {e}")
|
||||
|
||||
raise ValueError("Could not process page")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception while processing page {page_num}: {e}")
|
||||
return None
|
||||
raise
|
||||
|
||||
|
||||
async def process_pdf(args, pdf_s3_path):
|
||||
async def process_pdf(args, pdf_s3_path: str):
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
# TODO Switch to aioboto3 or something
|
||||
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
|
||||
@ -231,29 +244,27 @@ async def process_pdf(args, pdf_s3_path):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for page_num in range(1, num_pages + 1):
|
||||
# Create a task for each page
|
||||
task = asyncio.create_task(process_page(session, tf.name, page_num, args))
|
||||
task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num))
|
||||
page_tasks.append(task)
|
||||
|
||||
# Gather results from all page processing tasks
|
||||
page_results = await asyncio.gather(*page_tasks)
|
||||
try:
|
||||
page_results: list[PageResult] = await asyncio.gather(*page_tasks)
|
||||
except:
|
||||
logger.warning(f"Could not load page for {pdf_s3_path}, aborting document")
|
||||
return None
|
||||
|
||||
# If we failed to build a page, then this document is toast
|
||||
# TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf
|
||||
if any(page is None for page in page_results):
|
||||
logger.warning(f"PDF {pdf_s3_path} was not able to complete, not able to process a page")
|
||||
return None
|
||||
|
||||
# Build the document text and page spans
|
||||
document_text = ''
|
||||
pdf_page_spans = []
|
||||
current_char_pos = 0
|
||||
|
||||
for page_num, result in page_data:
|
||||
try:
|
||||
content = result['choices'][0]['message']['content']
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(f"Failed to extract content for page {page_num}: {e}")
|
||||
continue
|
||||
for index, page_result in enumerate(page_results):
|
||||
if page_result.response.natural_text is not None:
|
||||
content = page_result.response.natural_text + ("\n" if index == len(page_results) - 1 else "")
|
||||
else:
|
||||
content = ""
|
||||
|
||||
start_pos = current_char_pos
|
||||
document_text += content
|
||||
@ -298,6 +309,8 @@ async def worker(args, queue):
|
||||
|
||||
# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
|
||||
# under the proper work_hash location
|
||||
for dolma_doc in completed_pdfs:
|
||||
logger.info("Done!", dolma_doc)
|
||||
|
||||
queue.task_done()
|
||||
|
||||
@ -330,8 +343,6 @@ async def sglang_server_task(args):
|
||||
# Make really sure we kill this subprocess on exit
|
||||
def _kill_proc():
|
||||
proc.terminate()
|
||||
time.sleep(3)
|
||||
proc.kill()
|
||||
|
||||
atexit.register(_kill_proc)
|
||||
|
||||
@ -390,6 +401,7 @@ async def main():
|
||||
pdf_s3 = pdf_session.client("s3")
|
||||
|
||||
check_poppler_version()
|
||||
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
||||
|
||||
if args.pdfs:
|
||||
await populate_pdf_work_queue(args)
|
||||
|
@ -15,7 +15,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
|
||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class PageResponse:
|
||||
primary_language: Optional[str]
|
||||
is_rotation_valid: bool
|
||||
|
Loading…
x
Reference in New Issue
Block a user