mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-19 06:12:23 +00:00
Refactoring to assemble docs
This commit is contained in:
parent
da1b23fc47
commit
9fb464c654
@ -18,6 +18,8 @@ from tqdm import tqdm
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
|
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
|
||||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||||
@ -39,6 +41,12 @@ pdf_s3 = boto3.client('s3')
|
|||||||
|
|
||||||
MAX_TOKENS = 3000
|
MAX_TOKENS = 3000
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PageResult:
|
||||||
|
s3_path: str
|
||||||
|
page_num: int
|
||||||
|
response: PageResponse
|
||||||
|
|
||||||
|
|
||||||
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
||||||
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
||||||
@ -188,7 +196,7 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
|
|||||||
return queue
|
return queue
|
||||||
|
|
||||||
|
|
||||||
async def process_page(session, pdf_path, page_num, args) -> PageResponse:
|
async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||||
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
|
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
|
||||||
|
|
||||||
query = await build_page_query(
|
query = await build_page_query(
|
||||||
@ -206,16 +214,21 @@ async def process_page(session, pdf_path, page_num, args) -> PageResponse:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
base_response_data = await response.json()
|
base_response_data = await response.json()
|
||||||
model_response_json = json.loads(base_response_data["outputs"][0]["text"])
|
|
||||||
|
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
|
||||||
page_response = PageResponse(**model_response_json)
|
page_response = PageResponse(**model_response_json)
|
||||||
|
|
||||||
|
return PageResult(pdf_s3_path, page_num, page_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not parse response for {pdf_path}-{page_num}")
|
logger.warning(f"Could not parse response for {pdf_path}-{page_num}, reason: {e}")
|
||||||
|
|
||||||
|
raise ValueError("Could not process page")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception while processing page {page_num}: {e}")
|
logger.error(f"Exception while processing page {page_num}: {e}")
|
||||||
return None
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def process_pdf(args, pdf_s3_path):
|
async def process_pdf(args, pdf_s3_path: str):
|
||||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||||
# TODO Switch to aioboto3 or something
|
# TODO Switch to aioboto3 or something
|
||||||
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
|
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
|
||||||
@ -231,29 +244,27 @@ async def process_pdf(args, pdf_s3_path):
|
|||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
for page_num in range(1, num_pages + 1):
|
for page_num in range(1, num_pages + 1):
|
||||||
# Create a task for each page
|
# Create a task for each page
|
||||||
task = asyncio.create_task(process_page(session, tf.name, page_num, args))
|
task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num))
|
||||||
page_tasks.append(task)
|
page_tasks.append(task)
|
||||||
|
|
||||||
# Gather results from all page processing tasks
|
# Gather results from all page processing tasks
|
||||||
page_results = await asyncio.gather(*page_tasks)
|
try:
|
||||||
|
page_results: list[PageResult] = await asyncio.gather(*page_tasks)
|
||||||
|
except:
|
||||||
|
logger.warning(f"Could not load page for {pdf_s3_path}, aborting document")
|
||||||
|
return None
|
||||||
|
|
||||||
# If we failed to build a page, then this document is toast
|
|
||||||
# TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf
|
|
||||||
if any(page is None for page in page_results):
|
|
||||||
logger.warning(f"PDF {pdf_s3_path} was not able to complete, not able to process a page")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Build the document text and page spans
|
# Build the document text and page spans
|
||||||
document_text = ''
|
document_text = ''
|
||||||
pdf_page_spans = []
|
pdf_page_spans = []
|
||||||
current_char_pos = 0
|
current_char_pos = 0
|
||||||
|
|
||||||
for page_num, result in page_data:
|
for index, page_result in enumerate(page_results):
|
||||||
try:
|
if page_result.response.natural_text is not None:
|
||||||
content = result['choices'][0]['message']['content']
|
content = page_result.response.natural_text + ("\n" if index == len(page_results) - 1 else "")
|
||||||
except (KeyError, IndexError) as e:
|
else:
|
||||||
logger.error(f"Failed to extract content for page {page_num}: {e}")
|
content = ""
|
||||||
continue
|
|
||||||
|
|
||||||
start_pos = current_char_pos
|
start_pos = current_char_pos
|
||||||
document_text += content
|
document_text += content
|
||||||
@ -298,6 +309,8 @@ async def worker(args, queue):
|
|||||||
|
|
||||||
# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
|
# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
|
||||||
# under the proper work_hash location
|
# under the proper work_hash location
|
||||||
|
for dolma_doc in completed_pdfs:
|
||||||
|
logger.info("Done!", dolma_doc)
|
||||||
|
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
|
|
||||||
@ -330,8 +343,6 @@ async def sglang_server_task(args):
|
|||||||
# Make really sure we kill this subprocess on exit
|
# Make really sure we kill this subprocess on exit
|
||||||
def _kill_proc():
|
def _kill_proc():
|
||||||
proc.terminate()
|
proc.terminate()
|
||||||
time.sleep(3)
|
|
||||||
proc.kill()
|
|
||||||
|
|
||||||
atexit.register(_kill_proc)
|
atexit.register(_kill_proc)
|
||||||
|
|
||||||
@ -390,6 +401,7 @@ async def main():
|
|||||||
pdf_s3 = pdf_session.client("s3")
|
pdf_s3 = pdf_session.client("s3")
|
||||||
|
|
||||||
check_poppler_version()
|
check_poppler_version()
|
||||||
|
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
||||||
|
|
||||||
if args.pdfs:
|
if args.pdfs:
|
||||||
await populate_pdf_work_queue(args)
|
await populate_pdf_work_queue(args)
|
||||||
|
@ -15,7 +15,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
|
|||||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class PageResponse:
|
class PageResponse:
|
||||||
primary_language: Optional[str]
|
primary_language: Optional[str]
|
||||||
is_rotation_valid: bool
|
is_rotation_valid: bool
|
||||||
|
Loading…
x
Reference in New Issue
Block a user