mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-10 01:42:42 +00:00
199 lines
6.8 KiB
Python
199 lines
6.8 KiB
Python
import argparse
|
|
import os
|
|
import json
|
|
import hashlib
|
|
import logging
|
|
from collections import defaultdict
|
|
from typing import Optional
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import boto3
|
|
from pypdf import PdfReader
|
|
from cached_path import cached_path
|
|
from smart_open import smart_open
|
|
|
|
from dataclasses import dataclass
|
|
|
|
# Import your existing modules if necessary
|
|
# from dolma_refine.evaluate.metrics import DocumentEditSimilarity
|
|
# from dolma_refine.evaluate.segmenters import SpacySegmenter
|
|
# from dolma_refine.evaluate.aligners import HirschbergAligner
|
|
|
|
@dataclass(frozen=True)
|
|
class NormalizedEntry:
|
|
s3_path: str
|
|
pagenum: int
|
|
text: Optional[str]
|
|
finish_reason: Optional[str]
|
|
error: Optional[str] = None
|
|
|
|
@staticmethod
|
|
def from_goldkey(goldkey: str, **kwargs):
|
|
s3_path = goldkey[:goldkey.rindex("-")]
|
|
page_num = int(goldkey[goldkey.rindex("-") + 1:])
|
|
return NormalizedEntry(s3_path, page_num, **kwargs)
|
|
|
|
@property
|
|
def goldkey(self):
|
|
return f"{self.s3_path}-{self.pagenum}"
|
|
|
|
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
|
if "outputs" in data:
|
|
# Birr case
|
|
if data["outputs"] is None:
|
|
text = None
|
|
finish_reason = None
|
|
else:
|
|
text = data["outputs"][0]["text"]
|
|
finish_reason = data["outputs"][0]["finish_reason"]
|
|
|
|
# Try to parse the structured output if possible
|
|
try:
|
|
if text is not None:
|
|
parsed_content = json.loads(text)
|
|
text = parsed_content["natural_text"]
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return NormalizedEntry.from_goldkey(
|
|
goldkey=data["custom_id"],
|
|
text=text,
|
|
finish_reason=finish_reason,
|
|
error=data.get("completion_error", None)
|
|
)
|
|
else:
|
|
# OpenAI case
|
|
try:
|
|
# Attempt to parse the JSON content from OpenAI's response
|
|
parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
|
|
return NormalizedEntry.from_goldkey(
|
|
goldkey=data["custom_id"],
|
|
text=parsed_content["natural_text"],
|
|
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
|
|
)
|
|
except json.JSONDecodeError:
|
|
# Fallback if content is not valid JSON
|
|
return NormalizedEntry.from_goldkey(
|
|
goldkey=data["custom_id"],
|
|
text=data["response"]["body"]["choices"][0]["message"]["content"],
|
|
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
|
|
)
|
|
|
|
def parse_s3_path(s3_path):
|
|
if not s3_path.startswith('s3://'):
|
|
raise ValueError('Invalid S3 path')
|
|
s3_path = s3_path[5:]
|
|
bucket_name, _, key = s3_path.partition('/')
|
|
return bucket_name, key
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
|
|
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
|
|
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
|
|
args = parser.parse_args()
|
|
|
|
# Set up logging
|
|
logging.basicConfig(filename='processing.log', level=logging.INFO, format='%(asctime)s %(message)s')
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Initialize S3 client
|
|
s3 = boto3.client('s3')
|
|
bucket_name, prefix = parse_s3_path(args.s3_path)
|
|
|
|
# List all .json and .jsonl files in the specified S3 path
|
|
paginator = s3.get_paginator('list_objects_v2')
|
|
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
|
|
|
files = []
|
|
for page in page_iterator:
|
|
if 'Contents' in page:
|
|
for obj in page['Contents']:
|
|
key = obj['Key']
|
|
if key.endswith('.json') or key.endswith('.jsonl'):
|
|
files.append(key)
|
|
|
|
# Build documents mapping
|
|
documents = defaultdict(list)
|
|
|
|
print("Processing JSON files and building documents mapping...")
|
|
for key in tqdm(files):
|
|
file_s3_path = f's3://{bucket_name}/{key}'
|
|
try:
|
|
with smart_open(file_s3_path, 'r') as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
entry = normalize_json_entry(data)
|
|
documents[entry.s3_path].append(entry)
|
|
except Exception as e:
|
|
logging.error(f"Error processing file {file_s3_path}: {e}")
|
|
|
|
total_documents = len(documents)
|
|
successful_documents = 0
|
|
total_pages = 0
|
|
successful_pages = 0
|
|
|
|
print("Processing documents...")
|
|
for s3_path, entries in tqdm(documents.items()):
|
|
try:
|
|
# Download the PDF locally
|
|
pdf_local_path = cached_path(s3_path, quiet=True)
|
|
|
|
pdf = PdfReader(pdf_local_path)
|
|
total_pages_in_pdf = len(pdf.pages)
|
|
except Exception as e:
|
|
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
|
continue
|
|
|
|
total_pages += total_pages_in_pdf
|
|
|
|
# Build mapping from pagenum to entry
|
|
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
|
|
|
|
valid_entries = []
|
|
missing_pages = []
|
|
errors = []
|
|
|
|
for page_num in range(total_pages_in_pdf):
|
|
entry = entry_by_pagenum.get(page_num)
|
|
if entry is None:
|
|
missing_pages.append(page_num)
|
|
elif entry.error is not None or entry.finish_reason != 'stop':
|
|
errors.append(entry)
|
|
else:
|
|
valid_entries.append(entry)
|
|
|
|
successful_pages += len(valid_entries)
|
|
|
|
if not missing_pages and not errors:
|
|
# Assemble text
|
|
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
|
text = '\n'.join(entry.text for entry in valid_entries_sorted)
|
|
|
|
# Generate a filename based on the s3_path
|
|
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
|
|
output_filename = os.path.join(args.output_dir, f'{doc_hash}.json')
|
|
|
|
output_data = {
|
|
'source': s3_path,
|
|
'total_pages': total_pages_in_pdf,
|
|
'text': text
|
|
}
|
|
|
|
with open(output_filename, 'w') as f_out:
|
|
json.dump(output_data, f_out)
|
|
|
|
successful_documents += 1
|
|
else:
|
|
logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}')
|
|
|
|
print(f'Total documents: {total_documents}')
|
|
print(f'Successful documents: {successful_documents}')
|
|
print(f'Total pages: {total_pages}')
|
|
print(f'Successful pages: {successful_pages}')
|
|
|
|
if __name__ == '__main__':
|
|
main()
|