mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Working on new pipeline script
This commit is contained in:
parent
a8b50ae8fa
commit
49b5b233c3
@ -1,276 +1,146 @@
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
import boto3
|
||||
import duckdb
|
||||
import json
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import boto3
|
||||
from pypdf import PdfReader
|
||||
from cached_path import cached_path
|
||||
from smart_open import smart_open
|
||||
def build_index(s3_path):
|
||||
# Hash the s3_path to get a cache key
|
||||
cache_key = hashlib.sha256(s3_path.encode('utf-8')).hexdigest()
|
||||
cache_dir = os.path.join('.cache', cache_key)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
db_path = os.path.join(cache_dir, 'index.db')
|
||||
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
from dataclasses import dataclass, asdict
|
||||
# Connect to duckdb and create tables if not exist
|
||||
print("Building page index at", db_path)
|
||||
conn = duckdb.connect(database=db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS index_table (
|
||||
custom_id TEXT,
|
||||
s3_path TEXT,
|
||||
start_index BIGINT,
|
||||
end_index BIGINT
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS processed_files (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
etag TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedEntry:
|
||||
s3_path: str
|
||||
pagenum: int
|
||||
text: Optional[str]
|
||||
finish_reason: Optional[str]
|
||||
error: Optional[str] = None
|
||||
s3 = boto3.client('s3')
|
||||
bucket, prefix = parse_s3_path(s3_path)
|
||||
|
||||
@staticmethod
|
||||
def from_goldkey(goldkey: str, **kwargs):
|
||||
s3_path = goldkey[:goldkey.rindex("-")]
|
||||
page_num = int(goldkey[goldkey.rindex("-") + 1:])
|
||||
return NormalizedEntry(s3_path, page_num, **kwargs)
|
||||
# List all .json and .jsonl files under s3_path with their ETags
|
||||
files = list_s3_files(s3, bucket, prefix)
|
||||
|
||||
@property
|
||||
def goldkey(self):
|
||||
return f"{self.s3_path}-{self.pagenum}"
|
||||
# Filter out files that have already been processed
|
||||
files_to_process = filter_processed_files(db_path, files)
|
||||
|
||||
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||
if "outputs" in data:
|
||||
# Birr case
|
||||
if data["outputs"] is None:
|
||||
text = None
|
||||
finish_reason = None
|
||||
else:
|
||||
text = data["outputs"][0]["text"]
|
||||
finish_reason = data["outputs"][0]["finish_reason"]
|
||||
if not files_to_process:
|
||||
print("All files have been processed. Nothing to do.")
|
||||
return
|
||||
|
||||
# Try to parse the structured output if possible
|
||||
try:
|
||||
if text is not None:
|
||||
parsed_content = json.loads(text)
|
||||
text = parsed_content["natural_text"]
|
||||
except json.JSONDecodeError:
|
||||
# Use ThreadPoolExecutor to process files with tqdm progress bar
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures = [executor.submit(process_file, s3, bucket, key, etag, db_path) for key, etag in files_to_process.items()]
|
||||
for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
|
||||
pass
|
||||
|
||||
return NormalizedEntry.from_goldkey(
|
||||
goldkey=data["custom_id"],
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
error=data.get("completion_error", None)
|
||||
)
|
||||
else:
|
||||
# OpenAI case
|
||||
try:
|
||||
# Attempt to parse the JSON content from OpenAI's response
|
||||
parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
|
||||
return NormalizedEntry.from_goldkey(
|
||||
goldkey=data["custom_id"],
|
||||
text=parsed_content["natural_text"],
|
||||
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if content is not valid JSON
|
||||
return NormalizedEntry.from_goldkey(
|
||||
goldkey=data["custom_id"],
|
||||
text=data["response"]["body"]["choices"][0]["message"]["content"],
|
||||
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
|
||||
)
|
||||
|
||||
def parse_s3_path(s3_path):
|
||||
if not s3_path.startswith("s3://"):
|
||||
raise ValueError("Invalid S3 path")
|
||||
s3_path = s3_path[5:]
|
||||
bucket_name, _, key = s3_path.partition("/")
|
||||
return bucket_name, key
|
||||
if not s3_path.startswith('s3://'):
|
||||
raise ValueError('s3_path must start with s3://')
|
||||
path = s3_path[5:]
|
||||
bucket, _, prefix = path.partition('/')
|
||||
return bucket, prefix
|
||||
|
||||
def process_document(s3_path, entries, output_dir):
|
||||
"""
|
||||
Processes a single document:
|
||||
- Downloads the PDF
|
||||
- Validates and assembles text
|
||||
- Writes the output JSON if successful
|
||||
- Returns processing results for aggregation
|
||||
"""
|
||||
def list_s3_files(s3, bucket, prefix):
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||
files = {}
|
||||
for page in page_iterator:
|
||||
contents = page.get('Contents', [])
|
||||
for obj in contents:
|
||||
key = obj['Key']
|
||||
if key.endswith('.json') or key.endswith('.jsonl'):
|
||||
# Retrieve ETag for each file
|
||||
files[key] = obj['ETag'].strip('"')
|
||||
return files
|
||||
|
||||
def filter_processed_files(db_path, files):
|
||||
conn = duckdb.connect(database=db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Retrieve processed files
|
||||
cursor.execute("SELECT s3_path, etag FROM processed_files")
|
||||
processed = dict(cursor.fetchall())
|
||||
|
||||
# Filter out files that are already processed with the same ETag
|
||||
files_to_process = {}
|
||||
for key, etag in files.items():
|
||||
if key not in processed or processed[key] != etag:
|
||||
files_to_process[key] = etag
|
||||
|
||||
conn.close()
|
||||
return files_to_process
|
||||
|
||||
def process_file(s3, bucket, key, etag, db_path):
|
||||
try:
|
||||
# Download the PDF locally
|
||||
pdf_local_path = cached_path(s3_path, quiet=True)
|
||||
pdf = PdfReader(pdf_local_path)
|
||||
total_pages_in_pdf = len(pdf.pages)
|
||||
# Get the object
|
||||
obj = s3.get_object(Bucket=bucket, Key=key)
|
||||
s3_path = f's3://{bucket}/{key}'
|
||||
|
||||
# Read the content as bytes
|
||||
content = obj['Body'].read()
|
||||
|
||||
# Connect to duckdb
|
||||
conn = duckdb.connect(database=db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Process the file as JSONL
|
||||
process_jsonl_content(content, s3_path, cursor)
|
||||
|
||||
# Update the processed_files table
|
||||
cursor.execute("""
|
||||
INSERT INTO processed_files (s3_path, etag)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT (s3_path) DO UPDATE SET etag=excluded.etag
|
||||
""", (key, etag))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
||||
return {
|
||||
"processed": 1,
|
||||
"successful_documents": 0,
|
||||
"successful_pages": 0,
|
||||
"total_pages": 0,
|
||||
"errored_entries": []
|
||||
}
|
||||
|
||||
# Build mapping from pagenum to entry
|
||||
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
|
||||
|
||||
valid_entries = []
|
||||
missing_pages = []
|
||||
errors = []
|
||||
|
||||
# Iterate from 1 to total_pages_in_pdf inclusive
|
||||
for page_num in range(1, total_pages_in_pdf + 1):
|
||||
entry = entry_by_pagenum.get(page_num)
|
||||
if entry is None:
|
||||
missing_pages.append(page_num)
|
||||
elif entry.error is not None or entry.finish_reason != "stop":
|
||||
errors.append(entry)
|
||||
else:
|
||||
valid_entries.append(entry)
|
||||
|
||||
if not missing_pages and not errors:
|
||||
# Assemble text
|
||||
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
||||
text = "\n".join(entry.text for entry in valid_entries_sorted if entry.text)
|
||||
|
||||
# Generate a filename based on the s3_path
|
||||
doc_hash = hashlib.md5(s3_path.encode("utf-8")).hexdigest()
|
||||
output_filename = os.path.join(output_dir, f"{doc_hash}.json")
|
||||
|
||||
output_data = {
|
||||
"source": s3_path,
|
||||
"total_pages": total_pages_in_pdf,
|
||||
"text": text
|
||||
}
|
||||
print(f"Error processing file {key}: {e}")
|
||||
|
||||
def process_jsonl_content(content, s3_path, cursor):
|
||||
start_index = 0
|
||||
lines = content.splitlines(keepends=True)
|
||||
for line in lines:
|
||||
line_length = len(line)
|
||||
end_index = start_index + line_length
|
||||
try:
|
||||
with open(output_filename, "w") as f_out:
|
||||
json.dump(output_data, f_out)
|
||||
return {
|
||||
"processed": 1,
|
||||
"successful_documents": 1,
|
||||
"successful_pages": len(valid_entries),
|
||||
"total_pages": total_pages_in_pdf,
|
||||
"errored_entries": []
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error writing output file {output_filename}: {e}")
|
||||
return {
|
||||
"processed": 1,
|
||||
"successful_documents": 0,
|
||||
"successful_pages": 0,
|
||||
"total_pages": total_pages_in_pdf,
|
||||
"errored_entries": []
|
||||
}
|
||||
else:
|
||||
missing = [page for page in missing_pages]
|
||||
error_pages = [e.pagenum for e in errors]
|
||||
logging.info(f"Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}")
|
||||
# Collect the errored entries
|
||||
errored_entries = [asdict(entry) for entry in errors]
|
||||
return {
|
||||
"processed": 1,
|
||||
"successful_documents": 0,
|
||||
"successful_pages": len(valid_entries),
|
||||
"total_pages": total_pages_in_pdf,
|
||||
"errored_entries": errored_entries
|
||||
}
|
||||
data = json.loads(line)
|
||||
custom_id = data.get('custom_id')
|
||||
if custom_id:
|
||||
cursor.execute("""
|
||||
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (custom_id, s3_path, start_index, end_index))
|
||||
except json.JSONDecodeError:
|
||||
pass # Handle JSON decode errors if necessary
|
||||
start_index = end_index
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Process finished birr inference outputs into dolma docs")
|
||||
parser.add_argument("s3_path", help="S3 path to the directory containing JSON or JSONL files")
|
||||
parser.add_argument("--output_dir", default="output", help="Directory to save the output files")
|
||||
parser.add_argument("--max_workers", type=int, default=8, help="Maximum number of worker threads")
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Build a local index of JSON files from S3.')
|
||||
parser.add_argument('s3_path', help='The S3 path to process (e.g., s3://bucket/prefix/)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(filename="processing.log", level=logging.INFO, format="%(asctime)s %(message)s")
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Initialize S3 client
|
||||
s3 = boto3.client("s3")
|
||||
bucket_name, prefix = parse_s3_path(args.s3_path)
|
||||
|
||||
# List all .json and .jsonl files in the specified S3 path
|
||||
paginator = s3.get_paginator("list_objects_v2")
|
||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||
|
||||
files = []
|
||||
for page in page_iterator:
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
key = obj["Key"]
|
||||
if key.endswith(".json") or key.endswith(".jsonl"):
|
||||
files.append(key)
|
||||
|
||||
# Build documents mapping
|
||||
documents = defaultdict(list)
|
||||
|
||||
print("Processing JSON files and building documents mapping...")
|
||||
for key in tqdm(files):
|
||||
file_s3_path = f"s3://{bucket_name}/{key}"
|
||||
try:
|
||||
with smart_open(file_s3_path, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
entry = normalize_json_entry(data)
|
||||
documents[entry.s3_path].append(entry)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {file_s3_path}: {e}")
|
||||
|
||||
total_documents = len(documents)
|
||||
successful_documents = 0
|
||||
total_pages = 0
|
||||
successful_pages = 0
|
||||
all_errored_entries = []
|
||||
|
||||
print("Processing documents with ThreadPoolExecutor...")
|
||||
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
||||
# Prepare futures
|
||||
future_to_s3 = {
|
||||
executor.submit(
|
||||
process_document,
|
||||
s3_path,
|
||||
entries,
|
||||
args.output_dir
|
||||
): s3_path for s3_path, entries in documents.items()
|
||||
}
|
||||
|
||||
# Use tqdm to display progress
|
||||
for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
|
||||
try:
|
||||
result = future.result()
|
||||
successful_documents += result.get("successful_documents", 0)
|
||||
successful_pages += result.get("successful_pages", 0)
|
||||
total_pages += result.get("total_pages", 0)
|
||||
all_errored_entries.extend(result.get("errored_entries", []))
|
||||
except Exception as e:
|
||||
s3_path = future_to_s3[future]
|
||||
logging.error(f"Error processing document {s3_path}: {e}")
|
||||
|
||||
# Write errored entries to a new JSONL file
|
||||
os.makedirs(os.path.join(args.output_dir, "cleanups"), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output_dir, "errors"), exist_ok=True)
|
||||
error_output_file = os.path.join(args.output_dir, "errors", "errored_pages.jsonl")
|
||||
|
||||
with open(error_output_file, "w") as f_err:
|
||||
for entry in all_errored_entries:
|
||||
json.dump(entry, f_err)
|
||||
f_err.write("\n")
|
||||
|
||||
clean_output_file = os.path.join(args.output_dir, "cleanups", "cleanup_pages.jsonl")
|
||||
|
||||
with open(clean_output_file, "w") as f_err:
|
||||
for entry in all_errored_entries:
|
||||
local_path = cached_path(entry["s3_path"])
|
||||
entry["text"] = get_anchor_text(local_path, entry["pagenum"], pdf_engine="pdftotext")
|
||||
entry["error"] = None
|
||||
entry["finish_reason"] = "stop"
|
||||
json.dump(entry, f_err)
|
||||
f_err.write("\n")
|
||||
|
||||
|
||||
print(f"Total documents: {total_documents}")
|
||||
print(f"Successful documents: {successful_documents}")
|
||||
print(f"Total pages: {total_pages}")
|
||||
print(f"Successful pages: {successful_pages}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
build_index(args.s3_path)
|
||||
|
@ -27,7 +27,8 @@ dependencies = [
|
||||
"lingua-language-detector",
|
||||
"Pillow",
|
||||
"ftfy",
|
||||
"bleach"
|
||||
"bleach",
|
||||
"duckdb",
|
||||
]
|
||||
license = {file = "LICENSE"}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user