Working on new pipeline script

This commit is contained in:
Jake Poznanski 2024-10-10 22:10:26 +00:00
parent a8b50ae8fa
commit 49b5b233c3
2 changed files with 126 additions and 255 deletions

View File

@ -1,276 +1,146 @@
import argparse
import os
import json
import sys
import hashlib
import logging
from collections import defaultdict
from typing import Optional
import boto3
import duckdb
import json
import argparse
from tqdm import tqdm
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import boto3
from pypdf import PdfReader
from cached_path import cached_path
from smart_open import smart_open
def build_index(s3_path):
# Hash the s3_path to get a cache key
cache_key = hashlib.sha256(s3_path.encode('utf-8')).hexdigest()
cache_dir = os.path.join('.cache', cache_key)
os.makedirs(cache_dir, exist_ok=True)
db_path = os.path.join(cache_dir, 'index.db')
from pdelfin.prompts.anchor import get_anchor_text
from dataclasses import dataclass, asdict
# Connect to duckdb and create tables if not exist
print("Building page index at", db_path)
conn = duckdb.connect(database=db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS index_table (
custom_id TEXT,
s3_path TEXT,
start_index BIGINT,
end_index BIGINT
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY,
etag TEXT
)
""")
conn.commit()
conn.close()
@dataclass(frozen=True)
class NormalizedEntry:
s3_path: str
pagenum: int
text: Optional[str]
finish_reason: Optional[str]
error: Optional[str] = None
s3 = boto3.client('s3')
bucket, prefix = parse_s3_path(s3_path)
@staticmethod
def from_goldkey(goldkey: str, **kwargs):
s3_path = goldkey[:goldkey.rindex("-")]
page_num = int(goldkey[goldkey.rindex("-") + 1:])
return NormalizedEntry(s3_path, page_num, **kwargs)
# List all .json and .jsonl files under s3_path with their ETags
files = list_s3_files(s3, bucket, prefix)
@property
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"
# Filter out files that have already been processed
files_to_process = filter_processed_files(db_path, files)
def normalize_json_entry(data: dict) -> NormalizedEntry:
if "outputs" in data:
# Birr case
if data["outputs"] is None:
text = None
finish_reason = None
else:
text = data["outputs"][0]["text"]
finish_reason = data["outputs"][0]["finish_reason"]
if not files_to_process:
print("All files have been processed. Nothing to do.")
return
# Try to parse the structured output if possible
try:
if text is not None:
parsed_content = json.loads(text)
text = parsed_content["natural_text"]
except json.JSONDecodeError:
# Use ThreadPoolExecutor to process files with tqdm progress bar
with ThreadPoolExecutor(max_workers=8) as executor:
futures = [executor.submit(process_file, s3, bucket, key, etag, db_path) for key, etag in files_to_process.items()]
for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
pass
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=text,
finish_reason=finish_reason,
error=data.get("completion_error", None)
)
else:
# OpenAI case
try:
# Attempt to parse the JSON content from OpenAI's response
parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=parsed_content["natural_text"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
)
except json.JSONDecodeError:
# Fallback if content is not valid JSON
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=data["response"]["body"]["choices"][0]["message"]["content"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
)
def parse_s3_path(s3_path):
if not s3_path.startswith("s3://"):
raise ValueError("Invalid S3 path")
s3_path = s3_path[5:]
bucket_name, _, key = s3_path.partition("/")
return bucket_name, key
if not s3_path.startswith('s3://'):
raise ValueError('s3_path must start with s3://')
path = s3_path[5:]
bucket, _, prefix = path.partition('/')
return bucket, prefix
def process_document(s3_path, entries, output_dir):
"""
Processes a single document:
- Downloads the PDF
- Validates and assembles text
- Writes the output JSON if successful
- Returns processing results for aggregation
"""
def list_s3_files(s3, bucket, prefix):
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = {}
for page in page_iterator:
contents = page.get('Contents', [])
for obj in contents:
key = obj['Key']
if key.endswith('.json') or key.endswith('.jsonl'):
# Retrieve ETag for each file
files[key] = obj['ETag'].strip('"')
return files
def filter_processed_files(db_path, files):
conn = duckdb.connect(database=db_path)
cursor = conn.cursor()
# Retrieve processed files
cursor.execute("SELECT s3_path, etag FROM processed_files")
processed = dict(cursor.fetchall())
# Filter out files that are already processed with the same ETag
files_to_process = {}
for key, etag in files.items():
if key not in processed or processed[key] != etag:
files_to_process[key] = etag
conn.close()
return files_to_process
def process_file(s3, bucket, key, etag, db_path):
try:
# Download the PDF locally
pdf_local_path = cached_path(s3_path, quiet=True)
pdf = PdfReader(pdf_local_path)
total_pages_in_pdf = len(pdf.pages)
# Get the object
obj = s3.get_object(Bucket=bucket, Key=key)
s3_path = f's3://{bucket}/{key}'
# Read the content as bytes
content = obj['Body'].read()
# Connect to duckdb
conn = duckdb.connect(database=db_path)
cursor = conn.cursor()
# Process the file as JSONL
process_jsonl_content(content, s3_path, cursor)
# Update the processed_files table
cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT (s3_path) DO UPDATE SET etag=excluded.etag
""", (key, etag))
conn.commit()
conn.close()
except Exception as e:
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
return {
"processed": 1,
"successful_documents": 0,
"successful_pages": 0,
"total_pages": 0,
"errored_entries": []
}
# Build mapping from pagenum to entry
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
valid_entries = []
missing_pages = []
errors = []
# Iterate from 1 to total_pages_in_pdf inclusive
for page_num in range(1, total_pages_in_pdf + 1):
entry = entry_by_pagenum.get(page_num)
if entry is None:
missing_pages.append(page_num)
elif entry.error is not None or entry.finish_reason != "stop":
errors.append(entry)
else:
valid_entries.append(entry)
if not missing_pages and not errors:
# Assemble text
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
text = "\n".join(entry.text for entry in valid_entries_sorted if entry.text)
# Generate a filename based on the s3_path
doc_hash = hashlib.md5(s3_path.encode("utf-8")).hexdigest()
output_filename = os.path.join(output_dir, f"{doc_hash}.json")
output_data = {
"source": s3_path,
"total_pages": total_pages_in_pdf,
"text": text
}
print(f"Error processing file {key}: {e}")
def process_jsonl_content(content, s3_path, cursor):
start_index = 0
lines = content.splitlines(keepends=True)
for line in lines:
line_length = len(line)
end_index = start_index + line_length
try:
with open(output_filename, "w") as f_out:
json.dump(output_data, f_out)
return {
"processed": 1,
"successful_documents": 1,
"successful_pages": len(valid_entries),
"total_pages": total_pages_in_pdf,
"errored_entries": []
}
except Exception as e:
logging.error(f"Error writing output file {output_filename}: {e}")
return {
"processed": 1,
"successful_documents": 0,
"successful_pages": 0,
"total_pages": total_pages_in_pdf,
"errored_entries": []
}
else:
missing = [page for page in missing_pages]
error_pages = [e.pagenum for e in errors]
logging.info(f"Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}")
# Collect the errored entries
errored_entries = [asdict(entry) for entry in errors]
return {
"processed": 1,
"successful_documents": 0,
"successful_pages": len(valid_entries),
"total_pages": total_pages_in_pdf,
"errored_entries": errored_entries
}
data = json.loads(line)
custom_id = data.get('custom_id')
if custom_id:
cursor.execute("""
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
VALUES (?, ?, ?, ?)
""", (custom_id, s3_path, start_index, end_index))
except json.JSONDecodeError:
pass # Handle JSON decode errors if necessary
start_index = end_index
def main():
parser = argparse.ArgumentParser(description="Process finished birr inference outputs into dolma docs")
parser.add_argument("s3_path", help="S3 path to the directory containing JSON or JSONL files")
parser.add_argument("--output_dir", default="output", help="Directory to save the output files")
parser.add_argument("--max_workers", type=int, default=8, help="Maximum number of worker threads")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Build a local index of JSON files from S3.')
parser.add_argument('s3_path', help='The S3 path to process (e.g., s3://bucket/prefix/)')
args = parser.parse_args()
# Set up logging
logging.basicConfig(filename="processing.log", level=logging.INFO, format="%(asctime)s %(message)s")
os.makedirs(args.output_dir, exist_ok=True)
# Initialize S3 client
s3 = boto3.client("s3")
bucket_name, prefix = parse_s3_path(args.s3_path)
# List all .json and .jsonl files in the specified S3 path
paginator = s3.get_paginator("list_objects_v2")
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
files = []
for page in page_iterator:
if "Contents" in page:
for obj in page["Contents"]:
key = obj["Key"]
if key.endswith(".json") or key.endswith(".jsonl"):
files.append(key)
# Build documents mapping
documents = defaultdict(list)
print("Processing JSON files and building documents mapping...")
for key in tqdm(files):
file_s3_path = f"s3://{bucket_name}/{key}"
try:
with smart_open(file_s3_path, "r") as f:
for line in f:
data = json.loads(line)
entry = normalize_json_entry(data)
documents[entry.s3_path].append(entry)
except Exception as e:
logging.error(f"Error processing file {file_s3_path}: {e}")
total_documents = len(documents)
successful_documents = 0
total_pages = 0
successful_pages = 0
all_errored_entries = []
print("Processing documents with ThreadPoolExecutor...")
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
# Prepare futures
future_to_s3 = {
executor.submit(
process_document,
s3_path,
entries,
args.output_dir
): s3_path for s3_path, entries in documents.items()
}
# Use tqdm to display progress
for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
try:
result = future.result()
successful_documents += result.get("successful_documents", 0)
successful_pages += result.get("successful_pages", 0)
total_pages += result.get("total_pages", 0)
all_errored_entries.extend(result.get("errored_entries", []))
except Exception as e:
s3_path = future_to_s3[future]
logging.error(f"Error processing document {s3_path}: {e}")
# Write errored entries to a new JSONL file
os.makedirs(os.path.join(args.output_dir, "cleanups"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "errors"), exist_ok=True)
error_output_file = os.path.join(args.output_dir, "errors", "errored_pages.jsonl")
with open(error_output_file, "w") as f_err:
for entry in all_errored_entries:
json.dump(entry, f_err)
f_err.write("\n")
clean_output_file = os.path.join(args.output_dir, "cleanups", "cleanup_pages.jsonl")
with open(clean_output_file, "w") as f_err:
for entry in all_errored_entries:
local_path = cached_path(entry["s3_path"])
entry["text"] = get_anchor_text(local_path, entry["pagenum"], pdf_engine="pdftotext")
entry["error"] = None
entry["finish_reason"] = "stop"
json.dump(entry, f_err)
f_err.write("\n")
print(f"Total documents: {total_documents}")
print(f"Successful documents: {successful_documents}")
print(f"Total pages: {total_pages}")
print(f"Successful pages: {successful_pages}")
if __name__ == "__main__":
main()
build_index(args.s3_path)

View File

@ -27,7 +27,8 @@ dependencies = [
"lingua-language-detector",
"Pillow",
"ftfy",
"bleach"
"bleach",
"duckdb",
]
license = {file = "LICENSE"}