mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-25 16:30:28 +00:00
Taking notes, starting on document assembly
This commit is contained in:
parent
8e5809da71
commit
847064f46f
198
pdelfin/assemblepipeline.py
Normal file
198
pdelfin/assemblepipeline.py
Normal file
@ -0,0 +1,198 @@
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import boto3
|
||||
from pypdf import PdfReader
|
||||
from cached_path import cached_path
|
||||
from smart_open import smart_open
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Import your existing modules if necessary
|
||||
# from dolma_refine.evaluate.metrics import DocumentEditSimilarity
|
||||
# from dolma_refine.evaluate.segmenters import SpacySegmenter
|
||||
# from dolma_refine.evaluate.aligners import HirschbergAligner
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedEntry:
|
||||
s3_path: str
|
||||
pagenum: int
|
||||
text: Optional[str]
|
||||
finish_reason: Optional[str]
|
||||
error: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_goldkey(goldkey: str, **kwargs):
|
||||
s3_path = goldkey[:goldkey.rindex("-")]
|
||||
page_num = int(goldkey[goldkey.rindex("-") + 1:])
|
||||
return NormalizedEntry(s3_path, page_num, **kwargs)
|
||||
|
||||
@property
|
||||
def goldkey(self):
|
||||
return f"{self.s3_path}-{self.pagenum}"
|
||||
|
||||
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||
if "outputs" in data:
|
||||
# Birr case
|
||||
if data["outputs"] is None:
|
||||
text = None
|
||||
finish_reason = None
|
||||
else:
|
||||
text = data["outputs"][0]["text"]
|
||||
finish_reason = data["outputs"][0]["finish_reason"]
|
||||
|
||||
# Try to parse the structured output if possible
|
||||
try:
|
||||
if text is not None:
|
||||
parsed_content = json.loads(text)
|
||||
text = parsed_content["natural_text"]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return NormalizedEntry.from_goldkey(
|
||||
goldkey=data["custom_id"],
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
error=data.get("completion_error", None)
|
||||
)
|
||||
else:
|
||||
# OpenAI case
|
||||
try:
|
||||
# Attempt to parse the JSON content from OpenAI's response
|
||||
parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
|
||||
return NormalizedEntry.from_goldkey(
|
||||
goldkey=data["custom_id"],
|
||||
text=parsed_content["natural_text"],
|
||||
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if content is not valid JSON
|
||||
return NormalizedEntry.from_goldkey(
|
||||
goldkey=data["custom_id"],
|
||||
text=data["response"]["body"]["choices"][0]["message"]["content"],
|
||||
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
|
||||
)
|
||||
|
||||
def parse_s3_path(s3_path):
|
||||
if not s3_path.startswith('s3://'):
|
||||
raise ValueError('Invalid S3 path')
|
||||
s3_path = s3_path[5:]
|
||||
bucket_name, _, key = s3_path.partition('/')
|
||||
return bucket_name, key
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
|
||||
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
|
||||
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(filename='processing.log', level=logging.INFO, format='%(asctime)s %(message)s')
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Initialize S3 client
|
||||
s3 = boto3.client('s3')
|
||||
bucket_name, prefix = parse_s3_path(args.s3_path)
|
||||
|
||||
# List all .json and .jsonl files in the specified S3 path
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||
|
||||
files = []
|
||||
for page in page_iterator:
|
||||
if 'Contents' in page:
|
||||
for obj in page['Contents']:
|
||||
key = obj['Key']
|
||||
if key.endswith('.json') or key.endswith('.jsonl'):
|
||||
files.append(key)
|
||||
|
||||
# Build documents mapping
|
||||
documents = defaultdict(list)
|
||||
|
||||
print("Processing JSON files and building documents mapping...")
|
||||
for key in tqdm(files):
|
||||
file_s3_path = f's3://{bucket_name}/{key}'
|
||||
try:
|
||||
with smart_open(file_s3_path, 'r') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
entry = normalize_json_entry(data)
|
||||
documents[entry.s3_path].append(entry)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {file_s3_path}: {e}")
|
||||
|
||||
total_documents = len(documents)
|
||||
successful_documents = 0
|
||||
total_pages = 0
|
||||
successful_pages = 0
|
||||
|
||||
print("Processing documents...")
|
||||
for s3_path, entries in tqdm(documents.items()):
|
||||
try:
|
||||
# Download the PDF locally
|
||||
pdf_local_path = cached_path(s3_path, quiet=True)
|
||||
|
||||
pdf = PdfReader(pdf_local_path)
|
||||
total_pages_in_pdf = len(pdf.pages)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
||||
continue
|
||||
|
||||
total_pages += total_pages_in_pdf
|
||||
|
||||
# Build mapping from pagenum to entry
|
||||
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
|
||||
|
||||
valid_entries = []
|
||||
missing_pages = []
|
||||
errors = []
|
||||
|
||||
for page_num in range(total_pages_in_pdf):
|
||||
entry = entry_by_pagenum.get(page_num)
|
||||
if entry is None:
|
||||
missing_pages.append(page_num)
|
||||
elif entry.error is not None or entry.finish_reason != 'stop':
|
||||
errors.append(entry)
|
||||
else:
|
||||
valid_entries.append(entry)
|
||||
|
||||
successful_pages += len(valid_entries)
|
||||
|
||||
if not missing_pages and not errors:
|
||||
# Assemble text
|
||||
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
||||
text = '\n'.join(entry.text for entry in valid_entries_sorted)
|
||||
|
||||
# Generate a filename based on the s3_path
|
||||
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
|
||||
output_filename = os.path.join(args.output_dir, f'{doc_hash}.json')
|
||||
|
||||
output_data = {
|
||||
'source': s3_path,
|
||||
'total_pages': total_pages_in_pdf,
|
||||
'text': text
|
||||
}
|
||||
|
||||
with open(output_filename, 'w') as f_out:
|
||||
json.dump(output_data, f_out)
|
||||
|
||||
successful_documents += 1
|
||||
else:
|
||||
logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}')
|
||||
|
||||
print(f'Total documents: {total_documents}')
|
||||
print(f'Successful documents: {successful_documents}')
|
||||
print(f'Total pages: {total_pages}')
|
||||
print(f'Successful pages: {successful_pages}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -285,6 +285,8 @@ def do_eval(gold_data_path: str, eval_data_path: str, review_page_name: str, rev
|
||||
|
||||
# if len(pd["gold_text"]) < 200 and len(pd["eval_text"]) < 200:
|
||||
# continue
|
||||
if "[Error processing this page: overrun" not in pd["eval_text"]:
|
||||
continue
|
||||
|
||||
page_eval_data.append(pd)
|
||||
|
||||
|
@ -1,81 +0,0 @@
|
||||
import subprocess
|
||||
from typing import Literal
|
||||
|
||||
import pymupdf
|
||||
import pypdfium2 as pdfium
|
||||
|
||||
|
||||
def get_page_text(
|
||||
local_pdf_path: str, page_num: int, pdf_engine: Literal["pdftotext", "pymupdf", "pdfium"] = "pdftotext"
|
||||
) -> str:
|
||||
if pdf_engine == "pdftotext":
|
||||
pdftotext_result = subprocess.run(
|
||||
[
|
||||
"pdftotext",
|
||||
"-f",
|
||||
str(page_num),
|
||||
"-l",
|
||||
str(page_num),
|
||||
local_pdf_path,
|
||||
"-",
|
||||
],
|
||||
timeout=60,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
assert pdftotext_result.returncode == 0
|
||||
|
||||
return pdftotext_result.stdout.decode("utf-8")
|
||||
elif pdf_engine == "pymupdf":
|
||||
pm_doc = pymupdf.open(local_pdf_path)
|
||||
return pm_doc[page_num - 1].get_text()
|
||||
elif pdf_engine == "pdfium":
|
||||
pdf = pdfium.PdfDocument(local_pdf_path, autoclose=True)
|
||||
page = pdf[page_num - 1]
|
||||
textpage = page.get_textpage()
|
||||
|
||||
# Extract text from the whole page
|
||||
result = textpage.get_text_range()
|
||||
pdf.close()
|
||||
return result
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_document_text(local_pdf_path: str, pdf_engine: Literal["pdftotext", "pymupdf", "pdfium"] = "pdftotext") -> str:
|
||||
if pdf_engine == "pdftotext":
|
||||
pdftotext_result = subprocess.run(
|
||||
[
|
||||
"pdftotext",
|
||||
local_pdf_path,
|
||||
"-",
|
||||
],
|
||||
timeout=60,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
assert pdftotext_result.returncode == 0
|
||||
|
||||
return pdftotext_result.stdout.decode("utf-8")
|
||||
elif pdf_engine == "pymupdf":
|
||||
pm_doc = pymupdf.open(local_pdf_path)
|
||||
result = ""
|
||||
|
||||
for page in pm_doc:
|
||||
result += page.get_text()
|
||||
result += "\n"
|
||||
|
||||
return result
|
||||
elif pdf_engine == "pdfium":
|
||||
pdf = pdfium.PdfDocument(local_pdf_path, autoclose=True)
|
||||
result = ""
|
||||
|
||||
for page in pdf:
|
||||
textpage = page.get_textpage()
|
||||
result += textpage.get_text_range()
|
||||
result += "\n"
|
||||
|
||||
pdf.close()
|
||||
return result
|
||||
else:
|
||||
raise NotImplementedError()
|
@ -1,8 +1,19 @@
|
||||
# The way this script works is it gets a list of pdfs to process
|
||||
# and an output/scratch folder location either locally or in s3 to work with
|
||||
# On the first run, with an empty output folder, it will queue up each page in each pdf to go into a VLM
|
||||
# Then, the user queues up that task in birr, and it outputs to a new subfolder in the same location
|
||||
# Then, you run your script again, and it will see that you have some valid output files
|
||||
# If so, then it will check those output files, and if it has a complete document, it will build a dolma doc for it, and that's considered done
|
||||
# For any remaining pages that got errored out, or failed due to stop_reason not being "stop" (ex. over length)
|
||||
# Then, it will queue up another set of tasks, hopefully much smaller, to send into batch inference again
|
||||
# This process will keep going, until you run it with the --fallback option, at which point it will
|
||||
# just use a basic text extraction on any remaining pages, and assemble the rest of the dolma docs
|
||||
#
|
||||
#
|
||||
#
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import subprocess
|
||||
import base64
|
||||
import argparse
|
||||
import boto3
|
||||
import json
|
||||
@ -258,6 +269,7 @@ def main():
|
||||
|
||||
# Print the number of PDFs that resulted in at least one output
|
||||
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
|
||||
print(f"Now you should run these prompts through mise/birr")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user