Taking notes, starting on document assembly

This commit is contained in:
Jake Poznanski 2024-10-09 22:14:28 +00:00
parent 8e5809da71
commit 847064f46f
4 changed files with 214 additions and 83 deletions

198
pdelfin/assemblepipeline.py Normal file
View File

@ -0,0 +1,198 @@
import argparse
import os
import json
import hashlib
import logging
from collections import defaultdict
from typing import Optional
from tqdm import tqdm
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import boto3
from pypdf import PdfReader
from cached_path import cached_path
from smart_open import smart_open
from dataclasses import dataclass
# Import your existing modules if necessary
# from dolma_refine.evaluate.metrics import DocumentEditSimilarity
# from dolma_refine.evaluate.segmenters import SpacySegmenter
# from dolma_refine.evaluate.aligners import HirschbergAligner
@dataclass(frozen=True)
class NormalizedEntry:
s3_path: str
pagenum: int
text: Optional[str]
finish_reason: Optional[str]
error: Optional[str] = None
@staticmethod
def from_goldkey(goldkey: str, **kwargs):
s3_path = goldkey[:goldkey.rindex("-")]
page_num = int(goldkey[goldkey.rindex("-") + 1:])
return NormalizedEntry(s3_path, page_num, **kwargs)
@property
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"
def normalize_json_entry(data: dict) -> NormalizedEntry:
if "outputs" in data:
# Birr case
if data["outputs"] is None:
text = None
finish_reason = None
else:
text = data["outputs"][0]["text"]
finish_reason = data["outputs"][0]["finish_reason"]
# Try to parse the structured output if possible
try:
if text is not None:
parsed_content = json.loads(text)
text = parsed_content["natural_text"]
except json.JSONDecodeError:
pass
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=text,
finish_reason=finish_reason,
error=data.get("completion_error", None)
)
else:
# OpenAI case
try:
# Attempt to parse the JSON content from OpenAI's response
parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=parsed_content["natural_text"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
)
except json.JSONDecodeError:
# Fallback if content is not valid JSON
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=data["response"]["body"]["choices"][0]["message"]["content"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
)
def parse_s3_path(s3_path):
if not s3_path.startswith('s3://'):
raise ValueError('Invalid S3 path')
s3_path = s3_path[5:]
bucket_name, _, key = s3_path.partition('/')
return bucket_name, key
def main():
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
args = parser.parse_args()
# Set up logging
logging.basicConfig(filename='processing.log', level=logging.INFO, format='%(asctime)s %(message)s')
os.makedirs(args.output_dir, exist_ok=True)
# Initialize S3 client
s3 = boto3.client('s3')
bucket_name, prefix = parse_s3_path(args.s3_path)
# List all .json and .jsonl files in the specified S3 path
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
files = []
for page in page_iterator:
if 'Contents' in page:
for obj in page['Contents']:
key = obj['Key']
if key.endswith('.json') or key.endswith('.jsonl'):
files.append(key)
# Build documents mapping
documents = defaultdict(list)
print("Processing JSON files and building documents mapping...")
for key in tqdm(files):
file_s3_path = f's3://{bucket_name}/{key}'
try:
with smart_open(file_s3_path, 'r') as f:
for line in f:
data = json.loads(line)
entry = normalize_json_entry(data)
documents[entry.s3_path].append(entry)
except Exception as e:
logging.error(f"Error processing file {file_s3_path}: {e}")
total_documents = len(documents)
successful_documents = 0
total_pages = 0
successful_pages = 0
print("Processing documents...")
for s3_path, entries in tqdm(documents.items()):
try:
# Download the PDF locally
pdf_local_path = cached_path(s3_path, quiet=True)
pdf = PdfReader(pdf_local_path)
total_pages_in_pdf = len(pdf.pages)
except Exception as e:
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
continue
total_pages += total_pages_in_pdf
# Build mapping from pagenum to entry
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
valid_entries = []
missing_pages = []
errors = []
for page_num in range(total_pages_in_pdf):
entry = entry_by_pagenum.get(page_num)
if entry is None:
missing_pages.append(page_num)
elif entry.error is not None or entry.finish_reason != 'stop':
errors.append(entry)
else:
valid_entries.append(entry)
successful_pages += len(valid_entries)
if not missing_pages and not errors:
# Assemble text
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
text = '\n'.join(entry.text for entry in valid_entries_sorted)
# Generate a filename based on the s3_path
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
output_filename = os.path.join(args.output_dir, f'{doc_hash}.json')
output_data = {
'source': s3_path,
'total_pages': total_pages_in_pdf,
'text': text
}
with open(output_filename, 'w') as f_out:
json.dump(output_data, f_out)
successful_documents += 1
else:
logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}')
print(f'Total documents: {total_documents}')
print(f'Successful documents: {successful_documents}')
print(f'Total pages: {total_pages}')
print(f'Successful pages: {successful_pages}')
if __name__ == '__main__':
main()

