First stab at document assembly

This commit is contained in:
Jake Poznanski 2024-10-09 22:19:16 +00:00
parent 847064f46f
commit c6bdf69d8f

View File

@ -88,10 +88,93 @@ def parse_s3_path(s3_path):
bucket_name, _, key = s3_path.partition('/') bucket_name, _, key = s3_path.partition('/')
return bucket_name, key return bucket_name, key
def process_document(s3_path, entries, output_dir):
"""
Processes a single document:
- Downloads the PDF
- Validates and assembles text
- Writes the output JSON if successful
- Returns processing results for aggregation
"""
try:
# Download the PDF locally
pdf_local_path = cached_path(s3_path, quiet=True)
pdf = PdfReader(pdf_local_path)
total_pages_in_pdf = len(pdf.pages)
except Exception as e:
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
return {
'processed': 1,
'successful_documents': 0,
'successful_pages': 0,
'total_pages': 0
}
# Build mapping from pagenum to entry
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
valid_entries = []
missing_pages = []
errors = []
# Iterate from 1 to total_pages_in_pdf inclusive
for page_num in range(1, total_pages_in_pdf + 1):
entry = entry_by_pagenum.get(page_num)
if entry is None:
missing_pages.append(page_num)
elif entry.error is not None or entry.finish_reason != 'stop':
errors.append(entry)
else:
valid_entries.append(entry)
if not missing_pages and not errors:
# Assemble text
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
text = '\n'.join(entry.text for entry in valid_entries_sorted if entry.text)
# Generate a filename based on the s3_path
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
output_filename = os.path.join(output_dir, f'{doc_hash}.json')
output_data = {
'source': s3_path,
'total_pages': total_pages_in_pdf,
'text': text
}
try:
with open(output_filename, 'w') as f_out:
json.dump(output_data, f_out)
return {
'processed': 1,
'successful_documents': 1,
'successful_pages': len(valid_entries),
'total_pages': total_pages_in_pdf
}
except Exception as e:
logging.error(f"Error writing output file {output_filename}: {e}")
return {
'processed': 1,
'successful_documents': 0,
'successful_pages': 0,
'total_pages': total_pages_in_pdf
}
else:
missing = [page for page in missing_pages]
error_pages = [e.pagenum for e in errors]
logging.info(f'Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}')
return {
'processed': 1,
'successful_documents': 0,
'successful_pages': len(valid_entries),
'total_pages': total_pages_in_pdf
}
def main(): def main():
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs') parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files') parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
parser.add_argument('--output_dir', default='output', help='Directory to save the output files') parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of worker threads')
args = parser.parse_args() args = parser.parse_args()
# Set up logging # Set up logging
@ -135,59 +218,28 @@ def main():
total_pages = 0 total_pages = 0
successful_pages = 0 successful_pages = 0
print("Processing documents...") print("Processing documents with ThreadPoolExecutor...")
for s3_path, entries in tqdm(documents.items()): with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
try: # Prepare futures
# Download the PDF locally future_to_s3 = {
pdf_local_path = cached_path(s3_path, quiet=True) executor.submit(
process_document,
pdf = PdfReader(pdf_local_path) s3_path,
total_pages_in_pdf = len(pdf.pages) entries,
except Exception as e: args.output_dir
logging.error(f"Error downloading or reading PDF {s3_path}: {e}") ): s3_path for s3_path, entries in documents.items()
continue
total_pages += total_pages_in_pdf
# Build mapping from pagenum to entry
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
valid_entries = []
missing_pages = []
errors = []
for page_num in range(total_pages_in_pdf):
entry = entry_by_pagenum.get(page_num)
if entry is None:
missing_pages.append(page_num)
elif entry.error is not None or entry.finish_reason != 'stop':
errors.append(entry)
else:
valid_entries.append(entry)
successful_pages += len(valid_entries)
if not missing_pages and not errors:
# Assemble text
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
text = '\n'.join(entry.text for entry in valid_entries_sorted)
# Generate a filename based on the s3_path
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
output_filename = os.path.join(args.output_dir, f'{doc_hash}.json')
output_data = {
'source': s3_path,
'total_pages': total_pages_in_pdf,
'text': text
} }
with open(output_filename, 'w') as f_out: # Use tqdm to display progress
json.dump(output_data, f_out) for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
try:
successful_documents += 1 result = future.result()
else: successful_documents += result.get('successful_documents', 0)
logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}') successful_pages += result.get('successful_pages', 0)
total_pages += result.get('total_pages', 0)
except Exception as e:
s3_path = future_to_s3[future]
logging.error(f"Error processing document {s3_path}: {e}")
print(f'Total documents: {total_documents}') print(f'Total documents: {total_documents}')
print(f'Successful documents: {successful_documents}') print(f'Successful documents: {successful_documents}')