olmocr/pdelfin/assemblepipeline.py
2024-10-09 22:14:28 +00:00

199 lines
6.8 KiB
Python

import argparse
import os
import json
import hashlib
import logging
from collections import defaultdict
from typing import Optional
from tqdm import tqdm
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import boto3
from pypdf import PdfReader
from cached_path import cached_path
from smart_open import smart_open
from dataclasses import dataclass
# Import your existing modules if necessary
# from dolma_refine.evaluate.metrics import DocumentEditSimilarity
# from dolma_refine.evaluate.segmenters import SpacySegmenter
# from dolma_refine.evaluate.aligners import HirschbergAligner
@dataclass(frozen=True)
class NormalizedEntry:
s3_path: str
pagenum: int
text: Optional[str]
finish_reason: Optional[str]
error: Optional[str] = None
@staticmethod
def from_goldkey(goldkey: str, **kwargs):
s3_path = goldkey[:goldkey.rindex("-")]
page_num = int(goldkey[goldkey.rindex("-") + 1:])
return NormalizedEntry(s3_path, page_num, **kwargs)
@property
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"
def normalize_json_entry(data: dict) -> NormalizedEntry:
if "outputs" in data:
# Birr case
if data["outputs"] is None:
text = None
finish_reason = None
else:
text = data["outputs"][0]["text"]
finish_reason = data["outputs"][0]["finish_reason"]
# Try to parse the structured output if possible
try:
if text is not None:
parsed_content = json.loads(text)
text = parsed_content["natural_text"]
except json.JSONDecodeError:
pass
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=text,
finish_reason=finish_reason,
error=data.get("completion_error", None)
)
else:
# OpenAI case
try:
# Attempt to parse the JSON content from OpenAI's response
parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=parsed_content["natural_text"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
)
except json.JSONDecodeError:
# Fallback if content is not valid JSON
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=data["response"]["body"]["choices"][0]["message"]["content"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
)
def parse_s3_path(s3_path):
if not s3_path.startswith('s3://'):
raise ValueError('Invalid S3 path')
s3_path = s3_path[5:]
bucket_name, _, key = s3_path.partition('/')
return bucket_name, key
def main():
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
args = parser.parse_args()
# Set up logging
logging.basicConfig(filename='processing.log', level=logging.INFO, format='%(asctime)s %(message)s')
os.makedirs(args.output_dir, exist_ok=True)
# Initialize S3 client
s3 = boto3.client('s3')
bucket_name, prefix = parse_s3_path(args.s3_path)
# List all .json and .jsonl files in the specified S3 path
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
files = []
for page in page_iterator:
if 'Contents' in page:
for obj in page['Contents']:
key = obj['Key']
if key.endswith('.json') or key.endswith('.jsonl'):
files.append(key)
# Build documents mapping
documents = defaultdict(list)
print("Processing JSON files and building documents mapping...")
for key in tqdm(files):
file_s3_path = f's3://{bucket_name}/{key}'
try:
with smart_open(file_s3_path, 'r') as f:
for line in f:
data = json.loads(line)
entry = normalize_json_entry(data)
documents[entry.s3_path].append(entry)
except Exception as e:
logging.error(f"Error processing file {file_s3_path}: {e}")
total_documents = len(documents)
successful_documents = 0
total_pages = 0
successful_pages = 0
print("Processing documents...")
for s3_path, entries in tqdm(documents.items()):
try:
# Download the PDF locally
pdf_local_path = cached_path(s3_path, quiet=True)
pdf = PdfReader(pdf_local_path)
total_pages_in_pdf = len(pdf.pages)
except Exception as e:
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
continue
total_pages += total_pages_in_pdf
# Build mapping from pagenum to entry
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
valid_entries = []
missing_pages = []
errors = []
for page_num in range(total_pages_in_pdf):
entry = entry_by_pagenum.get(page_num)
if entry is None:
missing_pages.append(page_num)
elif entry.error is not None or entry.finish_reason != 'stop':
errors.append(entry)
else:
valid_entries.append(entry)
successful_pages += len(valid_entries)
if not missing_pages and not errors:
# Assemble text
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
text = '\n'.join(entry.text for entry in valid_entries_sorted)
# Generate a filename based on the s3_path
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
output_filename = os.path.join(args.output_dir, f'{doc_hash}.json')
output_data = {
'source': s3_path,
'total_pages': total_pages_in_pdf,
'text': text
}
with open(output_filename, 'w') as f_out:
json.dump(output_data, f_out)
successful_documents += 1
else:
logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}')
print(f'Total documents: {total_documents}')
print(f'Successful documents: {successful_documents}')
print(f'Total pages: {total_pages}')
print(f'Successful pages: {successful_pages}')
if __name__ == '__main__':
main()