View File

@ -285,6 +285,8 @@ def do_eval(gold_data_path: str, eval_data_path: str, review_page_name: str, rev
# if len(pd["gold_text"]) < 200 and len(pd["eval_text"]) < 200:
# continue
if "[Error processing this page: overrun" not in pd["eval_text"]:
continue
page_eval_data.append(pd)

View File

@ -1,81 +0,0 @@
import subprocess
from typing import Literal
import pymupdf
import pypdfium2 as pdfium
def get_page_text(
local_pdf_path: str, page_num: int, pdf_engine: Literal["pdftotext", "pymupdf", "pdfium"] = "pdftotext"
) -> str:
if pdf_engine == "pdftotext":
pdftotext_result = subprocess.run(
[
"pdftotext",
"-f",
str(page_num),
"-l",
str(page_num),
local_pdf_path,
"-",
],
timeout=60,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
assert pdftotext_result.returncode == 0
return pdftotext_result.stdout.decode("utf-8")
elif pdf_engine == "pymupdf":
pm_doc = pymupdf.open(local_pdf_path)
return pm_doc[page_num - 1].get_text()
elif pdf_engine == "pdfium":
pdf = pdfium.PdfDocument(local_pdf_path, autoclose=True)
page = pdf[page_num - 1]
textpage = page.get_textpage()
# Extract text from the whole page
result = textpage.get_text_range()
pdf.close()
return result
else:
raise NotImplementedError()
def get_document_text(local_pdf_path: str, pdf_engine: Literal["pdftotext", "pymupdf", "pdfium"] = "pdftotext") -> str:
if pdf_engine == "pdftotext":
pdftotext_result = subprocess.run(
[
"pdftotext",
local_pdf_path,
"-",
],
timeout=60,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
assert pdftotext_result.returncode == 0
return pdftotext_result.stdout.decode("utf-8")
elif pdf_engine == "pymupdf":
pm_doc = pymupdf.open(local_pdf_path)
result = ""
for page in pm_doc:
result += page.get_text()
result += "\n"
return result
elif pdf_engine == "pdfium":
pdf = pdfium.PdfDocument(local_pdf_path, autoclose=True)
result = ""
for page in pdf:
textpage = page.get_textpage()
result += textpage.get_text_range()
result += "\n"
pdf.close()
return result
else:
raise NotImplementedError()

View File

@ -1,8 +1,19 @@
# The way this script works is it gets a list of pdfs to process
# and an output/scratch folder location either locally or in s3 to work with
# On the first run, with an empty output folder, it will queue up each page in each pdf to go into a VLM
# Then, the user queues up that task in birr, and it outputs to a new subfolder in the same location
# Then, you run your script again, and it will see that you have some valid output files
# If so, then it will check those output files, and if it has a complete document, it will build a dolma doc for it, and that's considered done
# For any remaining pages that got errored out, or failed due to stop_reason not being "stop" (ex. over length)
# Then, it will queue up another set of tasks, hopefully much smaller, to send into batch inference again
# This process will keep going, until you run it with the --fallback option, at which point it will
# just use a basic text extraction on any remaining pages, and assemble the rest of the dolma docs
#
#
#
import os
import glob
import random
import subprocess
import base64
import argparse
import boto3
import json
@ -258,6 +269,7 @@ def main():
# Print the number of PDFs that resulted in at least one output
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
print(f"Now you should run these prompts through mise/birr")
if __name__ == "__main__":
main()